Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Split a dataset created by Tensorflow dataset API in to Train and Test?

People also ask

How do you split data into training and testing in Tensorflow?

You could just run train_test_split twice to do this as well. I.e. split the data into (Train + Validation) and Test, then split Train + Validation into two separate tensors.

How do you split a dataset in Tfds?

Splitting is possible by passing split parameter to tfds. load like so split="test[:70%]" . With the above code the training_set has 2569 entries, while validation_set has 1101.

How do you split data into training and testing in keras?

train_test_split() is a method in sklearn that allows users to split their data into training and testing sets. What this does, is split the input data, X and y, into 80–20 train test splits randomly (the test_size parameter controls the size of the split. Alternatively, one could test train_size!).


Assuming you have all_dataset variable of tf.data.Dataset type:

test_dataset = all_dataset.take(1000) 
train_dataset = all_dataset.skip(1000)

Test dataset now has first 1000 elements and the rest goes for training.


You may use Dataset.take() and Dataset.skip():

train_size = int(0.7 * DATASET_SIZE)
val_size = int(0.15 * DATASET_SIZE)
test_size = int(0.15 * DATASET_SIZE)

full_dataset = tf.data.TFRecordDataset(FLAGS.input_file)
full_dataset = full_dataset.shuffle()
train_dataset = full_dataset.take(train_size)
test_dataset = full_dataset.skip(train_size)
val_dataset = test_dataset.skip(val_size)
test_dataset = test_dataset.take(test_size)

For more generality, I gave an example using a 70/15/15 train/val/test split but if you don't need a test or a val set, just ignore the last 2 lines.

Take:

Creates a Dataset with at most count elements from this dataset.

Skip:

Creates a Dataset that skips count elements from this dataset.

You may also want to look into Dataset.shard():

Creates a Dataset that includes only 1/num_shards of this dataset.


Disclaimer I stumbled upon this question after answering this one so I thought I'd spread the love


Most of the answers here use take() and skip(), which requires knowing the size of your dataset before hand. This isn't always possible, or is difficult/intensive to ascertain.

Instead what you can do is to essentially slice the dataset up so that 1 every N records becomes a validation record.

To accomplish this, lets start with a simple dataset of 0-9:

dataset = tf.data.Dataset.range(10)
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

Now for our example, we're going to slice it so that we have a 3/1 train/validation split. Meaning 3 records will go to training, then 1 record to validation, then repeat.

split = 3
dataset_train = dataset.window(split, split + 1).flat_map(lambda ds: ds)
# [0, 1, 2, 4, 5, 6, 8, 9]
dataset_validation = dataset.skip(split).window(1, split + 1).flat_map(lambda ds: ds)
# [3, 7]

So the first dataset.window(split, split + 1) says to grab split number (3) of elements, then advance split + 1 elements, and repeat. That + 1 effectively skips the 1 element we're going to use in our validation dataset.
The flat_map(lambda ds: ds) is because window() returns the results in batches, which we don't want. So we flatten it back out.

Then for the validation data we first skip(split), which skips over the first split number (3) of elements that were grabbed in the first training window, so we start our iteration on the 4th element. The window(1, split + 1) then grabs 1 element, advances split + 1 (4), and repeats.

 

Note on nested datasets:
The above example works well for simple datasets, but flat_map() will generate an error if the dataset is nested. To address this, you can swap out the flat_map() with a more complicated version that can handle both simple and nested datasets:

.flat_map(lambda *ds: ds[0] if len(ds) == 1 else tf.data.Dataset.zip(ds))

@ted's answer will cause some overlap. Try this.

train_ds_size = int(0.64 * full_ds_size)
valid_ds_size = int(0.16 * full_ds_size)

train_ds = full_ds.take(train_ds_size)
remaining = full_ds.skip(train_ds_size)  
valid_ds = remaining.take(valid_ds_size)
test_ds = remaining.skip(valid_ds_size)

use code below to test.

tf.enable_eager_execution()

dataset = tf.data.Dataset.range(100)

train_size = 20
valid_size = 30
test_size = 50

train = dataset.take(train_size)
remaining = dataset.skip(train_size)
valid = remaining.take(valid_size)
test = remaining.skip(valid_size)

for i in train:
    print(i)

for i in valid:
    print(i)

for i in test:
    print(i)