Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

AttributeError using pyBrain _splitWithPortion - object type changed?

Tags:

python

pybrain

I'm testing out pybrain following the basic classification tutorial here and a different take on it with some more realistic data here. However I receive this error when applying trndata._convertToOneOfMany() with the error:

AttributeError: 'SupervisedDataSet' object has no attribute '_convertToOneOfMany

The data set is created as a classification.ClassificationDataSet object however calling splitWithProportion seems to change it supervised.SupervisedDataSet object, so being fairly new to Python this error doesn't seem such a surprise as the supervised.SupervisedDataSet doesn't have that method, classification.ClassificationDataSet does. Code here.

However the same exact code is used across so many tutorials I feel that I must be missing something as plenty of other people have it working. I've looked at changes to the codebase on github and there's nothing around this function, I've also tried running under Python 3 vs 2.7 but no difference. If anyone has any pointers to get me back on the right path and that would be very much appreciated.

#flatten the 64x64 data in to one dimensional 4096
ds = ClassificationDataSet(4096, 1 , nb_classes=40)
for k in xrange(len(X)): #length of X is 400
    ds.addSample(np.ravel(X[k]),y[k])
    # a new sample consisting of input and target

print(type(ds))      
tstdata, trndata = ds.splitWithProportion( 0.25 )
print(type(trndata))

trndata._convertToOneOfMany()
tstdata._convertToOneOfMany()
like image 998
iammarkhammond Avatar asked Jan 11 '15 14:01

iammarkhammond


1 Answers

I had the same problem. I added the following code to make it work on my machine.

tstdata_temp, trndata_temp = alldata.splitWithProportion(0.25)

tstdata = ClassificationDataSet(2, 1, nb_classes=3)
for n in xrange(0, tstdata_temp.getLength()):
    tstdata.addSample( tstdata_temp.getSample(n)[0], tstdata_temp.getSample(n)[1] )

trndata = ClassificationDataSet(2, 1, nb_classes=3)
for n in xrange(0, trndata_temp.getLength()):
    trndata.addSample( trndata_temp.getSample(n)[0], trndata_temp.getSample(n)[1] )

This converts tstdata and trndata back to the ClassificationDataSet type.

like image 55
Muhammed Miah Avatar answered Sep 20 '22 03:09

Muhammed Miah