Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Number of classes, 4, does not match size of target_names, 6. Try specifying the labels parameter

I am facing some issue when I try to make confusion matrix of my CNN model.When I run the code , it returns some error like :

print(classification_report(np.argmax(y_test,axis=1), y_pred,target_names=target_names))

Traceback (most recent call last):

  File "<ipython-input-102-82d46efe536a>", line 1, in <module>
    print(classification_report(np.argmax(y_test,axis=1), y_pred,target_names=target_names))

  File "G:\anaconda_installation_file\lib\site-packages\sklearn\metrics\classification.py", line 1543, in classification_report
    "parameter".format(len(labels), len(target_names))

ValueError: Number of classes, 4, does not match size of target_names, 6. Try specifying the labels parameter

Already I have searched about to solve this problem but still don't get the perfect solution. I am totally new in this field, can anyone help me out? Thanks.

from sklearn.metrics import classification_report,confusion_matrix
import itertools

Y_pred = model.predict(X_test)
print(Y_pred)
y_pred = np.argmax(Y_pred, axis=1)
print(y_pred)

target_names = ['class 0(cardboard)', 'class 1(glass)', 'class 2(metal)','class 3(paper)', 'class 4(plastic)','class 5(trash)']

print(classification_report(np.argmax(y_test,axis=1), y_pred,target_names=target_names))
like image 915
Shahinur Shakib Avatar asked Sep 13 '19 01:09

Shahinur Shakib


2 Answers

You should have framed your question better! I am making some assumptions!
The problem is that :

target_names = ['class 0(cardboard)', 'class 1(glass)', 'class 2(metal)','class 3(paper)', 'class 4(plastic)','class 5(trash)']

has 6 classes and your model is able to predict only 4 classes which throws error as confusion matrix is supplied with 4 classes(it should be 6x6 and not 6x4).
To correct this just supply the class labels also. for ecample if having 3 labels (in predictor variable) namely 1,2,3

print(classification_report(y_true, y_pred, labels=[1, 2, 3]))

Refer documentation here https://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html

PS:

  1. Your model is performing poorly.

  2. Your dataset may have a class imbalance problem.

like image 171
aryan chaudhary Avatar answered Nov 08 '22 03:11

aryan chaudhary


The problem is that you are have 6 label names: 'class 0(cardboard)', 'class 1(glass)', 'class 2(metal)','class 3(paper)', 'class 4(plastic)','class 5(trash)'

but you only have 4 classes in your confusion_matrix, when you print: print(y_pred): you will get something with numbers with 0,1,2,3or when you print(y_test)you will get numbers from 0,1,2,3, it should help to remove:

print(classification_report(np.argmax(y_test,axis=1), y_pred,target_names=target_names))

from your code, somehow you do not have 6 prediction/test classes.

here is also a example how to plot a confusion matrix: How can I plot a confusion matrix?

like image 20
PV8 Avatar answered Nov 08 '22 01:11

PV8