Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

multiprocessing / psycopg2 TypeError: can't pickle _thread.RLock objects

I followed the below code in order to implement a parallel select query on a postgres database:

https://tech.geoblink.com/2017/07/06/parallelizing-queries-in-postgresql-with-python/

My basic problem is that I have ~6k queries that need to be executed, and I am trying to optimise the execution of these select queries. Initially it was a single query with the where id in (...) contained all 6k predicate IDs but I ran into issues with the query using up > 4GB of RAM on the machine it ran on, so I decided to split it out into 6k individual queries which when synchronously keeps a steady memory usage. However it takes a lot longer to run time wise, which is less of an issue for my use case. Even so I am trying to reduce the time as much as possible.

This is what my code looks like:

class PostgresConnector(object):
    def __init__(self, db_url):
        self.db_url = db_url
        self.engine = self.init_connection()
        self.pool = self.init_pool()

    def init_pool(self):
        CPUS = multiprocessing.cpu_count()
        return multiprocessing.Pool(CPUS)

    def init_connection(self):
        LOGGER.info('Creating Postgres engine')
        return create_engine(self.db_url)

    def run_parallel_queries(self, queries):
        results = []
        try:
            for i in self.pool.imap_unordered(self.execute_parallel_query, queries):
                results.append(i)
        except Exception as exception:
            LOGGER.error('Error whilst executing %s queries in parallel: %s', len(queries), exception)
            raise
        finally:
            self.pool.close()
            self.pool.join()

        LOGGER.info('Parallel query ran producing %s sets of results of type: %s', len(results), type(results))

        return list(chain.from_iterable(results))

    def execute_parallel_query(self, query):
        con = psycopg2.connect(self.db_url)
        cur = con.cursor()
        cur.execute(query)
        records = cur.fetchall()
        con.close()

        return list(records)

However whenever this runs, I get the following error:

TypeError: can't pickle _thread.RLock objects

I've read lots of similar questions regarding the use of multiprocessing and pickleable objects but I cant for the life of me figure out what I am doing wrong.

The pool is generally one per process (which I believe is the best practise) but shared per instance of the connector class so that its not creating a pool for each use of the parallel_query method.

The top answer to a similar question:

Accessing a MySQL connection pool from Python multiprocessing

Shows an almost identical implementation to my own, except using MySql instead of Postgres.

Am I doing something wrong?

Thanks!

EDIT:

I've found this answer:

Python Postgres psycopg2 ThreadedConnectionPool exhausted

which is incredibly detailed and looks as though I have misunderstood what multiprocessing.Pool vs a connection pool such as ThreadedConnectionPool gives me. However in the first link it doesn't mention needing any connection pools etc. This solution seems good but seems A LOT of code for what I think is a fairly simple problem?

EDIT 2:

So the above link solves another problem, which I would have likely run into anyway so I'm glad I found that, but it doesnt solve the initial issue of not being able to use imap_unordered down to the pickling error. Very frustrating.

Lastly, I think its probably worth noting that this runs in Heroku, on a worker dyno, using Redis rq for scheduling, background tasks etc and a hosted instance of Postgres as the database.

like image 307
JustinMoser Avatar asked Sep 27 '18 13:09

JustinMoser


1 Answers

To put it simply, postgres connection and sqlalchemy connection pool is thread safe, however they are not fork-safe.

If you want to use multiprocessing, you should initialize the engine in each child processes after the fork.

You should use multithreading instead if you want to share engines.

Refer to Thread and process safety in psycopg2 documentation:

libpq connections shouldn’t be used by a forked processes, so when using a module such as multiprocessing or a forking web deploy method such as FastCGI make sure to create the connections after the fork.

If you are using multiprocessing.Pool, there is a keyword argument initializer which can be used to run code once on each child process. Try this:

class PostgresConnector(object):
    def __init__(self, db_url):
        self.db_url = db_url
        self.pool = self.init_pool()

    def init_pool(self):
        CPUS = multiprocessing.cpu_count()
        return multiprocessing.Pool(CPUS, initializer=self.init_connection(self.db_url))

    @classmethod
    def init_connection(cls, db_url):
        def _init_connection():
            LOGGER.info('Creating Postgres engine')
            cls.engine = create_engine(db_url)
        return _init_connection

    def run_parallel_queries(self, queries):
        results = []
        try:
            for i in self.pool.imap_unordered(self.execute_parallel_query, queries):
                results.append(i)
        except Exception as exception:
            LOGGER.error('Error whilst executing %s queries in parallel: %s', len(queries), exception)
            raise
        finally:
            pass
            #self.pool.close()
            #self.pool.join()

        LOGGER.info('Parallel query ran producing %s sets of results of type: %s', len(results), type(results))

        return list(chain.from_iterable(results))

    def execute_parallel_query(self, query):
        with self.engine.connect() as conn:
            with conn.begin():
                result = conn.execute(query)
                return result.fetchall()

    def __getstate__(self):
        # this is a hack, if you want to remove this method, you should
        # remove self.pool and just pass pool explicitly
        self_dict = self.__dict__.copy()
        del self_dict['pool']
        return self_dict

Now, to address the XY problem.

Initially it was a single query with the where id in (...) contained all 6k predicate IDs but I ran into issues with the query using up > 4GB of RAM on the machine it ran on, so I decided to split it out into 6k individual queries which when synchronously keeps a steady memory usage.

What you may want to do instead is one of these options:

  1. write a subquery that generates all 6000 IDs and use the subquery in your original bulk query.
  2. as above, but write the subquery as a CTE
  3. if your ID list comes from an external source (i.e. not from the database), then you can create a temporary table containing the 6000 IDs and then run your original bulk query against the temporary table

However, if you insist on running 6000 IDs through python, then the fastest query is likely neither to do all 6000 IDs in one go (which will run out of memory) nor to run 6000 individual queries. Instead, you may want to try to chunk the queries. Send 500 IDs at once for example. You will have to experiment with the chunk size to determine the largest number of IDs you can send at one time while still comfortably within your memory budget.

like image 91
Lie Ryan Avatar answered Sep 30 '22 09:09

Lie Ryan