Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

`data` must be a data frame, or other object coercible by `fortify()`, not an S3 object with class ranger

I am working with R. Using a tutorial, I was able to create a statistical model and produce visual plots for some of the outputs:

#load libraries
library(survival)

library(dplyr)

library(ranger)

library(data.table)

library(ggplot2)

#use the built in "lung" data set
#remove missing values (dataset is called "a")

a <- na.omit(lung)

#create id variable

a$ID <- seq_along(a[,1])

#create test set with only the first 3 rows

new <- a[1:3,]

#create a training set by removing first three rows

a <- a[-c(1:3),]



#fit survival model (random survival forest)

r_fit <- ranger(Surv(time,status) ~ age + sex + ph.ecog + ph.karno + pat.karno + meal.cal + wt.loss, data = a, mtry = 4, importance = "permutation", splitrule = "extratrees", verbose = TRUE)

#create new intermediate variables required for the survival curves

death_times <- r_fit$unique.death.times

surv_prob <- data.frame(r_fit$survival)

avg_prob <- sapply(surv_prob, mean)

#use survival model to produce estimated survival curves for the first three observations

pred <- predict(r_fit, new, type = 'response')$survival

pred <- data.table(pred)

colnames(pred) <- as.character(r_fit$unique.death.times)

#plot the results for these 3 patients

plot(r_fit$unique.death.times, pred[1,], type = "l", col = "red")

lines(r_fit$unique.death.times, pred[2,], type = "l", col = "green")

lines(r_fit$unique.death.times, pred[3,], type = "l", col = "blue")

enter image description here

Now, I am trying to convert the above plot into ggplot format (and add 95% confidence intervals):

ggplot(r_fit) + geom_line(aes(x = r_fit$unique.death.times, y = pred[1,], group = 1), color = red)  +  geom_ribbon(aes(ymin = 0.95 * pred[1,], ymax = - 0.95 * pred[1,]), fill = "red") + geom_line(aes(x = r_fit$unique.death.times, y = pred[2,], group = 1), color = blue) + geom_ribbon(aes(ymin = 0.95 * pred[2,], ymax = - 0.95 * pred[2,]), fill = "blue") + geom_line(aes(x = r_fit$unique.death.times, y = pred[3,], group = 1), color = green) + geom_ribbon(aes(ymin = 0.95 * pred[3,], ymax = - 0.95 * pred[3,]), fill = "green") + theme(axis.text.x = element_text(angle = 90)) + ggtitle("sample graph")

But this produces the following error:

Error: `data` must be a data frame, or other object coercible by `fortify()`, not an S3 object with class ranger
Run `rlang::last_error()` to see where the error occurred.

What is the reason for this error? Can someone please show me how to fix this problem?

Thanks

like image 635
stats_noob Avatar asked May 16 '21 16:05

stats_noob


People also ask

What does fortify function do in R?

Description. The fortify function converts an S3 object generated by evalmod to a data frame for ggplot2.

What does as data frame do in R?

as. data. frame() function in R Programming Language is used to convert an object to data frame. These objects can be Vectors, Lists, Matrices, and Factors.

How do you create a Dataframe in R?

How to Create a Data Frame. We can create a dataframe in R by passing the variable a,b,c,d into the data. frame() function. We can R create dataframe and name the columns with name() and simply specify the name of the variables.


1 Answers

As per the ggplot2 documentation, you need to provide a data.frame() or object that can be converted (coerced) to a data.frame(). In this case, if you want to reproduce the plot above in ggplot2, you will need to manually set up the data frame yourself.

Below is an example of how you could set up the data to display the plot in ggplot2.

Data Frame

First we create a data.frame() with the variables that we want to plot. The easiest way to do this is to just group them all in as separate columns. Note that I have used the as.numeric() function to first coerce the predicted values to a vector, because they were previously a data.table row, and if you don't convert them they are maintained as rows.

ggplot_data <- data.frame(unique.death.times = r_fit$unique.death.times,
                      pred1 = as.numeric(pred[1,]),
                      pred2 = as.numeric(pred[2,]),
                      pred3 = as.numeric(pred[3,]))
head(ggplot_data)
## unique.death.times     pred1     pred2     pred3
## 1                  5 0.9986676 1.0000000 0.9973369
## 2                 11 0.9984678 1.0000000 0.9824642
## 3                 12 0.9984678 0.9998182 0.9764154
## 4                 13 0.9984678 0.9998182 0.9627118
## 5                 15 0.9731656 0.9959416 0.9527424
## 6                 26 0.9731656 0.9959416 0.9093876

Pivot the data

This format is still not ideal, because in order to plot the data and colour by the correct column (variable), we need to 'pivot' the data. We need to load the tidyr package for this.

library(tidyr)
ggplot_data <- ggplot_data %>% 
  pivot_longer(cols = !unique.death.times, 
  names_to = "category", values_to = "predicted.value")

Plotting

Now the data is in a form that makes it really easy to plot in ggplot2.

plot <- ggplot(ggplot_data, aes(x = unique.death.times, y = predicted.value, colour = category)) +
      geom_line()
plot

ggplot

If you really want to match the look of the base plot, you can add theme_classic():

plot + theme_classic()

ggplot with theme_classic

Additional notes

Note that this doesn't include 95% confidence intervals, so they would have to be calculated separately. Be aware though, that a 95% confidence interval is not just 95% of the y value at a given x value. There are calculations that will give you the correct values of the confidence interval, including functions built into R.

For a quick view of a trend line with prediction intervals, you can use the geom_smooth() function in ggplot2, but in this case it adds a loess curve by default, and the intervals provided by that function.

plot + theme_classic() + geom_smooth()

ggplot with smooth trend

like image 174
Sam Rogers Avatar answered Nov 04 '22 13:11

Sam Rogers