Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Simple Regression Example pyBrain

I am trying to make the simpliest regression on pyBrain but somehow I'm failing.

The Neural Network should learn the function Y=3*X

from pybrain.supervised.trainers import BackpropTrainer
from pybrain.datasets import SupervisedDataSet
from pybrain.structure import FullConnection, FeedForwardNetwork, TanhLayer, LinearLayer, BiasUnit
import matplotlib.pyplot as plt
from numpy import *

n = FeedForwardNetwork()
n.addInputModule(LinearLayer(1, name = 'in'))
n.addInputModule(BiasUnit(name = 'bias'))
n.addModule(TanhLayer(1,name = 'tan'))
n.addOutputModule(LinearLayer(1, name = 'out'))
n.addConnection(FullConnection(n['bias'], n['tan']))
n.addConnection(FullConnection(n['in'], n['tan']))
n.addConnection(FullConnection(n['tan'], n['out']))
n.sortModules()

# initialize the backprop trainer and train
t = BackpropTrainer(n, learningrate = 0.1, momentum = 0.0, verbose = True)

#DATASET
DS = SupervisedDataSet( 1, 1 )
X = random.rand(100,1)*100
Y = X*3+random.rand(100,1)*5
for r in xrange(X.shape[0]):
    DS.appendLinked((X[r]),(Y[r]))

t.trainOnDataset(DS, 200)
plt.plot(X,Y,'.b')
X=[[i] for i in arange(0,100,0.1)]
Y=map(n.activate,X)
plt.plot(X,Y,'-g')

It doesn't learn anything. I have tried to remove the hidden layer (because in this example we don't even need that) and the network started to predict NaNs. What's going on?

EDIT: This is the code that solved my problem:

#DATASET
DS = SupervisedDataSet( 1, 1 )
X = random.rand(100,1)*100
Y = X*3+random.rand(100,1)*5
maxy = float(max(Y))
maxx = 100.0
for r in xrange(X.shape[0]):
    DS.appendLinked((X[r]/maxx),(Y[r]/maxy))

t.trainOnDataset(DS, 200)

plt.plot(X,Y,'.b')
X=[[i] for i in arange(0,100,0.1)]
Y=map(lambda x: n.activate(array(x)/maxx)*maxy,X)
plt.plot(X,Y,'-g')
like image 819
João Abrantes Avatar asked Sep 30 '22 05:09

João Abrantes


1 Answers

The basic pybrain neurons are going to output something between 0 and 1. Divide your Y by 300 (the maximum possible value), and you'll get better results.

More generally, find the maximum Y for your dataset, and scale everything by that.

like image 90
rossdavidh Avatar answered Oct 06 '22 19:10

rossdavidh