Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

R: classifying a text variable

I have a data frame with a string variable representing diagnoses of diseases. I want to classify the diagnoses according to some rules:

rules <- list(
  group1 = c('A012', 'A02', 'C30'),
  group2 = c('B01', 'B02')
)

If the diagnosis field contains 'A01', 'A02' or 'A03' (including text that includes these patterns, e.g. 'A0199'), then the case should be put in group 1, and so on.

My data looks like this:

dat <- data.frame(
  ID = seq_len(10),
  diagnosis = c('A012', 'A01', 'B23', 'C43', 'B023', 'A99', 'A023', 'B012', 'B04', 'A07')
)

The method I've used looks like this:

# modify the rules so they work with grep    
rules <- lapply(rules, paste, collapse = '|')

# create a function that classifies an individual diagnosis
group <- function(y) {
      a <- sapply(rules, grepl, x = y)
      a <- names(a)[a]
      return(if (length(a) == 0) NA else a)
    }

# apply the function across the data frame
dat$group <- sapply(dat$diagnosis, group)

This seems to work, but my dataset is big and there are lots of rules, and it is extremely slow!

Are there faster ways I could do this?

like image 416
Dan Lewer Avatar asked Apr 29 '26 17:04

Dan Lewer


1 Answers

If the number of rules isn't too big (OP says it's only 40), we could just run over the rules and perform an exact match using stringi::stri_detect_fixed (which is much faster than using regex)

First, we will flatter rules

rules_dt <- list(rules = unlist(rules, use.names = FALSE), 
                 grp = rep(seq_len(length(rules)), lengths(rules))) 

Then, define the function

library(stringi)
f <- function(x) dat[stri_detect_fixed(dat$diagnosis, rules_dt$rules[x]), "group"] <<- rules_dt$grp[x]

Then, run it per rule

invisible(lapply(seq_len(length(rules_dt[[1]])), f))
dat
#    ID diagnosis group
# 1   1      A012     1
# 2   2       A02     1
# 3   3       B23    NA
# 4   4       C43    NA
# 5   5      B023     2
# 6   6       A99    NA
# 7   7      A023     1
# 8   8      B012     2
# 9   9       B04    NA
# 10 10       A07    NA

Benchmark: On .5MM rows and 10 groups of 10 it runs for about ~4 secs on my laptop

library(stringi)
n <- 10
N <- 5e5

set.seed(123)
rules <- setNames(replicate(n, 
                  stri_rand_strings(n = n, length = 4), simplify = FALSE), 
                  paste0("group", 1:n))

dat <- data.frame(
  ID = 1:N,
  diagnosis = stri_rand_strings(N, 4),
  stringsAsFactors = FALSE
)

system.time({
  rules_dt <- list(rules = unlist(rules, use.names = FALSE), 
                   grp = rep(seq_len(length(rules)), lengths(rules))) 
  invisible(lapply(seq_len(length(rules_dt[[1]])), f))
})

# user  system elapsed 
# 3.27    0.43    3.70
like image 91
David Arenburg Avatar answered May 02 '26 07:05

David Arenburg