Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to change plot_confusion_matrix default figure size in sklearn.metrics package

I tried to plot confusion matrix with Jupyter notebook using sklearn.metrics.plot_confusion_matrix package, but the default figure size is a little bit small. I have added plt.figure(figsize=(20, 20)) before plotting, but the figure size did not change with output text 'Figure size 1440x1440 with 0 Axes'. How can I change the figure size?

%matplotlib inline
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import plot_confusion_matrix
from matplotlib import pyplot as plt

plt.figure(figsize=(20, 20))
clf = GradientBoostingClassifier(random_state=42)
clf.fit(X_train, y_train)
plot_confusion_matrix(clf, X_test, y_test, cmap=plt.cm.Blues)
plt.title('Confusion matrix')
plt.show()

just like this image

like image 531
Romaker Avatar asked Apr 20 '20 14:04

Romaker


People also ask

What is the size of confusion matrix?

The confusion matrix is a N x N matrix, where N is the number of classes or outputs. For 2 class ,we get 2 x 2 confusion matrix.

How do I change font size in confusion matrix?

Use sns. set to change the font size of the heatmap values. You can specify the font size of the labels and the title as a dictionary in ax. set_xlabel , ax.


2 Answers

I don't know why BigBen posted that as a comment, rather than an answer, but I almost missed seeing it. Here it is as an answer, so future onlookers don't make the same mistake I almost made!

fig, ax = plt.subplots(figsize=(10, 10))
plot_confusion_matrix(your_model, X_test, y_test, ax=ax)
like image 128
Hovanes Gasparian Avatar answered Oct 16 '22 16:10

Hovanes Gasparian


I use set_figwidth and set_figheight to specify the figure size

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay  
import matplotlib.pyplot as plt 
disp = ConfusionMatrixDisplay.from_predictions(
                  [0,1,1,0,1], 
                  [0,1,0,1,0], 
                  labels=[1,0],
                  cmap=plt.cm.Blues,
                  display_labels=['Good','Bad'], 
                  values_format='',  
) 
fig = disp.ax_.get_figure() 
fig.set_figwidth(3)
fig.set_figheight(3)  
like image 1
dom free Avatar answered Oct 16 '22 15:10

dom free