Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Getting x_test, y_test from generator in Keras?

For certain problems, the validation data can't be a generator, e.g.: TensorBoard histograms:

If printing histograms, validation_data must be provided, and cannot be a generator.

My current code looks like:

image_data_generator = ImageDataGenerator()

training_seq   = image_data_generator.flow_from_directory(training_dir)
validation_seq = image_data_generator.flow_from_directory(validation_dir)
testing_seq    = image_data_generator.flow_from_directory(testing_dir)

model = Sequential(..)
# ..
model.compile(..)
model.fit_generator(training_seq, validation_data=validation_seq, ..)

How do I provide it as validation_data=(x_test, y_test)?

like image 490
A T Avatar asked Jan 03 '23 06:01

A T


1 Answers

Python 2.7 and Python 3.* solution:

from platform import python_version_tuple

if python_version_tuple()[0] == '3':
    xrange = range
    izip = zip
    imap = map
else:
    from itertools import izip, imap

import numpy as np

# ..
# other code as in question
# ..

x, y = izip(*(validation_seq[i] for i in xrange(len(validation_seq))))
x_val, y_val = np.vstack(x), np.vstack(y)

Or to support class_mode='binary', then:

from keras.utils import to_categorical

x_val = np.vstack(x)
y_val = np.vstack(imap(to_categorical, y))[:,0] if class_mode == 'binary' else y

Full runnable code: https://gist.github.com/AlecTaylor/7f6cc03ed6c3dd84548a039e2e0fd006

like image 192
A T Avatar answered Jan 13 '23 12:01

A T