Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Creating geom / stat from scratch

I just started working with R not long ago, and I am currently trying to strengthen my visualization skills. What I want to do is to create boxplots with mean diamonds as a layer on top (see picture in the link below). I did not find any functions that does this already, so I guess I have to create it myself.

Link: Boxplot and mean diamonds

What I was hoping to do was to create a geom or a stat that would allow something like this to work:

ggplot(data, aes(...))) + 
   geom_boxplot(...) +
   geom_meanDiamonds(...)

I have no idea where to start in order to build this new function. I know which values are needed for the mean diamonds (mean and confidence interval), but I do not know how to build the geom / stat that takes the data from ggplot(), calculates the mean and CI for each group, and plots a mean diamond on top of each boxplot.

I have searched for detailed descriptions on how to build these type of functions from scratch, however, I have not found anything that really starts from the bottom. I would really appreciate it, if anyone could point me towards some useful guides.

Thank you!

like image 644
MagKvis Avatar asked Sep 27 '18 14:09

MagKvis


1 Answers

I'm currently learning to write geoms myself, so this is going to be a rather long & rambling post as I go through my thought processes, untangling the Geom aspects (creating polygons & line segments) from the Stats aspects (calculating where these polygons & segments should be) of a geom.

Disclaimer: I'm not familiar with this kind of plot, and Google didn't throw up many authoritative guides. My understanding of how the confidence interval is calculated / used here may be off.

Step 0. Understand the relationship between a geom / stat and a layer function.

geom_boxplot and stat_boxplot are examples of layer functions. If you enter them into the R console, you'll see that they are (relatively) short, and does not contain actual code for calculating the box / whiskers of the boxplot. Instead, geom_boxplot contains a line that says geom = GeomBoxplot, while stat_boxplot contains a line that says stat = StatBoxplot (reproduced below).

> stat_boxplot
function (mapping = NULL, data = NULL, geom = "boxplot", position = "dodge2", 
    ..., coef = 1.5, na.rm = FALSE, show.legend = NA, inherit.aes = TRUE) 
{
    layer(data = data, mapping = mapping, stat = StatBoxplot, 
        geom = geom, position = position, show.legend = show.legend, 
        inherit.aes = inherit.aes, params = list(na.rm = na.rm, 
            coef = coef, ...))
}

GeomBoxplot and StatBoxplot are ggproto objects. They are where the magic happens.

Step 1. Recognise that ggproto()'s _inherit parameter is your friend.

Don't reinvent the wheel. Since we want to create something that overlaps nicely with a boxplot, we can take reference from the Geom / Stat used for that, and only change what's necessary.

StatMeanDiamonds <- ggproto(
  `_class` = "StatMeanDiamonds",
  `_inherit` = StatBoxplot,
  ... # add functions here to override those defined in StatBoxplot
)

GeomMeanDiamonds <- ggproto(
  `_class` = "GeomMeanDiamonds",
  `_inherit` = GeomBoxplot,
  ... # as above
)

Step 2. Modify the Stat.

There are 3 functions defined within StatBoxplot: setup_data, setup_params, and compute_group. You can refer to the code on Github (link above) for the details, or view them by entering for example StatBoxplot$compute_group.

The compute_group function calculates the ymin / lower / middle / upper / ymax values for all the y values associated with each group (i.e. each unique x value), which are used to plot the box plot. We can override it with one that calculates the confidence interval & mean values instead:

# ci is added as a parameter, to allow the user to specify different confidence intervals
compute_group_new <- function(data, scales, width = NULL, 
                              ci = 0.95, na.rm = FALSE){
  a <- mean(data$y)
  s <- sd(data$y)
  n <- sum(!is.na(data$y))
  error <- qt(ci + (1-ci)/2, df = n-1) * s / sqrt(n)
  stats <- c("lower" = a - error, "mean" = a, "upper" = a + error)

  if(length(unique(data$x)) > 1) width <- diff(range(data$x)) * 0.9

  df <- as.data.frame(as.list(stats))

  df$x <- if(is.factor(data$x)) data$x[1] else mean(range(data$x))
  df$width <- width

  df
}

(Optional) StatBoxplot has provision for the user to include weight as an aesthetic mapping. We can allow for that as well, by replacing:

  a <- mean(data$y)
  s <- sd(data$y)
  n <- sum(!is.na(data$y))

with:

  if(!is.null(data$weight)) {
    a <- Hmisc::wtd.mean(data$y, weights = data$weight)
    s <- sqrt(Hmisc::wtd.var(data$y, weights = data$weight))
    n <- sum(data$weight[!is.na(data$y) & !is.na(data$weight)])
  } else {
    a <- mean(data$y)
    s <- sd(data$y)
    n <- sum(!is.na(data$y))
  }

There's no need to change the other functions in StatBoxplot. So we can define StatMeanDiamonds as follows:

StatMeanDiamonds <- ggproto(
  `_class` = "StatMeanDiamonds",
  `_inherit` = StatBoxplot,
  compute_group = compute_group_new
)

Step 3. Modify the Geom.

GeomBoxplot has 3 functions: setup_data, draw_group, and draw_key. It also includes definitions for default_aes() and required_aes().

Since we've changed the upstream data source (the data produced by StatMeanDiamonds contain the calculated columns "lower" / "mean" / "upper", while the data produced by StatBoxplot would have contained the calculated columns "ymin" / "lower" / "middle" / "upper" / "ymax"), do check whether the downstream setup_data function is affected as well. (In this case, GeomBoxplot$setup_data makes no reference to the affected columns, so no changes required here.)

The draw_group function takes the data produced by StatMeanDiamonds and set up by setup_data, and produces multiple data frames. "common" contains the aesthetic mappings common to all geoms. "diamond.df" for the mappings that contribute towards the diamond polygon, and "segment.df" for the mappings that contribute towards the horizontal line segment at the mean. The data frames are then passed to the draw_panel functions of GeomPolygon and GeomSegment respectively, to produce the actual polygons / line segments.

draw_group_new = function(data, panel_params, coord,
                      varwidth = FALSE){
  common <- data.frame(colour = data$colour, 
                       size = data$size,
                       linetype = data$linetype, 
                       fill = alpha(data$fill, data$alpha),
                       group = data$group, 
                       stringsAsFactors = FALSE)
  diamond.df <- data.frame(x = c(data$x, data$xmax, data$x, data$xmin),
                           y = c(data$upper, data$mean, data$lower, data$mean),
                           alpha = data$alpha,
                           common,
                           stringsAsFactors = FALSE)
  segment.df <- data.frame(x = data$xmin, xend = data$xmax,
                           y = data$mean, yend = data$mean,
                           alpha = NA,
                           common,
                           stringsAsFactors = FALSE)
  ggplot2:::ggname("geom_meanDiamonds",
                   grid::grobTree(
                     GeomPolygon$draw_panel(diamond.df, panel_params, coord),
                     GeomSegment$draw_panel(segment.df, panel_params, coord)
                   ))
}

The draw_key function is used to create the legend for this layer, should the need arise. Since GeomMeanDiamonds inherits from GeomBoxplot, the default is draw_key = draw_key_boxplot, and we don't have to change it. Leaving it unchanged will not break the code. However, I think a simpler legend such as draw_key_polygon offers a less cluttered look.

GeomBoxplot's default_aes specifications look fine. But we need to change the required_aes since the data we expect to get from StatMeanDiamonds is different ("lower" / "mean" / "upper" instead of "ymin" / "lower" / "middle" / "upper" / "ymax").

We are now ready to define GeomMeanDiamonds:

GeomMeanDiamonds <- ggproto(
  "GeomMeanDiamonds",
  GeomBoxplot,
  draw_group = draw_group_new,
  draw_key = draw_key_polygon,
  required_aes = c("x", "lower", "upper", "mean")
)

Step 4. Define the layer functions.

This is the boring part. I copied from geom_boxplot / stat_boxplot directly, removing all references to outliers in geom_meanDiamonds, changing to geom = GeomMeanDiamonds / stat = StatMeanDiamonds, and adding ci = 0.95 to stat_meanDiamonds.

geom_meanDiamonds <- function(mapping = NULL, data = NULL,
                              stat = "meanDiamonds", position = "dodge2",
                              ..., varwidth = FALSE, na.rm = FALSE, show.legend = NA,
                              inherit.aes = TRUE){
  if (is.character(position)) {
    if (varwidth == TRUE) position <- position_dodge2(preserve = "single")
  } else {
    if (identical(position$preserve, "total") & varwidth == TRUE) {
      warning("Can't preserve total widths when varwidth = TRUE.", call. = FALSE)
      position$preserve <- "single"
    }
  }
  layer(data = data, mapping = mapping, stat = stat,
        geom = GeomMeanDiamonds, position = position,
        show.legend = show.legend, inherit.aes = inherit.aes,
        params = list(varwidth = varwidth, na.rm = na.rm, ...))
}

stat_meanDiamonds <- function(mapping = NULL, data = NULL,
                         geom = "meanDiamonds", position = "dodge2",
                         ..., ci = 0.95,
                         na.rm = FALSE, show.legend = NA, inherit.aes = TRUE) {
  layer(data = data, mapping = mapping, stat = StatMeanDiamonds,
        geom = geom, position = position, show.legend = show.legend,
        inherit.aes = inherit.aes,
        params = list(na.rm = na.rm, ci = ci, ...))
}

Step 5. Check output.

# basic
ggplot(iris, 
       aes(Species, Sepal.Length)) +
  geom_boxplot() +
  geom_meanDiamonds()

# with additional parameters, to see if they break anything
ggplot(iris, 
       aes(Species, Sepal.Length)) +
  geom_boxplot(width = 0.8) +
  geom_meanDiamonds(aes(fill = Species),
                    color = "red", alpha = 0.5, size = 1, 
                    ci = 0.99, width = 0.3)

plot

like image 192
Z.Lin Avatar answered Sep 30 '22 11:09

Z.Lin