Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to serialize/deserialized pybrain networks?

PyBrain is a python library that provides (among other things) easy to use Artificial Neural Networks.

I fail to properly serialize/deserialize PyBrain networks using either pickle or cPickle.

See the following example:

from pybrain.datasets            import SupervisedDataSet
from pybrain.tools.shortcuts     import buildNetwork
from pybrain.supervised.trainers import BackpropTrainer
import cPickle as pickle
import numpy as np 

#generate some data
np.random.seed(93939393)
data = SupervisedDataSet(2, 1)
for x in xrange(10):
    y = x * 3
    z = x + y + 0.2 * np.random.randn()  
    data.addSample((x, y), (z,))

#build a network and train it    

net1 = buildNetwork( data.indim, 2, data.outdim )
trainer1 = BackpropTrainer(net1, dataset=data, verbose=True)
for i in xrange(4):
    trainer1.trainEpochs(1)
    print '\tvalue after %d epochs: %.2f'%(i, net1.activate((1, 4))[0])

This is the output of the above code:

Total error: 201.501998476
    value after 0 epochs: 2.79
Total error: 152.487616382
    value after 1 epochs: 5.44
Total error: 120.48092561
    value after 2 epochs: 7.56
Total error: 97.9884043452
    value after 3 epochs: 8.41

As you can see, network total error decreases as the training progresses. You can also see that the predicted value approaches the expected value of 12.

Now we will do a similar exercise, but will include serialization/deserialization:

print 'creating net2'
net2 = buildNetwork(data.indim, 2, data.outdim)
trainer2 = BackpropTrainer(net2, dataset=data, verbose=True)
trainer2.trainEpochs(1)
print '\tvalue after %d epochs: %.2f'%(1, net2.activate((1, 4))[0])

#So far, so good. Let's test pickle
pickle.dump(net2, open('testNetwork.dump', 'w'))
net2 = pickle.load(open('testNetwork.dump'))
trainer2 = BackpropTrainer(net2, dataset=data, verbose=True)
print 'loaded net2 using pickle, continue training'
for i in xrange(1, 4):
        trainer2.trainEpochs(1)
        print '\tvalue after %d epochs: %.2f'%(i, net2.activate((1, 4))[0])

This is the output of this block:

creating net2
Total error: 176.339378639
    value after 1 epochs: 5.45
loaded net2 using pickle, continue training
Total error: 123.392181859
    value after 1 epochs: 5.45
Total error: 94.2867637623
    value after 2 epochs: 5.45
Total error: 78.076711114
    value after 3 epochs: 5.45

As you can see, it seems that the training has some effect on the network (the reported total error value continues to decrease), however the output value of the network freezes on a value that was relevant for the first training iteration.

Is there any caching mechanism that I need to be aware of that causes this erroneous behaviour? Are there better ways to serialize/deserialize pybrain networks?

Relevant version numbers:

  • Python 2.6.5 (r265:79096, Mar 19 2010, 21:48:26) [MSC v.1500 32 bit (Intel)]
  • Numpy 1.5.1
  • cPickle 1.71
  • pybrain 0.3

P.S. I have created a bug report on the project's site and will keep both SO and the bug tracker updatedj

like image 648
Boris Gorelik Avatar asked Dec 02 '10 12:12

Boris Gorelik


1 Answers

Cause

The mechanism that causes this behavior is the handling of parameters (.params) and derivatives (.derivs) in PyBrain modules: in fact, all network parameters are stored in one array, but the individual Module or Connection objects have access to "their own" .params, which, however are just a view on a slice of the total array. This allows both local and network-wide writes and read-outs on the same data-structure.

Apparently this slice-view link gets lost by pickling-unpickling.

Solution

Insert

net2.sorted = False
net2.sortModules()

after loading from the file (which recreates this sharing), and it should work.

like image 148
schaul Avatar answered Nov 08 '22 08:11

schaul