Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use shared memory instead of passing objects via pickling between multiple processes

I am working on a CPU intensive ML problem which is centered around an additive model. Since addition is the main operation I can divide the input data into pieces and spawn multiple models which are then merged by the overriden __add__ method.

The code relating to the multiprocessing looks like this:

def pool_worker(filename, doshuffle):
    print(f"Processing file: {filename}")
    with open(filename, 'r') as f:
        partial = FragmentModel(order=args.order, indata=f, shuffle=doshuffle)
        return partial

def generateModel(is_mock=False, save=True):
    model = None
    with ThreadPool(args.nthreads) as pool:
        from functools import partial
        partial_models = pool.imap_unordered(partial(pool_worker, doshuffle=is_mock), args.input)
        i = 0
        for m in partial_models:
            logger.info(f'Starting to merge model {i}')
            if model is None:
                import copy
                model = copy.deepcopy(m)
            else:
                model += m
            logger.info(f'Done merging...')
            i += 1

    return model

The issue is that the memory consumption scales exponentially as the model order increases, so at order 4 each instance of the model is about 4-5 GB, which causes the threadpool to crash as the intermediate model objects are then not pickleable.

I read about this a bit and it appears as even if the pickling is not an issue, it's still extremely inefficient to pass data like this, as commented to this answer.

There is very little guidance as to how one can use shared memory for this purpose, however. Is it possible to avoid this problem without having to change the internals of the model object?

like image 646
posdef Avatar asked Dec 04 '18 15:12

posdef


1 Answers

Since Python 3.8, there is multiprocessing.shared_memory that enables direct memory sharing between processes, similar to "real" multi-threading in C or Java. Direct memory sharing can be significantly faster than sharing via files, sockets, or data copy serialization/deserialization.

It works by providing a shared memory buffer on which different processes can claim and declare variables, via a basic SharedMemory class instance or a more advanced SharedMemoryManager class instance. Variables in basic python data types can be conveniently declared using the built-in ShareableList. Variables in advanced data types such as numpy.ndarray can be shared by specifying the memory buffer when declaring.

Example with numpy.ndarray:

import numpy as np
from multiprocessing import shared_memory

# setting up
data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
d_shape = (len(data),)
d_type = np.int64
d_size = np.dtype(d_type).itemsize * np.prod(d_shape)

# IN THE MAIN PROCESS
# allocate new shared memory
shm = shared_memory.SharedMemory(create=True, size=d_size)
shm_name = shm.name
# numpy array on shared memory buffer
a = np.ndarray(shape=d_shape, dtype=d_type, buffer=shm.buf)
# copy data into shared memory ndarray once
a[:] = data[:]

# IN ANOTHER PROCESS
# reuse existing shared memory
ex_shm = shared_memory.SharedMemory(name=shm_name)
# numpy array b uses the same memory buffer as a
b = np.ndarray(shape=d_shape, dtype=d_type, buffer=ex_shm.buf)
# changes in b will be reflected in a and vice versa...
ex_shm.close()  # close after using

# IN THE MAIN PROCESS
shm.close()  # close after using
shm.unlink()  # free memory

In the above code, a and b arrays use the same underlying memory and can update the same memory directly, which can be very useful in machine learning. However, you should beware of the concurrent update problems and decide how to handle them, either by using Lock/partitioned accesses/or accept stochastic updates (the so-called HogWild style).

like image 95
THN Avatar answered Oct 17 '22 06:10

THN