Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to save / serialize a trained model in theano?

I saved the model as documented on loading and saving.

# saving trained model
f = file('models/simple_model.save', 'wb')
cPickle.dump(ca, f, protocol=cPickle.HIGHEST_PROTOCOL)
f.close()

ca is a trained auto-encoder. It's a instance of class cA. From the script in which I build and save the model I can call ca.get_reconstructed_input(...) and ca.get_hidden_values(...) without any problem.

In a different script I try to load the trained model.

# loading the trained model
model_file = file('models/simple_model.save', 'rb')
ca = cPickle.load(model_file)
model_file.close()

I receive the following error.

ca = cPickle.load(model_file)

AttributeError: 'module' object has no attribute 'cA'

like image 778
xagg Avatar asked Aug 10 '15 13:08

xagg


1 Answers

All the class definitions of the pickled objects need to be known by the script that does the unpickling. There is more on this in other StackOverflow questions (e.g. AttributeError: 'module' object has no attribute 'newperson').

Your code is correct as long as you properly import cA. Given the error you're getting it may not be the case. Make sure you're using from cA import cA and not just import cA.

Alternatively, your model is defined by its parameters so you could instead just pickle the parameter values). This could be done in two ways depending on what you point of view.

  1. Save the Theano shared variables. Here we assume that ca.params is a regular Python list of Theano shared variable instances.

    cPickle.dump(ca.params, f, protocol=cPickle.HIGHEST_PROTOCOL)
    
  2. Save the numpy arrays stored inside the Theano shared variables.

    cPickle.dump([param.get_value() for param in ca.params], f, protocol=cPickle.HIGHEST_PROTOCOL)
    

When you want to load the model you'll need to reinitialize the parameters. For example, create a new instance of the cA class then either

ca.params = cPickle.load(f)
ca.W, ca.b, ca.b_prime = ca.params

or

ca.params = [theano.shared(param) for param in cPickle.load(f)]
ca.W, ca.b, ca.b_prime = ca.params

Note that you need to set both the params field and the separate parameters fields.

like image 132
Daniel Renshaw Avatar answered Oct 09 '22 01:10

Daniel Renshaw