Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using rejection_resample() with the Dataset Api

I am having a hard time trying to make some balancing batching using the rejection_resample() along with the Dataset API. I am using images and labels (ints) as input, as you can glance in the code, but the rejection_resample() seems not to work as expected.

Note: I am using Tensorflow v1.3

Here I define the dataset, the dataset's distribution and the distribution I want.

target_dist = [0.1, 0.0, 0.0, 0.0, 0.9]
initial_dist = [0.1061, 0.3213, 0.4238, 0.1203, 0.0282]

training_filenames = training_records
training_dataset = tf.contrib.data.TFRecordDataset(training_filenames)
training_dataset = training_dataset.map(tf_record_parser)  # Parse the record into tensors.
training_dataset = training_dataset.repeat()  # number of epochs
training_dataset = training_dataset.shuffle(buffer_size=1000)

training_dataset = tf.contrib.data.rejection_resample(training_dataset,
                                                      class_func=lambda _, c: c,
                                                      target_dist=target_dist,
                                                      initial_dist=initial_dist)

# Return to the same Dataset shape as was the original input
training_dataset = training_dataset.map(lambda _, data: (data))

training_dataset = training_dataset.batch(64)

handle = tf.placeholder(tf.string, shape=[])
iterator = tf.contrib.data.Iterator.from_string_handle(
    handle, training_dataset.output_types, training_dataset.output_shapes)
batch_images, batch_labels = iterator.get_next()
training_iterator = training_dataset.make_initializable_iterator()

When I run this thing I should only get samples from the classes 0 and 4, but I get results from all of the classes, as though it did not work.

with tf.Session() as sess:
    training_handle = sess.run(training_iterator.string_handle())
    sess.run(training_iterator.initializer)
    batch_faces_np, batch_label_np = sess.run([batch_images, batch_labels],feed_dict={handle: training_handle})

    ctr = Counter(batch_label_np)

Counter({2: 31, 3: 22, 4: 6, 1: 5})

I tested with an example based on this post: Dataset API, Iterators and tf.contrib.data.rejection_resample and from the original testing code from the tensorflow repo and it works.

initial_known = True
classes = np.random.randint(5, size=(20000,))  # Uniformly sampled
target_dist = [0.5, 0.0, 0.0, 0.0, 0.4]
initial_dist = [0.2] * 5 if initial_known else None

iterator = dataset_ops.Iterator.from_dataset(
    dataset_ops.rejection_resample(
        (dataset_ops.Dataset.from_tensor_slices(classes)
         .shuffle(200, seed=21)
         .map(lambda c: (c, string_ops.as_string(c)))),
        target_dist=target_dist,
        initial_dist=initial_dist,
        class_func=lambda c, _: c,
        seed=27))
init_op = iterator.initializer
get_next = iterator.get_next()
variable_init_op = variables.global_variables_initializer()

with tf.Session() as sess:
    sess.run(variable_init_op)
    sess.run(init_op)
    returned = []
    while True:
        returned.append(sess.run(get_next))

Counter({(0, (0, b'0')): 3873, (4, (4, b'4')): 3286})

Can you guys help me with that? Thanks.

like image 210
Thalles Avatar asked Nov 08 '17 11:11

Thalles


People also ask

What is TF data dataset?

TensorFlow Datasets is a collection of datasets ready to use, with TensorFlow or other Python ML frameworks, such as Jax. All datasets are exposed as tf. data. Datasets , enabling easy-to-use and high-performance input pipelines. To get started see the guide and our list of datasets.

What is TF data experimental Autotune?

tf. data builds a performance model of the input pipeline and runs an optimization algorithm to find a good allocation of its CPU budget across all parameters specified as AUTOTUNE .

What is the use of resample in pandas?

Pandas dataframe.resample() function is primarily used for time series data. A time series is a series of data points indexed (or listed or graphed) in time order. Most commonly, a time series is a sequence taken at successive equally spaced points in time.

How do I resample a Dataframe?

Resample a DataFrame. Group DataFrame by mapping, function, label, or list of labels. Reindex a DataFrame with the given frequency without grouping. See the user guide for more. To learn more about the offset strings, please see this link. Start by creating a series with 9 one minute timestamps.

What is resampling in time series data?

Resampling generates a unique sampling distribution on the basis of the actual data. We can apply various frequency to resample our time series data. This is a very important technique in the field of analytics. There are many other types of time series frequency available.

How do I resample quarters by month in a Dataframe?

Resample quarters by month using ‘end’ convention. Values are assigned to the last month of the period. For DataFrame objects, the keyword on can be used to specify the column instead of the index for resampling. For a DataFrame with MultiIndex, the keyword level can be used to specify on which level the resampling needs to take place.


1 Answers

Try with seed value for shuffle. It worked with seed value for me.

like image 174
venkatesh-sg Avatar answered Oct 21 '22 21:10

venkatesh-sg