Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do I select only a specific digit from the MNIST dataset provided by Keras?

I'm currently training a Feedforward Neural Network on the MNIST data set using Keras. I'm loading the data set using the format

(X_train, Y_train), (X_test, Y_test) = mnist.load_data()

but then I only want to train my model using digit 0 and 4 not all of them. How do I select only the 2 digits? I am fairly new to python and can figure out how to filter the mnist dataset...

like image 723
sherry Avatar asked Jul 06 '18 02:07

sherry


1 Answers

Y_train and Y_test give you the labels of images, you can use them with numpy.where to filter out a subset of labels with 0's and 4's. All your variables are numpy arrays, so you can simply do;

import numpy as np

train_filter = np.where((Y_train == 0 ) | (Y_train == 4))
test_filter = np.where((Y_test == 0) | (Y_test == 4))

and you can use these filters to get the subset of arrays by index.

X_train, Y_train = X_train[train_filter], Y_train[train_filter]
X_test, Y_test = X_test[test_filter], Y_test[test_filter]

If you are interested in more than 2 labels, the syntax can get hairy with where and or. So you can also use numpy.isin to create masks.

train_mask = np.isin(Y_train, [0, 4])
test_mask = np.isin(Y_test, [0, 4])

You can use these masks for boolean indexing, same as before.

like image 113
umutto Avatar answered Nov 01 '22 13:11

umutto