Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tidymodels: Creating an rsplit object from training and testing data

Tags:

r

tidymodels

I’m trying to make the jump from Scikit-Learn to Tidymodels and most of the time it has been relatively painless thanks to the tutorials from Julia Silge and Andrew Couch. However, now I’m stuck. Normally I would use initial_split(df, strata = x) to get a split object to work with. But this time I’ve been provided with the test and train sets from a different department and I’m afraid this might become the norm. Without a split object functions like last_fit() and collect_predictions() don’t work.

How can I reverse engineer the provided datasets so that they become rsplit objects? Or alternatively, is it possible to bind the datasets together first and then tell initial_split() exactly what rows should go to train and test?

I see that someone asked the same question at https://community.rstudio.com/t/tidymodels-creating-a-split-object-from-testing-and-training-data-perform-last-fit/69885. Max Kuhn said you could reverse engineer an rsplit object but I didn’t understand how. Thanks!

# Example data
train <- tibble(predictor = c(0, 1, 1, 1, 0, 1, 0, 0),
       feature_1 = c(12, 18, 15, 5, 20, 2, 6, 10),
       feature_2 = c(120, 98, 111, 67, 335, 123, 22, 69))

test <- tibble(predictor = c(0, 1, 0, 1),
       feature_1 = c(5, 13, 8, 9),
       feature_2 = c(132, 105, 99, 112))
like image 722
Economist Avatar asked Oct 14 '25 22:10

Economist


1 Answers

Reverse engineering the split object is likely meaning simply looking at the construction of the rsplit object. Depending on the package implementation this can be as simple as reconstructing the object with the same fields as the ones that comes when using initial_split. This is most likely the case here, so we'd simply have to recreate the object and make certain all the fields are available.

One method however (likely the simplest) would be to combine the two data.frames and use indices together with make_splits to recreate the the original split pair

library(rsample)
library(dplyr)
combined <- bind_rows(train, test)
ind <- list(analysis = seq(nrow(train)), assessment = nrow(train) + seq(nrow(test)))
splits <- make_splits(ind, combined)
splits
<Analysis/Assess/Total>
<8/4/12>
like image 200
Oliver Avatar answered Oct 17 '25 12:10

Oliver