Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to deserialize a PyTorch saved model with private methods inside a class?

I used the PyTorch saving method to serialize a bunch of essential objects. Among those, there was one class referencing a private method inside the __init__ of that same class. Now, after the serialization, I can't deserialize (unpickle) files because the private method is not accessible outside the class. Any idea how to solve or bypass it? I need to recover the data saved into the attributes of that class.

  File ".conda/envs/py37/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3331, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-1-a5666d77c70f>", line 1, in <module>
    torch.load("snapshots/model.pth", map_location='cpu')
  File ".conda/envs/py37/lib/python3.7/site-packages/torch/serialization.py", line 529, in load
    return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
  File ".conda/envs/py37/lib/python3.7/site-packages/torch/serialization.py", line 702, in _legacy_load
    result = unpickler.load()
AttributeError: 'Trainer' object has no attribute '__iterator'
  • EDIT-1:

Here there is a piece of code that will generate the problem I’m facing right now.

import torch

class Test:
    def __init__(self):
        self.a = min
        self.b = max
        self.c = self.__private  # buggy

    def __private(self):
        return None

test = Test()

torch.save({"test": test}, "file.pkl")
torch.load("file.pkl")

However, if you remove the private attribute from the method, you won’t get any error.

import torch

class Test:
    def __init__(self):
        self.a = min
        self.b = max
        self.c = self.private  # not buggy

    def private(self):
        return None

test = Test()

torch.save({"test": test}, "file.pkl")
torch.load("file.pkl")
like image 931
user3108967 Avatar asked Dec 17 '25 09:12

user3108967


1 Answers

This question is similar to Python multiprocessing - mapping private method, but can not be marked as duplicate because of the open bounty.

The issue arises from this open issue on the Python bug tracker: Objects referencing private-mangled names do not roundtrip properly under pickling, and is related to the way pickle handles name-mangling. More details on this answer: https://stackoverflow.com/a/57497698/6352677.

At this point, the only workaround is not using private methods in __init__.

like image 194
Keldorn Avatar answered Dec 19 '25 23:12

Keldorn



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!