Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to plot confusion matrix for prefetched dataset in Tensorflow

I was trying to plot a confusion matrix for my image classifier with the following code but I got an error message: 'PrefetchDataset' object has no attribute 'classes'

Y_pred = model.predict(validation_dataset)
y_pred = np.argmax(Y_pred, axis=1)

print('Confusion Matrix')
print(confusion_matrix(validation_dataset.classes, y_pred)) # ERROR message generated

PrefetchDataset' object has no attribute 'classes'

like image 288
zZzZ Avatar asked Oct 31 '20 13:10

zZzZ


People also ask

How do you plot a confusion matrix model?

Plot Confusion Matrix for Binary Classes With Labels You need to create a list of the labels and convert it into an array using the np. asarray() method with shape 2,2 . Then, this array of labels must be passed to the attribute annot . This will plot the confusion matrix with the labels annotation.


1 Answers

This code will work with shuffled tf.data.Dataset

y_pred = []  # store predicted labels
y_true = []  # store true labels

# iterate over the dataset
for image_batch, label_batch in dataset:   # use dataset.unbatch() with repeat
   # append true labels
   y_true.append(label_batch)
   # compute predictions
   preds = model.predict(image_batch)
   # append predicted labels
   y_pred.append(np.argmax(preds, axis = - 1))

# convert the true and predicted labels into tensors
correct_labels = tf.concat([item for item in y_true], axis = 0)
predicted_labels = tf.concat([item for item in y_pred], axis = 0)
like image 68
Ankit Kumar Saini Avatar answered Sep 24 '22 01:09

Ankit Kumar Saini