Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

DataType string for attr 'TI' not in list of allowed values: uint8, int32, int64

I have been using a CNN for text classification and have used tensorflow's contrib learn.

However, when I try to execute the following code:

classifier = learn.Estimator(model_fn=cnn_model)

classifier.fit(x_train, y_train, steps=10000) 
y_predicted = [ p['class'] for p in classifier.predict(x_test, as_iterable=True)] 

score = metrics.accuracy_score(y_test, y_predicted) 

print('Accuracy: {0:f}'.format(score))

I am running in the following error:

ERROR:DataType string for attr 'TI' not in list of allowed values: uint8, int32, int64 on line 'classifier.fit'

like image 814
Raj Avatar asked Dec 12 '25 21:12

Raj


1 Answers

You need to convert your inputs y_train to the given type. print(type(y_train)) most likely is a float instead of an integer.

like image 147
fabrizioM Avatar answered Dec 14 '25 14:12

fabrizioM



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!