Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Collapsing one hot encoded columns based on conditional in R dplyr

Tags:

r

dplyr

I have this table that contains variables that are one hot encoded. I want to collapse these variables into one column. For example, any columns that have "high", "med", or "low", I want to be one column with numerical encodings for high = 0, med = 1, and low = 2. How can I do this in dplyr R? I suspect pivoting will help but I'm not sure where to start. The resulting column name should contain the name of the three columns without the high,med,low designation. For example, I would transform columns d-high_cm1, d-med_cm1, d-low_cm1 to d-cm1 with the numerical encodings.

input:

sex age    cost_cm  d-high_cm1 d-med_cm1 d-low_cm1 c-high_cm1 c-med_cm1 c-low_cm1
f   old    1        1           0         0           1           0         0
m   young  0        1           0         0           1           0         0
m   old    0        0           1         0           0           1         0
f   young  0        1           0         0           0           0         1
m   old    1        0           0         1           0           0         1

expected output:

sex age    cost_cm  d-cm1 c-cm1 
f   old    1        0     0
m   young  0        0     0
m   old    0        1     1
f   young  0        0     2
m   old    1        2     2

like image 835
Eisen Avatar asked Dec 22 '21 18:12

Eisen


3 Answers

We may do

library(stringr)
library(dplyr)
library(tidyr)
df1 %>% 
   mutate(across(contains("-"), ~ case_when(str_detect(cur_column(),
      'low') ~ . * 2, str_detect(cur_column(), 'med')  ~ . * 1,
        TRUE ~ .* 0))) %>%
   rename_with(~ str_replace(., "-(\\w+)_(\\w+)", "-\\2_\\1"), contains('-')) %>% 
   pivot_longer(cols = contains('-'), names_to = c(".value"), 
      names_pattern = "^([^_]+)_.*")%>% 
   group_by(sex, age, cost_cm) %>% 
   summarise(across(everything(), max), .groups = 'drop')

-output

# A tibble: 5 × 5
  sex   age   cost_cm `d-cm1` `c-cm1`
  <chr> <chr>   <int>   <dbl>   <dbl>
1 f     old         1       0       0
2 f     young       0       0       2
3 m     old         0       1       1
4 m     old         1       2       2
5 m     young       0       0       0

Or using base R

lst1 <- lapply(split.default(df1[-c(1:3)], sub("-[^_]+", "", 
    names(df1)[-(1:3)])), function(x) do.call(pmax, x *  (0:2)[col(x)]))
cbind(df1[1:3], lst1)

-output

   sex   age cost_cm c_cm1 d_cm1
1   f   old       1     0     0
2   m young       0     0     0
3   m   old       0     1     1
4   f young       0     2     0
5   m   old       1     2     2

data

df1 <- structure(list(sex = c("f", "m", "m", "f", "m"), age = c("old", 
"young", "old", "young", "old"), cost_cm = c(1L, 0L, 0L, 0L, 
1L), `d-high_cm1` = c(1L, 1L, 0L, 1L, 0L), `d-med_cm1` = c(0L, 
0L, 1L, 0L, 0L), `d-low_cm1` = c(0L, 0L, 0L, 0L, 1L), `c-high_cm1` = c(1L, 
1L, 0L, 0L, 0L), `c-med_cm1` = c(0L, 0L, 1L, 0L, 0L), `c-low_cm1` = c(0L, 
0L, 0L, 1L, 1L)), class = "data.frame", row.names = c(NA, -5L
))
like image 147
akrun Avatar answered Oct 18 '22 07:10

akrun


Here's an alternative,

library(dplyr)
library(tidyr) # pivot_*, unite
dat %>%
  pivot_longer(
    -c(sex, age, cost_cm),
    names_pattern = "([^.]+)-([a-z]+)_(.*)",
    names_to = c("ltr", "fctr", "key")
  ) %>%
  left_join(fctrs, by = "fctr") %>%
  mutate(value = value * fctrval) %>%
  unite("key", ltr, key) %>% 
  group_by(sex, age, cost_cm, key) %>%
  summarize(value = max(value)) %>%
  ungroup() %>%
  pivot_wider(c(sex, age, cost_cm), names_from = "key", values_from = "value")
# # A tibble: 5 x 5
#   sex   age   cost_cm c_cm1 d_cm1
#   <chr> <chr>   <int> <dbl> <dbl>
# 1 f     old         1     0     0
# 2 f     young       0     2     0
# 3 m     old         0     1     1
# 4 m     old         1     2     2
# 5 m     young       0     0     0
like image 20
r2evans Avatar answered Oct 18 '22 06:10

r2evans


Another possible solution:

library(tidyverse)

df %>% 
  mutate(across(contains("high"), ~ 0),
         across(contains("low"), ~ ifelse(.x == 1,2,0))) %>% 
  mutate(`d-cm1` = rowSums(.[,4:6]), `c-cm1` = rowSums(.[,7:9])) %>% 
  select(-(4:9))

#>   sex   age cost_cm d-cm1 c-cm1
#> 1   f   old       1     0     0
#> 2   m young       0     0     0
#> 3   m   old       0     1     1
#> 4   f young       0     0     2
#> 5   m   old       1     2     2
like image 1
PaulS Avatar answered Oct 18 '22 05:10

PaulS