Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Dataset API, Iterators and tf.contrib.data.rejection_resample

[Edit #1 after @mrry comment] I am using the (great & amazing) Dataset API along with tf.contrib.data.rejection_resample to set a specific distribution function to the input training pipeline.

Before adding the tf.contrib.data.rejection_resample to the input_fn I used the one shot Iterator. Alas, when starting to use the latter, I tried using the dataset.make_initializable_iterator() - This is because we are introducing to the pipeline stateful variables, and one is required to initialize the iterator AFTER all variables in the input pipeline are init. As @mrry wrote here.

I am passing the input_fn to an estimator and wrapped by an Experiment.

Problem is - where to hook the init of the iterator? If I try:

dataset = dataset.batch(batch_size)
if self.balance:
   dataset = tf.contrib.data.rejection_resample(dataset, self.class_mapping_function, self.dist_target)
   iterator = dataset.make_initializable_iterator()
   tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
else:
   iterator = dataset.make_one_shot_iterator() 

image_batch, label_batch = iterator.get_next()
print (image_batch) 

and the mapping function:

def class_mapping_function(self, feature, label):
    """
    returns a a function to be used with dataset.map() to return class numeric ID
    The function is mapping a nested structure of tensors (having shapes and types defined by dataset.output_shapes
    and dataset.output_types) to a scalar tf.int32 tensor. Values should be in [0, num_classes).
    """
    # For simplicity, trying to return the label itself as I assume its numeric...

    return tf.cast(label, tf.int32)  # <-- I guess this is the bug

the iterator does not receive the Tensor shape as it does with one shot iterator.

For Example. With One Shot iterator run, the iterator gets correct shape:

Tensor("train_input_fn/IteratorGetNext:0", shape=(?, 100, 100, 3), dtype=float32, device=/device:CPU:0)

But when using the initializable iterator, it is missing tensor shape info:

Tensor("train_input_fn/IteratorGetNext:0", shape=(?,), dtype=int32, device=/device:CPU:0)

Any help will be so appreciated!

[Edit #2 ]- following @mrry comment that it seems like another dataset] Perhaps the real issue here is not the init sequence of the iterator but the mapping function used by tf.contrib.data.rejection_resample that returns tf.int32. But then I wonder how the mapping function should be defined ? To keep the dataset shape as (?,100,100,3) for example...

[Edit #3]: From the implementation of rejection_resample

class_values_ds = dataset.map(class_func)

So it makes sense the class_func will take a dataset and return a dataset of tf.int32.

like image 606
Shahar Karny Avatar asked Dec 13 '22 20:12

Shahar Karny


1 Answers

Following @mrry response I could come up with a solution on how to use the Dataset API with tf.contrib.data.rejection_resample (using TF1.3).

The goal

Given a feature/label dataset with some distribution, have the input pipeline reshape the distribution to specific target distribution.

Numerical example

Lets assume we are building a network to classify some feature into one of 10 classes. And assume we only have 100 features with some random distribution of labels.
30 features labeled as class 1, 5 features labeled as class 2 and so forth. During training we do not want to prefer class 1 over class 2 so we would like each mini-batch to hold a uniform distribution for all classes.

The solution

Using tf.contrib.data.rejection_resample will allow to set a specific distribution for our inputs pipelines.

In the documentation it says tf.contrib.data.rejection_resample will take

(1) Dataset - which is the dataset you want to balance

(2) class_func - which is a function that generates a new numerical labels dataset only from the original dataset

(3) target_dist - a vector in the size of the number of classes to specificy required new distribution.

(4) some more optional values - skipped for now

and as the documentation says it returns a `Dataset.

It turns out that the shape of the input Dataset is different than the output Dataset shape. As a consequence, the returned Dataset (as implemeted in TF1.3) should be filtered by the user like this:

    balanced_dataset = tf.contrib.data.rejection_resample(input_dataset,
                                                          self.class_mapping_function,
                                                          self.target_distribution)

    # Return to the same Dataset shape as was the original input
    balanced_dataset = balanced_dataset.map(lambda _, data: (data))

One note on the Iterator kind. As @mrry explained here, when using stateful objects within the pipeline one should use the initializable iterator and not the one-hot. Note that when using the initializable iterator you should add the init_op to the TABLE_INITIALIZERS or you will recieve this error: "GetNext() failed because the iterator has not been initialized."

Code example:

# Creating the iterator, that allows to access elements from the dataset
if self.use_balancing:
    # For balancing function, we use stateful variables in the sense that they hold current dataset distribution
    # and calculate next distribution according to incoming examples.
    # For dataset pipeline that have state, one_shot iterator will not work, and we are forced to use
    # initializable iterator
    # This should be relaxed in the future.
    # https://stackoverflow.com/questions/44374083/tensorflow-cannot-capture-a-stateful-node-by-value-in-tf-contrib-data-api
    iterator = dataset.make_initializable_iterator()
    tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)

else:
    iterator = dataset.make_one_shot_iterator()

image_batch, label_batch = iterator.get_next()

Does it work ?

Yes. Here are 2 images from Tensorboard after collection a histogram on the input pipeline labels. The original input labels were uniformly distributed. Scenario A: Trying to achieve the following 10-class distribution: [0.1,0.4,0.05,0.05,0.05,0.05,0.05,0.05,0.1,0.1]

And the result:

enter image description here

Scenario B: Trying to achieve the following 10-class distribution: [0.1,0.1,0.05,0.05,0.05,0.05,0.05,0.05,0.4,0.1]

And the result:

enter image description here

like image 70
Shahar Karny Avatar answered May 17 '23 10:05

Shahar Karny