Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Cannot pickle 'weakref' object when using Tensorflow with Multiprocessing

I would like to train several neural networks at the same time, and I'm trying to use the multiprocessing module so that each network can be trained in a separate process, but I met an issue. When I ran the demo code below (as apply_async function does not give a hint on errors, I temporarily changed it to apply function):

import tensorflow as tf
import multiprocessing as mp


class SeqModel(tf.keras.Sequential):
    def __init__(self, input_size, hidden_sizes, output_size):
        super().__init__()
        self.add(tf.keras.layers.Dense(hidden_sizes[0], activation="relu", input_shape=(input_size,)))
        for hidden_size in hidden_sizes[1:]: self.add(tf.keras.layers.Dense(hidden_size, activation="relu"))
        if output_size is not None: self.add(tf.keras.layers.Dense(output_size))


class Partition:
    def __init__(self, partition_id):
        self.partition_id = partition_id
        self.model = None

    def initialization(self):
        self.model = SeqModel(10,[10,10],10)

    def test(self):
        print(f'partition {self.partition_id} testing...')


def func():
    partition_list = [Partition(i) for i in range(4)]

    for partition in partition_list: partition.initialization()

    p = mp.Pool(4)
    for partition in partition_list:
        p.apply(partition.test)
    p.close()
    p.join()


if __name__ == '__main__':
    func()

I got the following error:

Traceback (most recent call last):
  File "C:/Users/Administrator/Dropbox (ASU)/Work/Traffic State Estimation/traffic state estimation/dataset/mp/mp_net.py", line 43, in <module>
    func()
  File "C:/Users/Administrator/Dropbox (ASU)/Work/Traffic State Estimation/traffic state estimation/dataset/mp/mp_net.py", line 37, in func
    p.apply(partition.test)
  File "C:\Users\Administrator\AppData\Local\Programs\Python\Python38\lib\multiprocessing\pool.py", line 357, in apply
    return self.apply_async(func, args, kwds).get()
  File "C:\Users\Administrator\AppData\Local\Programs\Python\Python38\lib\multiprocessing\pool.py", line 771, in get
    raise self._value
  File "C:\Users\Administrator\AppData\Local\Programs\Python\Python38\lib\multiprocessing\pool.py", line 537, in _handle_tasks
    put(task)
  File "C:\Users\Administrator\AppData\Local\Programs\Python\Python38\lib\multiprocessing\connection.py", line 206, in send
    self._send_bytes(_ForkingPickler.dumps(obj))
  File "C:\Users\Administrator\AppData\Local\Programs\Python\Python38\lib\multiprocessing\reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
TypeError: cannot pickle 'weakref' object

If I do not do partition initialization (no SeqModel involved in Partition instances), the code runs without problem. Does that mean I cannot use tf Models in child processes?

like image 979
Jiawei Lu Avatar asked May 16 '26 02:05

Jiawei Lu


1 Answers

To use Pool your objects have to be pickalable because Pool methods use mp.SimpleQueue to send tasks to the processes, and mp.SimpleQueue accepts only pickled objects.

Tensorflow models though aren't pickalable by default, so you can't easily use Pool with Tensorflow models. See an open issue in TensorFlow to make Model pickalable.

However, you can try and make Model pickalable via a workaround suggested in the discussion https://github.com/tensorflow/tensorflow/issues/34697#issuecomment-627193883

like image 60
PermanentPon Avatar answered May 18 '26 15:05

PermanentPon



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!