I am trying to lookup values from other columns in my data frame/tibble that are dependent on the value in column var. I can achieve this by hardcoding them in case_when()
:
library(tidyverse)
set.seed(1)
ds <- tibble(var = paste0("x", sample(1:3, 10, replace = T)),
x1 = 0:9,
x2 = 100:109,
x3 = 1000:1009)
ds %>%
mutate(result = case_when(var == "x1" ~ x1,
var == "x2" ~ x2,
var == "x3" ~ x3))
#> # A tibble: 10 x 5
#> var x1 x2 x3 result
#> <chr> <int> <int> <int> <int>
#> 1 x1 0 100 1000 0
#> 2 x3 1 101 1001 1001
#> 3 x1 2 102 1002 2
#> 4 x2 3 103 1003 103
#> 5 x1 4 104 1004 4
#> 6 x3 5 105 1005 1005
#> 7 x3 6 106 1006 1006
#> 8 x2 7 107 1007 107
#> 9 x2 8 108 1008 108
#> 10 x3 9 109 1009 1009
However, What if I don't have just 3 columns but many xn?
I found that the following works for an external variable/object:
y <- "x2"
ds %>%
mutate(result = !!sym(y))
#> # A tibble: 10 x 5
#> var x1 x2 x3 result
#> <chr> <int> <int> <int> <int>
#> 1 x1 0 100 1000 100
#> 2 x3 1 101 1001 101
#> 3 x1 2 102 1002 102
#> 4 x2 3 103 1003 103
#> 5 x1 4 104 1004 104
#> 6 x3 5 105 1005 105
#> 7 x3 6 106 1006 106
#> 8 x2 7 107 1007 107
#> 9 x2 8 108 1008 108
#> 10 x3 9 109 1009 109
But it doesn't work for an internal variable/column in a tibble:
ds %>%
mutate(result = !!sym(var))
#> Error: Only strings can be converted to symbols
Created on 2021-05-24 by the reprex package (v2.0.0)
Any ideas of how to get this to work within a data frame/tibble column are greatly appreciated.
There are two solutions I can think of. The first is syntactically a bit cleaner and uses rowwise()
along with get()
:
ds %>%
rowwise() %>%
mutate(result = get(var)) %>%
ungroup()
#> # A tibble: 10 x 5
#> var x1 x2 x3 result
#> <chr> <int> <int> <int> <int>
#> 1 x1 0 100 1000 0
#> 2 x3 1 101 1001 1001
#> 3 x1 2 102 1002 2
#> 4 x2 3 103 1003 103
#> 5 x1 4 104 1004 4
#> 6 x3 5 105 1005 1005
#> 7 x3 6 106 1006 1006
#> 8 x2 7 107 1007 107
#> 9 x2 8 108 1008 108
#> 10 x3 9 109 1009 1009
The second uses purrr::pmap()
so could be considered a bit more advanced. However it has the advantages of being somewhat faster and more concise:
ds %>%
mutate(result = pmap_int(., function(var, ...) c(...)[var]))
#> # A tibble: 10 x 5
#> var x1 x2 x3 result
#> <chr> <int> <int> <int> <int>
#> 1 x1 0 100 1000 0
#> 2 x3 1 101 1001 1001
#> 3 x1 2 102 1002 2
#> 4 x2 3 103 1003 103
#> 5 x1 4 104 1004 4
#> 6 x3 5 105 1005 1005
#> 7 x3 6 106 1006 1006
#> 8 x2 7 107 1007 107
#> 9 x2 8 108 1008 108
#> 10 x3 9 109 1009 1009
One other option I just thought of is to programatically construct a call to case_when()
. This might look something like the following:
# Define a function to construct a `case_when()` call:
x <- switch_cols <- function(var) {
vals <- unique(var)
name <- deparse(substitute(var))
formulae <- lapply(
sprintf("%s == '%s' ~ %s", name, vals, vals),
as.formula,
env = parent.frame()
)
case_when(!!!formulae)
}
ds %>%
mutate(result = switch_cols(var))
#> # A tibble: 10 x 5
#> var x1 x2 x3 result
#> <chr> <int> <int> <int> <int>
#> 1 x1 0 100 1000 0
#> 2 x3 1 101 1001 1001
#> 3 x1 2 102 1002 2
#> 4 x2 3 103 1003 103
#> 5 x1 4 104 1004 4
#> 6 x3 5 105 1005 1005
#> 7 x3 6 106 1006 1006
#> 8 x2 7 107 1007 107
#> 9 x2 8 108 1008 108
#> 10 x3 9 109 1009 1009
We can test the performance using microbenchmark()
. I've also included @akrun's base R solution for completeness:
microbenchmark::microbenchmark(
rowwise = ds %>%
rowwise() %>%
mutate(result = get(var)) %>%
ungroup(),
purrr = ds %>%
mutate(result = purrr::pmap_int(., function(var, ...) c(...)[var])),
functional = ds %>%
mutate(result = switch_cols(var)),
base1 = ds %>%
mutate(result = as.data.frame(.[-1])[cbind(dplyr::row_number(),
match(var, names(.)[-1]))]),
base2 = ds$result <- as.data.frame(ds[-1])[cbind(seq_len(nrow(ds)),
match(ds$var, names(ds)[-1]))]
)
#> Unit: microseconds
#> expr min lq mean median uq max neval
#> rowwise 5385.9 6347.3 10692.3 8127.9 12756.3 32893 100
#> purrr 2957.2 3698.2 5837.4 4533.2 7566.6 12317 100
#> functional 3098.4 3956.6 5625.8 4536.0 7124.5 12665 100
#> base1 3028.9 3867.3 5839.6 4525.5 7610.0 16408 100
#> base2 275.9 386.6 584.5 488.6 676.9 3996 100
Unsurprisingly, the 'pure' base R
approach is far-and-away the fastest option. The others are fairly comparable, except for rowwise()
which is a lot slower.
It would be much faster with a row/column indexing approach in base R
ds$result <- as.data.frame(ds[-1])[cbind(seq_len(nrow(ds)),
match(ds$var, names(ds)[-1]))]
ds$result
#[1] 0 1001 2 103 4 1005 1006 107 108 1009
Or the same in dplyr
construct`
ds %>%
mutate(result = as.data.frame(.[-1])[cbind(row_number(),
match(var, names(.)[-1]))])
# A tibble: 10 x 5
# var x1 x2 x3 result
# <chr> <int> <int> <int> <int>
# 1 x1 0 100 1000 0
# 2 x3 1 101 1001 1001
# 3 x1 2 102 1002 2
# 4 x2 3 103 1003 103
# 5 x1 4 104 1004 4
# 6 x3 5 105 1005 1005
# 7 x3 6 106 1006 1006
# 8 x2 7 107 1007 107
# 9 x2 8 108 1008 108
#10 x3 9 109 1009 1009
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With