Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to create a database connect engine in each Dask sub process to parallel thousands of sql query, without recreating engine in every query

I need to embarrassingly parallel the fetch job for thousands of sql query from database. Here is the simplified example.

##Env info: python=3.7 postgresql=10 dask=latest
##generate the example db table.
from sqlalchemy import create_engine
import pandas as pd
import numpy as np

engine = create_engine('postgresql://dbadmin:dbadmin@server:5432/db01')
data = pd.DataFrame(np.random.randint(0,100 , size=(30000,5)),columns=['a','b','c','d','e'])
data.to_sql('tablename',engine,index=True,if_exists='append')

First, this is the basic example without dask parallel.

from sqlalchemy import create_engine
import pandas as pd
import numpy as np

engine = create_engine('postgresql://dbadmin:dbadmin@server:5432/db01')
def job(indexstr):
    'send the query, fetch the data, do some calculate and return'
    sql='select * from public.tablename where index='+indexstr
    df=pd.read_sql_query(sql, engine, index_col='index',)
    ##get the data and do some analysis.
    return np.sum(df.values)
for v in range(1000):
    lists.append(job(str(v)))
### wall time:17s

It's not as fast as we image since both the database query and data analysis might cost time and there are more idle cpu.

Then I try to use dask to parallel it like this.

def jobWithEngine(indexstr):
    `engine cannot be serialized between processes thus create each own.`
    engine = create_engine('postgresql://dbadmin:dbadmin@server:5432/db01')
    sql='select * from public.tablename where index='+indexstr
    df=pd.read_sql_query(sql, engine, index_col='index',)
    return np.sum(df.values)
import dask
dask.config.set(scheduler='processes')
import dask.bag as db
dbdata=db.from_sequence([str(v) for v in range(1000)])
dbdata=dbdata.map(lambda x:jobWithEngine(x))
results_bag = dbdata.compute()
###Wall time:1min8s

The problem is that I find the engine creation takes more time and there are thousands of it.

It will be recreated in every sql query which is really costly and it might crash the database service!

So I guess it must be more elegant way like this:

import dask
dask.config.set(scheduler='processes')
import dask.bag as db
dbdata=db.from_sequence([str(v) for v in range(1000)])
dbdata=dbdata.map(lambda x:job(x,init=create_engine))
results_bag = dbdata.compute()

1.The main process create 8 sub process.

2.Each process create its own engine to initialize the job preparation.

3.Then main process send them 1000 jobs and get the 1000 return.

4.After all is done, sub process engine is stopped and kill the sub process.

Or the dask have already done this and the additional time comes from communications between process?

like image 834
WilsonF Avatar asked Dec 08 '25 08:12

WilsonF


2 Answers

You can do this by setting a connected database as a variable for each worker using get_worker

from dask.distributed import get_worker

def connect_worker_db(db):
    worker = get_worker()
    worker.db = db          # DB settings, password, username etc
    worker.db.connect()     # Function that connects the database, e.g. create_engine()

Then have the client run the connect_worker_db:

from dask.distributed import Client, get_worker
client = Client()
client.run(connect_worker_db, db)

For the function using the connection, like jobWithEngine(), you have to get the worker and use the parameter you have saved it to:

def jobWithEngine():
    db = get_worker().db

Then make sure to disconnect at the end:

def disconnect_worker_db():
    worker = get_worker()
    worker.db.disconnect()

client.run(disconnect_worker_db)
like image 107
AmyChodorowski Avatar answered Dec 10 '25 21:12

AmyChodorowski


Amy's answer has the benefit of being simple, but if for any reason dask starts new workers, they will not have .db.

I don't know when first introduced, but Dask 1.12.2 has a Client.register_worker_callbacks which takes a function as a parameter intended for this kind of use. If this callback takes a param called dask_worker then worker itself will be passed.

def main():

    dask_client = dask.distributed.Client(cluster)

    db = dict(
        host="db-host",
        username="user",
        # etc etc
    )
    def worker_setup(dask_worker: dask.distributed.Worker):
        dask_worker.db = db

    dask_client.register_worker_callbacks(worker_setup)

https://distributed.dask.org/en/latest/api.html#distributed.Client.register_worker_callbacks

However, this doesn't close the db connections at the end. You probably will be covered with client.run(disconnect_worker_db) but I have seen some workers not releasing their resources. Fixing this in a more comprehensive manner needs a bit more code as per https://distributed.dask.org/en/latest/api.html#distributed.Client.register_worker_plugin

class MyWorkerPlugin(dask.distributed.WorkerPlugin):
    def __init__(self, *args, **kwargs):
        self.db = kwargs.get("db")
        assert self.db, "no db"

    def setup(self, worker: dask.distributed.Worker):
        worker.db = self.db

    def teardown(self, worker: dask.distributed.Worker):
        print(f"worker {worker.name} teardown")
        # eg db.disconnect()



def main():

    cluster = dask.distributed.LocalCluster(
        n_workers=os.cpu_count(),
        threads_per_worker=2,
    )
    dask_client = dask.distributed.Client(cluster)
    db = dict(
        host="db-host",
        username="user",
        # etc etc
    )

    dask_client.register_worker_plugin(LGInferWorkerPlugin, "set-dbs", db=db)
    dask_client.start()

You can give the plugin somewhat helpful names, and pass in kwargs to be used in the plugin's __init__.

like image 25
Enda Farrell Avatar answered Dec 10 '25 22:12

Enda Farrell



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!