Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Split data directory into training and test directory with sub directory structure preserved

I am interested in using ImageDataGenerator in Keras for data augmentation. But it requires that training and validation directories with sub directories for classes be fed in separately as below (this is from Keras documentation). I have a single directory with 2 subdirectories for 2 classes (Data/Class1 and Data/Class2). How do I randomly split this into training and validation directories

    train_datagen = ImageDataGenerator(
    rescale=1./255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)

    test_datagen = ImageDataGenerator(rescale=1./255)

    train_generator = train_datagen.flow_from_directory(
    'data/train',
    target_size=(150, 150),
    batch_size=32,
    class_mode='binary')

   validation_generator = test_datagen.flow_from_directory(
    'data/validation',
    target_size=(150, 150),
    batch_size=32,
    class_mode='binary')

   model.fit_generator(
    train_generator,
    steps_per_epoch=2000,
    epochs=50,
    validation_data=validation_generator,
    validation_steps=800)

I am interested in re-running my algorithm multiple times with random training and validation data splits.

like image 370
Sharanya Arcot Desai Avatar asked Oct 12 '17 19:10

Sharanya Arcot Desai


People also ask

Which command is used to split the data for training and testing?

Using scikit-learn (aka sklearn ) train_test_split() Using numpy 's randn() function. or with built-in pandas method called sample()

How does splitting a dataset into train and test sets help identify Overfitting?

The main idea of splitting the dataset into a validation set is to prevent our model from overfitting i.e., the model becomes really good at classifying the samples in the training set but cannot generalize and make accurate classifications on the data it has not seen before.


1 Answers

Thank you guys! I was able to write my own function to create training and test data sets. Here's the code for anyone who's looking.

import os
source1 = "/source_dir"
dest11 = "/dest_dir"
files = os.listdir(source1)
import shutil
import numpy as np
for f in files:
    if np.random.rand(1) < 0.2:
        shutil.move(source1 + '/'+ f, dest11 + '/'+ f)
like image 133
Sharanya Arcot Desai Avatar answered Oct 02 '22 13:10

Sharanya Arcot Desai