Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Quadratic Discriminant Analysis (QDA) plot in R

Tags:

r

ggplot2

I am trying to plot the results of Iris dataset Quadratic Discriminant Analysis (QDA) using MASS and ggplot2 packages. The script show in its first part, the Linear Discriminant Analysis (LDA) but I but I do not know to continue to do it for the QDA. The objects of class "qda" are a bit different from the "lda" class objects, for example: I can not find the Proportion of trace/X% of explained between-group Variance/discriminant components and can not add them to the graph axes. Any help or ideas how to code this graph using ggplot2?

Code:

require(MASS)
require(ggplot2)
require(scales)
 

irislda <- lda(Species ~ ., iris)
prop.lda = irislda$svd^2/sum(irislda$svd^2)
plda <- predict(irislda,   iris)

datasetLDA = data.frame(species = iris[,"Species"], irislda = plda$x)
ggplot(datasetLDA) + geom_point(aes(irislda.LD1, irislda.LD2, colour = species, shape = species), size = 2.5) + 
    labs(x = paste("LD1 (", percent(prop.lda[1]), ")", sep=""),
       y = paste("LD2 (", percent(prop.lda[2]), ")", sep=""))

 
irisqda <- qda(Species ~ ., iris)
pqda <- predict(irisqda,   iris)
datasetQDA = data.frame(species = iris[,"Species"], irisqda = pqda$posterior) 
ggplot(datasetQDA) + geom_point(???, ???, colour = species, shape = species), size = 2.5)
like image 225
JohnSal Avatar asked Dec 18 '22 12:12

JohnSal


1 Answers

Following Ducks comment, if you have only 2 dimensions we can use the decisionplot function provided in the link to visualize these. It has to be altered slightly for more variables.

library(MASS)
model <- qda(Species ~ Sepal.Length + Sepal.Width, iris)
decisionplot(model, iris, class = "Species")

base plot

The decisionplot function is shown below.

decisionplot <- function(model, data, class = NULL, predict_type = "class",
  resolution = 100, showgrid = TRUE, ...) {

  if(!is.null(class)) cl <- data[,class] else cl <- 1
  data <- data[,1:2]
  k <- length(unique(cl))

  plot(data, col = as.integer(cl)+1L, pch = as.integer(cl)+1L, ...)

  # make grid
  r <- sapply(data, range, na.rm = TRUE)
  xs <- seq(r[1,1], r[2,1], length.out = resolution)
  ys <- seq(r[1,2], r[2,2], length.out = resolution)
  g <- cbind(rep(xs, each=resolution), rep(ys, time = resolution))
  colnames(g) <- colnames(r)
  g <- as.data.frame(g)

  ### guess how to get class labels from predict
  ### (unfortunately not very consistent between models)
  p <- predict(model, g, type = predict_type)
  if(is.list(p)) p <- p$class
  p <- as.factor(p)

  if(showgrid) points(g, col = as.integer(p)+1L, pch = ".")

  z <- matrix(as.integer(p), nrow = resolution, byrow = TRUE)
  contour(xs, ys, z, add = TRUE, drawlabels = FALSE,
    lwd = 2, levels = (1:(k-1))+.5)

  invisible(z)
}

If we wanted to recreate this with ggplot2 we'd simply have to change the function to utilize ggplot2 functions rather than base plots. This entails changing the data into data.frames and building the plot along the way.

decisionplot_ggplot <- function(model, data, class = NULL, predict_type = "class",
                         resolution = 100, showgrid = TRUE, ...) {
  
  if(!is.null(class)) cl <- data[,class] else cl <- 1
  data <- data[,1:2]
  cn <- colnames(data)
  
  k <- length(unique(cl))
  
  data$pch <- data$col <- as.integer(cl) + 1L
  gg <- ggplot(aes_string(cn[1], cn[2]), data = data) + 
    geom_point(aes_string(col = 'as.factor(col)', shape = 'as.factor(col)'), size = 3)
  
  # make grid
  r <- sapply(data[, 1:2], range, na.rm = TRUE)
  xs <- seq(r[1, 1], r[2, 1], length.out = resolution)
  ys <- seq(r[1, 2], r[2, 2], length.out = resolution)
  
  g <- cbind(rep(xs, each = resolution), 
             rep(ys, time = resolution))
  colnames(g) <- colnames(r)
  
  g <- as.data.frame(g)
  
  ### guess how to get class labels from predict
  ### (unfortunately not very consistent between models)
  p <- predict(model, g, type = predict_type)
  if(is.list(p)) p <- p$class
  g$col <- g$pch <- as.integer(as.factor(p)) + 1L
  
  if(showgrid) 
    gg <- gg + geom_point(aes_string(x = cn[1], y = cn[2], col = 'as.factor(col)'), data = g, shape = 20, size = 1)
  
  gg + geom_contour(aes_string(x = cn[1], y = cn[2], z = 'col'), data = g, inherit.aes = FALSE)
}

Usage:

decisionplot_ggplot(model, iris, class = "Species")

Note it now returns the ggplot itself, so one could use standard functions to change the title, theme etc. Also this is simply a direct translation. Using geom_polygon with a valid alpha would likely be more visually pleasing. Similar better contours could be made with an alternative choice of geom_*. ggplot

like image 62
Oliver Avatar answered Jan 13 '23 19:01

Oliver