Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

stat_function and legends: create plot with two separate colour legends mapped to different variables

Tags:

r

ggplot2

I would like to combine two different types of plots in one image with ggplot2. Here's the code I use:

fun.bar <- function(x, param = 4) {
  return(((x + 1) ^ (1 - param)) / (1 - param))
}

plot.foo <- function(df, par = c(1.7, 2:8)) {
  require(ggplot2)
  require(reshape2)
  require(RColorBrewer)
  melt.df <- melt(df)
  melt.df$ypos <- as.numeric(melt.df$variable)
  p <- ggplot(data = melt.df, aes(x = value, y = ypos, colour = variable)) +
    geom_point(position = "jitter", alpha = 0.2, size = 2) + 
    xlim(-1, 1) + ylim(-5, 5) + 
    guides(colour = 
      guide_legend("Type", override.aes = list(alpha = 1, size = 4)))
 pal <- brewer.pal(length(par), "Set1")
 for (i in seq_along(par)) {
   p <- p + stat_function(fun = fun.bar, 
     arg = list(param = par[i]), colour = pal[i], size = 1.3)
  }
  p
}

df.foo <- data.frame(A=rnorm(1000, sd=0.25), 
  B=rnorm(1000, sd=0.25), C=rnorm(1000, sd=0.25))
plot.foo(df.foo)

As a result, I get the following picture. my_plot However, I'd like to have another legend with colours from red to pink, displaying information about parameters of curves in the lower part of the plot. The problem is the key aesthetics for both parts is the colour, so manual overriding via scale_colour_manual() destroys the existing legend.

I understand there's a "one aesthetic -- one legend" concept, but how can I bypass this restriction in this specific case?

like image 997
tonytonov Avatar asked Oct 07 '13 07:10

tonytonov


2 Answers

When looking at previous examples of stat_function and legend on SO, I got the impression that it is not very easy to make the two live happily together without some hard-coding of each curve generated by stat_summary (I would be happy to find that I am wrong). See e.g. here, here, and here. In the last answer @baptiste wrote: "you'll be better off building a data.frame before plotting". That's what I try in my answer: I pre-calculated data using the function, and then use geom_line instead of stat_summary in the plot.

# load relevant packages
library(ggplot2)
library(reshape2)
library(RColorBrewer)
library(gridExtra)
library(gtable)
library(plyr)

# create base data
df <- data.frame(A = rnorm(1000, sd = 0.25), 
                 B = rnorm(1000, sd = 0.25),
                 C = rnorm(1000, sd = 0.25))    
melt.df <- melt(df)
melt.df$ypos <- as.numeric(melt.df$variable)

# plot points only, to get a colour legend for points
p1 <- ggplot(data = melt.df, aes(x = value, y = ypos, colour = variable)) +
  geom_point(position = "jitter", alpha = 0.2, size = 2) + 
  xlim(-1, 1) + ylim(-5, 5) +
  guides(colour = 
           guide_legend("Type", override.aes = list(alpha = 1, size = 4)))

p1

# grab colour legend for points
legend_points <- gtable_filter(ggplot_gtable(ggplot_build(p1)), "guide-box")

# grab colours for points. To be used in final plot
point_cols <- unique(ggplot_build(p1)[["data"]][[1]]$colour)


# create data for lines
# define function for lines
fun.bar <- function(x, param = 4) {
  return(((x + 1) ^ (1 - param)) / (1 - param))
}

# parameters for lines
pars = c(1.7, 2:8)

# for each value of parameters and x (i.e. x = melt.df$value),
# calculate ypos for lines
df2 <- ldply(.data = pars, .fun = function(pars){
  ypos = fun.bar(melt.df$value, pars)
  data.frame(pars = pars, value = melt.df$value, ypos)
})

# colour palette for lines
line_cols <- brewer.pal(length(pars), "Set1")    

# plot lines only, to get a colour legends for lines
# please note that when using ylim:
# "Observations not in this range will be dropped completely and not passed to any other layers"
# thus the warnings
p2 <- ggplot(data = df2,
             aes(x = value, y = ypos, group = pars, colour = as.factor(pars))) +
  geom_line() +
  xlim(-1, 1) + ylim(-5, 5) +
  scale_colour_manual(name = "Param", values = line_cols, labels = as.character(pars))

p2

# grab colour legend for lines
legend_lines <- gtable_filter(ggplot_gtable(ggplot_build(p2)), "guide-box") 


# plot both points and lines with legend suppressed
p3 <- ggplot(data = melt.df, aes(x = value, y = ypos)) +
  geom_point(aes(colour = variable),
             position = "jitter", alpha = 0.2, size = 2) +
  geom_line(data = df2, aes(group = pars, colour = as.factor(pars))) +
  xlim(-1, 1) + ylim(-5, 5) +
  theme(legend.position = "none") +
  scale_colour_manual(values = c(line_cols, point_cols))
  # the colours in 'scale_colour_manual' are added in the order they appear in the legend
  # line colour (2, 3) appear before point cols (A, B, C)
  # slightly hard-coded
  # see alternative below

p3

# arrange plot and legends for points and lines with viewports
# define plotting regions (viewports)
# some hard-coding of positions
grid.newpage()
vp_plot <- viewport(x = 0.45, y = 0.5,
                    width = 0.9, height = 1)

vp_legend_points <- viewport(x = 0.91, y = 0.7,
                      width = 0.1, height = 0.25)

vp_legend_lines <- viewport(x = 0.93, y = 0.35,
                         width = 0.1, height = 0.75)

# add plot
print(p3, vp = vp_plot)

# add legend for points
upViewport(0)
pushViewport(vp_legend_points)
grid.draw(legend_points)

# add legend for lines
upViewport(0)
pushViewport(vp_legend_lines)
grid.draw(legend_lines)

enter image description here

# A second alternative, with greater control over the colours
# First, plot both points and lines with colour legend suppressed
# let ggplot choose the colours
p3 <- ggplot(data = melt.df, aes(x = value, y = ypos)) +
  geom_point(aes(colour = variable),
             position = "jitter", alpha = 0.2, size = 2) +
  geom_line(data = df2, aes(group = pars, colour = as.factor(pars))) +
  xlim(-1, 1) + ylim(-5, 5) +
  theme(legend.position = "none")

p3

# build p3 for rendering
# get a list of data frames (one for each layer) that can be manipulated
pp3 <- ggplot_build(p3)

# grab the whole vector of point colours from plot p1
point_cols_vec <- ggplot_build(p1)[["data"]][[1]]$colour

# grab the whole vector of line colours from plot p2
line_cols_vec <- ggplot_build(p2)[["data"]][[1]]$colour

# replace 'colour' values for points, with the colours from plot p1
# points are in the first layer -> first element in the 'data' list
pp3[["data"]][[1]]$colour <- point_cols_vec

# replace 'colour' values for lines, with the colours from plot p2
# lines are in the second layer -> second element in the 'data' list
pp3[["data"]][[2]]$colour <- line_cols_vec

# build a plot grob from the data generated by ggplot_build
# to be used in grid.draw below
grob3 <- ggplot_gtable(pp3)

# arrange plot and the two legends with viewports
# define plotting regions (viewports)
vp_plot <- viewport(x = 0.45, y = 0.5,
                    width = 0.9, height = 1)

vp_legend_points <- viewport(x = 0.91, y = 0.7,
                             width = 0.1, height = 0.25)

vp_legend_lines <- viewport(x = 0.92, y = 0.35,
                            width = 0.1, height = 0.75)

grid.newpage()

pushViewport(vp_plot)
grid.draw(grob3)

upViewport(0)
pushViewport(vp_legend_points)
grid.draw(legend_points)

upViewport(0)
pushViewport(vp_legend_lines)
grid.draw(legend_lines)
like image 156
Henrik Avatar answered Nov 17 '22 04:11

Henrik


I'd like to share a quick hack I used while waiting for an answer to this question.

fun.bar <- function(x, param = 4) {
  return(((x + 1) ^ (1 - param)) / (1 - param))
}

plot.foo <- function(df, par = c(1.7, 2:8)) {
  require(ggplot2)
  require(reshape2)
  require(RColorBrewer)
  melt.df <- melt(df)
  melt.df$ypos <- as.numeric(melt.df$variable)
  # the trick is to override factor levels
  levels(melt.df$variable) <- 1:nlevels(melt.df$variable)
  p <- ggplot(data = melt.df, aes(x = value, y = ypos, colour = variable)) +
    geom_point(position = "jitter", alpha = 0.2, size = 2) + 
    xlim(-1, 1) + ylim(-5, 5) + 
    guides(colour = 
      guide_legend("Type", override.aes = list(alpha = 1, size = 4)))
  pal <- brewer.pal(length(par), "Set1")
  for (i in seq_along(par)) {
    p <- p + stat_function(fun = fun.bar, 
      arg = list(param = par[i]), colour = pal[i], size = 1.3)
  }
  # points are displayed by supplying values for manual scale
  p + scale_colour_manual(values = pal, limits = seq_along(par), labels = par) + 
  # this needs proper "for" cycle to remove hardcoded labels
  annotate("text", x = 0.8, y = 1, label = "A", size = 8) +
  annotate("text", x = 0.8, y = 2, label = "B", size = 8) +
  annotate("text", x = 0.8, y = 3, label = "C", size = 8)
}

df.foo <- data.frame(A=rnorm(1000, sd=0.25), 
  B=rnorm(1000, sd=0.25), C=rnorm(1000, sd=0.25))
plot.foo(df.foo)

enter image description here This workaround is not even close to being so awesome as the answer provided by @Henrik, but suited my one-time needs.

like image 3
tonytonov Avatar answered Nov 17 '22 03:11

tonytonov