Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do I split Tensorflow datasets?

I have a tensorflow dataset based on one .tfrecord file. How do I split the dataset into test and train datasets? E.g. 70% Train and 30% test?

Edit:

My Tensorflow Version: 1.8 I've checked, there is no "split_v" function as mentioned in the possible duplicate. Also I am working with a tfrecord file.

like image 586
Lukas Hestermeyer Avatar asked Jul 01 '18 17:07

Lukas Hestermeyer


People also ask

How do I split a dataset into two?

The simplest way to split the modelling dataset into training and testing sets is to assign 2/3 data points to the former and the remaining one-third to the latter. Therefore, we train the model using the training set and then apply the model to the test set. In this way, we can evaluate the performance of our model.

How do you split data for deep learning?

If the size of our dataset is between 100 to 10,00,000, then we split it in the ratio 60:20:20. That is 60% data will go to the Training Set, 20% to the Dev Set and remaining to the Test Set. The main aim of deciding the splitting ratio is that all three sets should have the general trend of our original dataset.

How do you split data into training and testing in Tensorflow?

The model_selection. train_test_split() method is specifically designed to split your data into train and test sets randomly and by percentage. test_size is the percentage to reserve for testing and random_state is to seed the random sampling.


2 Answers

You may use Dataset.take() and Dataset.skip():

train_size = int(0.7 * DATASET_SIZE) val_size = int(0.15 * DATASET_SIZE) test_size = int(0.15 * DATASET_SIZE)  full_dataset = tf.data.TFRecordDataset(FLAGS.input_file) full_dataset = full_dataset.shuffle() train_dataset = full_dataset.take(train_size) test_dataset = full_dataset.skip(train_size) val_dataset = test_dataset.skip(test_size) test_dataset = test_dataset.take(test_size) 

For more generality, I gave an example using a 70/15/15 train/val/test split but if you don't need a test or a val set, just ignore the last 2 lines.

Take:

Creates a Dataset with at most count elements from this dataset.

Skip:

Creates a Dataset that skips count elements from this dataset.

You may also want to look into Dataset.shard():

Creates a Dataset that includes only 1/num_shards of this dataset.

like image 186
ted Avatar answered Sep 22 '22 20:09

ted


This question is similar to this one and this one, and I am afraid we have not had a satisfactory answer yet.

  • Using take() and skip() requires knowing the dataset size. What if I don't know that, or don't want to find out?

  • Using shard() only gives 1 / num_shards of dataset. What if I want the rest?

I try to present a better solution below, tested on TensorFlow 2 only. Assuming you already have a shuffled dataset, you can then use filter() to split it into two:

import tensorflow as tf  all = tf.data.Dataset.from_tensor_slices(list(range(1, 21))) \         .shuffle(10, reshuffle_each_iteration=False)  test_dataset = all.enumerate() \                     .filter(lambda x,y: x % 4 == 0) \                     .map(lambda x,y: y)  train_dataset = all.enumerate() \                     .filter(lambda x,y: x % 4 != 0) \                     .map(lambda x,y: y)  for i in test_dataset:     print(i)  print()  for i in train_dataset:     print(i) 

The parameter reshuffle_each_iteration=False is important. It makes sure the original dataset is shuffled once and no more. Otherwise, the two resulting sets may have some overlaps.

Use enumerate() to add an index.

Use filter(lambda x,y: x % 4 == 0) to take 1 sample out of 4. Likewise, x % 4 != 0 takes 3 out of 4.

Use map(lambda x,y: y) to strip the index and recover the original sample.

This example achieves a 75/25 split.

x % 5 == 0 and x % 5 != 0 gives a 80/20 split.

If you really want a 70/30 split, x % 10 < 3 and x % 10 >= 3 should do.

UPDATE:

As of TensorFlow 2.0.0, above code may result in some warnings due to AutoGraph's limitations. To eliminate those warnings, declare all lambda functions separately:

def is_test(x, y):     return x % 4 == 0  def is_train(x, y):     return not is_test(x, y)  recover = lambda x,y: y  test_dataset = all.enumerate() \                     .filter(is_test) \                     .map(recover)  train_dataset = all.enumerate() \                     .filter(is_train) \                     .map(recover) 

This gives no warning on my machine. And making is_train() to be not is_test() is definitely a good practice.

like image 25
Nick Lee Avatar answered Sep 23 '22 20:09

Nick Lee