Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Partition training data by class in NumPy

I have a 50000 x 784 data matrix (50000 samples and 784 features) and the corresponding 50000 x 1 class vector (classes are integers 0-9). I'm looking for an efficient way to group the data matrix into 10 data matrices and class vectors that each have only the data for a particular class 0-9.

I can't seem to find an elegant way to do this, aside from just looping through the data matrix and constructing the 10 other matrices that way.

Does anyone know if there is a clean way to do this with something in scipy, numpy, or sklearn?

like image 555
Ryan Avatar asked Mar 14 '23 01:03

Ryan


1 Answers

Probably the cleanest way of doing this in numpy, especially if you have many classes, is through sorting:

SAMPLES = 50000
FEATURES = 784
CLASSES = 10
data = np.random.rand(SAMPLES, FEATURES)
classes = np.random.randint(CLASSES, size=SAMPLES)

sorter = np.argsort(classes)
classes_sorted = classes[sorter]
splitter, = np.where(classes_sorted[:-1] != classes_sorted[1:])
data_splitted = np.split(data[sorter], splitter + 1)

data_splitted will be a list of arrays, one for each class found in classes. Running the above code with SAMPLES = 10, FEATURES = 2 and CLASSES = 3 I get:

>>> data
array([[ 0.45813694,  0.47942962],
       [ 0.96587082,  0.73260743],
       [ 0.70539842,  0.76376921],
       [ 0.01031978,  0.93660231],
       [ 0.45434223,  0.03778273],
       [ 0.01985781,  0.04272293],
       [ 0.93026735,  0.40216376],
       [ 0.39089845,  0.01891637],
       [ 0.70937483,  0.16077439],
       [ 0.45383099,  0.82074859]])

>>> classes
array([1, 1, 2, 1, 1, 2, 0, 2, 0, 1])

>>> data_splitted 
[array([[ 0.93026735,  0.40216376],
        [ 0.70937483,  0.16077439]]),
 array([[ 0.45813694,  0.47942962],
        [ 0.96587082,  0.73260743],
        [ 0.01031978,  0.93660231],
        [ 0.45434223,  0.03778273],
        [ 0.45383099,  0.82074859]]),
 array([[ 0.70539842,  0.76376921],
        [ 0.01985781,  0.04272293],
        [ 0.39089845,  0.01891637]])]

If you want to make sure the sort is stable, i.e. that data points in the same class remain in the same relative order after sorting, you will need to specify sorter = np.argsort(classes, kind='mergesort').

like image 131
Jaime Avatar answered Mar 20 '23 01:03

Jaime