Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

R how to visualize confusion matrix using the caret package

I'd like to visualize the data I've put in the confusion matrix. Is there a function I could simply put the confusion matrix and it would visualize it (plot it)?

Example what I'd like to do(Matrix$nnet is simply a table containing results from the classification):

Confusion$nnet <- confusionMatrix(Matrix$nnet) plot(Confusion$nnet) 

My Confusion$nnet$table looks like this:

    prediction (I would also like to get rid of this string, any help?)     1  2 1   42 6 2   8 28 
like image 713
shish Avatar asked May 27 '14 13:05

shish


People also ask

What is the caret package in R?

Caret is a one-stop solution for machine learning in R. The R package caret has a powerful train function that allows you to fit over 230 different models using one syntax. There are over 230 models included in the package including various tree-based models, neural nets, deep learning and much more.

How do you plot a confusion matrix?

Plot Confusion Matrix for Binary Classes With Labels You need to create a list of the labels and convert it into an array using the np. asarray() method with shape 2,2 . Then, this array of labels must be passed to the attribute annot . This will plot the confusion matrix with the labels annotation.

What library is confusion matrix in R?

A confusion matrix is a table of values that represent the predicted and actual values of the data points. You can make use of the most useful R libraries such as caret, gmodels, and functions such as a table() and crosstable() to get more insights into your data.


1 Answers

You can just use the rect functionality in r to layout the confusion matrix. Here we will create a function that allows the user to pass in the cm object created by the caret package in order to produce the visual.

Let's start by creating an evaluation dataset as done in the caret demo:

# construct the evaluation dataset set.seed(144) true_class <- factor(sample(paste0("Class", 1:2), size = 1000, prob = c(.2, .8), replace = TRUE)) true_class <- sort(true_class) class1_probs <- rbeta(sum(true_class == "Class1"), 4, 1) class2_probs <- rbeta(sum(true_class == "Class2"), 1, 2.5) test_set <- data.frame(obs = true_class,Class1 = c(class1_probs, class2_probs)) test_set$Class2 <- 1 - test_set$Class1 test_set$pred <- factor(ifelse(test_set$Class1 >= .5, "Class1", "Class2")) 

Now let's use caret to calculate the confusion matrix:

# calculate the confusion matrix cm <- confusionMatrix(data = test_set$pred, reference = test_set$obs) 

Now we create a function that lays out the rectangles as needed to showcase the confusion matrix in a more visually appealing fashion:

draw_confusion_matrix <- function(cm) {    layout(matrix(c(1,1,2)))   par(mar=c(2,2,2,2))   plot(c(100, 345), c(300, 450), type = "n", xlab="", ylab="", xaxt='n', yaxt='n')   title('CONFUSION MATRIX', cex.main=2)    # create the matrix    rect(150, 430, 240, 370, col='#3F97D0')   text(195, 435, 'Class1', cex=1.2)   rect(250, 430, 340, 370, col='#F7AD50')   text(295, 435, 'Class2', cex=1.2)   text(125, 370, 'Predicted', cex=1.3, srt=90, font=2)   text(245, 450, 'Actual', cex=1.3, font=2)   rect(150, 305, 240, 365, col='#F7AD50')   rect(250, 305, 340, 365, col='#3F97D0')   text(140, 400, 'Class1', cex=1.2, srt=90)   text(140, 335, 'Class2', cex=1.2, srt=90)    # add in the cm results    res <- as.numeric(cm$table)   text(195, 400, res[1], cex=1.6, font=2, col='white')   text(195, 335, res[2], cex=1.6, font=2, col='white')   text(295, 400, res[3], cex=1.6, font=2, col='white')   text(295, 335, res[4], cex=1.6, font=2, col='white')    # add in the specifics    plot(c(100, 0), c(100, 0), type = "n", xlab="", ylab="", main = "DETAILS", xaxt='n', yaxt='n')   text(10, 85, names(cm$byClass[1]), cex=1.2, font=2)   text(10, 70, round(as.numeric(cm$byClass[1]), 3), cex=1.2)   text(30, 85, names(cm$byClass[2]), cex=1.2, font=2)   text(30, 70, round(as.numeric(cm$byClass[2]), 3), cex=1.2)   text(50, 85, names(cm$byClass[5]), cex=1.2, font=2)   text(50, 70, round(as.numeric(cm$byClass[5]), 3), cex=1.2)   text(70, 85, names(cm$byClass[6]), cex=1.2, font=2)   text(70, 70, round(as.numeric(cm$byClass[6]), 3), cex=1.2)   text(90, 85, names(cm$byClass[7]), cex=1.2, font=2)   text(90, 70, round(as.numeric(cm$byClass[7]), 3), cex=1.2)    # add in the accuracy information    text(30, 35, names(cm$overall[1]), cex=1.5, font=2)   text(30, 20, round(as.numeric(cm$overall[1]), 3), cex=1.4)   text(70, 35, names(cm$overall[2]), cex=1.5, font=2)   text(70, 20, round(as.numeric(cm$overall[2]), 3), cex=1.4) }   

Finally, pass in the cm object that we calculated when using caret to create the confusion matrix:

draw_confusion_matrix(cm) 

And here are the results:

visualization of confusion matrix from caret package

like image 107
Cybernetic Avatar answered Sep 22 '22 09:09

Cybernetic