Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Scikit-learn train_test_split with indices

How do I get the original indices of the data when using train_test_split()?

What I have is the following

from sklearn.cross_validation import train_test_split import numpy as np data = np.reshape(np.randn(20),(10,2)) # 10 training examples labels = np.random.randint(2, size=10) # 10 labels x1, x2, y1, y2 = train_test_split(data, labels, size=0.2) 

But this does not give the indices of the original data. One workaround is to add the indices to data (e.g. data = [(i, d) for i, d in enumerate(data)]) and then pass them inside train_test_split and then expand again. Are there any cleaner solutions?

like image 766
CentAu Avatar asked Jul 20 '15 16:07

CentAu


1 Answers

You can use pandas dataframes or series as Julien said but if you want to restrict your-self to numpy you can pass an additional array of indices:

from sklearn.model_selection import train_test_split import numpy as np n_samples, n_features, n_classes = 10, 2, 2 data = np.random.randn(n_samples, n_features)  # 10 training examples labels = np.random.randint(n_classes, size=n_samples)  # 10 labels indices = np.arange(n_samples) (     data_train,     data_test,     labels_train,     labels_test,     indices_train,     indices_test, ) = train_test_split(data, labels, indices, test_size=0.2) 
like image 150
ogrisel Avatar answered Sep 25 '22 17:09

ogrisel