Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why is this code able to use the sklearn function without import sklearn?

So I just watched a tutorial that the author didn't need to import sklearn when using predict function of pickled model in anaconda environment (sklearn installed).

I have tried to reproduce the minimal version of it in Google Colab. If you have a pickled-sklearn-model, the code below works in Colab (sklearn installed):

import pickle
model = pickle.load(open("model.pkl", "rb"), encoding="bytes")
out = model.predict([[20, 0, 1, 1, 0]])
print(out)

I realized that I still need the sklearn package installed. If I uninstall the sklearn, the predict function now is not working:

!pip uninstall scikit-learn
import pickle
model = pickle.load(open("model.pkl", "rb"), encoding="bytes")
out = model.predict([[20, 0, 1, 1, 0]])
print(out)

the error:

WARNING: Skipping scikit-learn as it is not installed.

---------------------------------------------------------------------------

ModuleNotFoundError                       Traceback (most recent call last)

<ipython-input-1-dec96951ae29> in <module>()
      1 get_ipython().system('pip uninstall scikit-learn')
      2 import pickle
----> 3 model = pickle.load(open("model.pkl", "rb"), encoding="bytes")
      4 out = model.predict([[20, 0, 1, 1, 0]])
      5 print(out)

ModuleNotFoundError: No module named 'sklearn'

So, how does it work? as far as I understand pickle doesn't depend on scikit-learn. Does the serialized model do import sklearn? Why can I use predict function without import scikit learn in the first code?

like image 259
malioboro Avatar asked Dec 30 '22 12:12

malioboro


2 Answers

There's a few questions being asked here, so let's go through them one by one:

So, how does it work? as far as I understand pickle doesn't depend on scikit-learn.

There is nothing particular to scikit-learn going on here. Pickle will exhibit this behaviour for any module. Here's an example with Numpy:

will@will-desktop ~ $ python
Python 3.9.6 (default, Aug 24 2021, 18:12:51) 
[GCC 9.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import sys
>>> 'numpy' in sys.modules
False
>>> import numpy
>>> 'numpy' in sys.modules
True
>>> pickle.dumps(numpy.array([1, 2, 3]))
b'\x80\x04\x95\xa0\x00\x00\x00\x00\x00\x00\x00\x8c\x15numpy.core.multiarray\x94\x8c\x0c_reconstruct\x94\x93\x94\x8c\x05numpy\x94\x8c\x07ndarray\x94\x93\x94K\x00\x85\x94C\x01b\x94\x87\x94R\x94(K\x01K\x03\x85\x94h\x03\x8c\x05dtype\x94\x93\x94\x8c\x02i8\x94\x89\x88\x87\x94R\x94(K\x03\x8c\x01<\x94NNNJ\xff\xff\xff\xffJ\xff\xff\xff\xffK\x00t\x94b\x89C\x18\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x94t\x94b.'
>>> exit()

So far what I've done is show that in a fresh Python process 'numpy' is not in sys.modules (the dict of imported modules). Then we import Numpy, and pickle a Numpy array.

Then in a new Python process shown below, we we see that before we unpickle the array Numpy has not been imported, but after we have Numpy has been imported.

will@will-desktop ~ $ python
Python 3.9.6 (default, Aug 24 2021, 18:12:51) 
[GCC 9.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import pickle
>>> import sys
>>> 'numpy' in sys.modules
False
>>> pickle.loads(b'\x80\x04\x95\xa0\x00\x00\x00\x00\x00\x00\x00\x8c\x15numpy.core.multiarray\x94\x8c\x0c_reconstruct\x94\x93\x94\x8c\x05numpy\x94\x8c\x07ndarray\x94\x93\x94K\x00\x85\x94C\x01b\x94\x87\x94R\x94(K\x01K\x03\x85\x94h\x03\x8c\x05dtype\x94\x93\x94\x8c\x02i8\x94\x89\x88\x87\x94R\x94(K\x03\x8c\x01<\x94NNNJ\xff\xff\xff\xffJ\xff\xff\xff\xffK\x00t\x94b\x89C\x18\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x94t\x94b.')
array([1, 2, 3])
>>> 'numpy' in sys.modules
True
>>> numpy
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NameError: name 'numpy' is not defined

Despite being imported, however, numpy is still not a defined variable name. Imports in Python are global, but an import will only update the namespace of the module that actually did the import. If we want to access numpy we still need to write import numpy, but since Numpy was already imported elsewhere in the process this will not re-run Numpy's module initialization code. Instead it will create a numpy variable in our module's globals dictionary, and make it a reference to the Numpy module object that existed beforehand, and could be accessed through sys.modules['numpy'].

So what is Pickle doing here? It embeds the information about what module was used to define whatever it is pickling within the pickle. Then when it unpickles something, it uses that information to import the module such that it can use the unpickle method of the class. We can look to the source code for the Pickle module we can see that's what's happening:

In the _Pickler we see save method uses the save_global method. This in turn uses the whichmodule function to obtain the module name ('scikit-learn', in your case), which is then saved in the pickle.

In the _UnPickler we see the find_class method uses __import__ to import the module using the stored module name. The find_class method is used in a few of the load_* methods, such as load_inst, which is what would be used to load an instance of a class, such as your model instance:

def load_inst(self):
    module = self.readline()[:-1].decode("ascii")
    name = self.readline()[:-1].decode("ascii")
    klass = self.find_class(module, name)
    self._instantiate(klass, self.pop_mark())

The documentation for Unpickler.find_class explains:

Import module if necessary and return the object called name from it, where the module and name arguments are str objects.

The docs also explain how you can restrict this behaviour:

[You] may want to control what gets unpickled by customizing Unpickler.find_class(). Unlike its name suggests, Unpickler.find_class() is called whenever a global (i.e., a class or a function) is requested. Thus it is possible to either completely forbid globals or restrict them to a safe subset.

Though this is generally only relevant when unpickling untrusted data, which doesn't appear to be the case here.


Does the serialized model do import sklearn?

The serialized model itself doesn't do anything, strictly speaking. It's all handled by the Pickle module as described above.


Why can I use predict function without import scikit learn in the first code?

Because sklearn is imported by the Pickle module when it unpickles the data, thereby providing you with a fully realized model object. It's just like if some other module imported sklearn, created the model object, and then passed it into your code as a parameter to a function.


As a consequence of all this, in order to unpickle your model you'll need to have sklearn installed - ideally the same version that was used to create the pickle. In general the Pickle module stores the fully qualified path of any required module, so the Python process that pickles the object and the one that unpickles the object must have all [1] required modules exist with the same fully qualified names.


[1] A caveat to that is that the Pickle module can automatically adjust/fix certain imports for particular modules/classes that have different fully qualified names between Python 2 and 3. From the docs:

If fix_imports is true, pickle will try to map the old Python 2 names to the new names used in Python 3.

like image 103
Will Da Silva Avatar answered Jan 14 '23 13:01

Will Da Silva


When the model was first pickled, you had sklearn installed. The pickle file depends on sklearn for its structure, as the class of the object it represents is a sklearn class, and pickle needs to know the details of that class’s structure in order to unpickle the object.

When you try to unpickle the file without sklearn installed, pickle determines from the file that the class the object is an instance of is sklearn.x.y.z or what have you, and then the unpickling fails because the module sklearn cannot be found when pickle tries to resolve that name. Notice that the exception occurs on the unpickling line, not on the line where predict is called.

You don’t need to import sklearn in your code when it does work because once the object is unpickled, it knows what its class is and what all its method names are, so you can just call them from the object.

like image 22
Angus L'Herrou Avatar answered Jan 14 '23 15:01

Angus L'Herrou