Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to change batch size dynamically in Tensorflow 2.0 Dataset?

In TensorFlow 1.X you could change the batch size dynamically using a placeholder. eg

dataset.batch(batch_size=tf.placeholder())
See full example

How do you do it in TensorFlow 2.0?

I have tried the following but it doesn't work.

import numpy as np
import tensorflow as tf


def new_gen_function():
    for i in range(100):
        yield np.ones(2).astype(np.float32)


batch_size = tf.Variable(5, trainable=False, dtype=tf.int64)
train_ds = tf.data.Dataset.from_generator(new_gen_function, output_types=(tf.float32)).batch(
    batch_size=batch_size)

for data in train_ds:
    print(data.shape[0])
    batch_size.assign(10)
    print(batch_size)

Output

5
<tf.Variable 'Variable:0' shape=() dtype=int64, numpy=10>
5
<tf.Variable 'Variable:0' shape=() dtype=int64, numpy=10>
5
<tf.Variable 'Variable:0' shape=() dtype=int64, numpy=10>
5
...
...

I am training a model using a custom training loop using Gradient tape. How can I achieve this?

like image 663
Himaprasoon Avatar asked Dec 17 '19 06:12

Himaprasoon


People also ask

What is dynamic batch size?

Batch size is the number of images processed at a time during model inference. In fixed batch size scenarios, the batch size is determined by the value of N in the shape. In dynamic batch size scenarios, the dynamic batch size can be set dynamically using the dynamic batch function.

How do I choose a batch size in Lstm?

Optimal Batch Size? By experience, in most cases, an optimal batch-size is 64. Nevertheless, there might be some cases where you select the batch size as 32, 64, 128 which must be dividable by 8. Note that this batch size fine-tuning must be done based on the performance observation.

What is batch size in TF?

By default, the batch size (batch_size) is 32.

What is TF data dataset From_tensor_slices?

from_tensor_slices creates a dataset with a separate element for each row of the input tensor: >>> t = tf.constant([[1, 2], [3, 4]]) >>> ds = tf.data.Dataset.from_tensor_slices(t) >>> [x for x in ds] [<tf.Tensor: shape=(2,), dtype=int32, numpy=array([1, 2], dtype=int32)>, <tf.Tensor: shape=(2,), dtype=int32, numpy= ...


2 Answers

I don't think you can the way you used to in TF1.

A work-around could be to build the batch yourself by stacking individual samples:

import tensorflow as tf

ds = tf.data.Dataset.range(10).repeat()
iterator = iter(ds)
for batch_size in range(1, 10):
  batch = tf.stack([iterator.next() for _ in range(batch_size)], axis=0)
  print(batch)

# tf.Tensor([0], shape=(1,), dtype=int64)
# tf.Tensor([1 2], shape=(2,), dtype=int64)
# tf.Tensor([3 4 5], shape=(3,), dtype=int64)
# tf.Tensor([6 7 8 9], shape=(4,), dtype=int64)
# tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int64)
# tf.Tensor([5 6 7 8 9 0], shape=(6,), dtype=int64)
# tf.Tensor([1 2 3 4 5 6 7], shape=(7,), dtype=int64)
# tf.Tensor([8 9 0 1 2 3 4 5], shape=(8,), dtype=int64)
# tf.Tensor([6 7 8 9 0 1 2 3 4], shape=(9,), dtype=int64)
like image 55
P-Gn Avatar answered Sep 29 '22 17:09

P-Gn


From what I know, you should instantiate a new dataset iterator to make your change take effect. This will require to tweak a little bit to skip already seen samples.

Here is my simplest solution:

import numpy as np
import tensorflow as tf

def get_dataset(batch_size, num_samples_seen):
    return tf.data.Dataset.range(
        100
    ).skip(
        num_samples_seen
    ).batch(
        batch_size=batch_size
    )

def main():
    batch_size = 1
    num_samples_seen = 0

    train_ds = get_dataset(batch_size, num_samples_seen)

    ds_iterator = iter(train_ds)
    while True:
        try:
            data = next(ds_iterator)
        except StopIteration:
            print("End of iteration")
            break

        print(data)
        batch_size *= 2
        num_samples_seen += data.shape[0]
        ds_iterator = iter(get_dataset(batch_size, num_samples_seen))
        print("New batch size:", batch_size)

if __name__ == "__main__":
    main()

As you can see here, you have to instantiate a new dataset (through a call to get_dataset) and update the iterator.

I don't know of the performance impact of such a solution. Maybe there is another solution requiring to "just" instantiate a batch step instead of the whole dataset.

like image 20
AlexisBRENON Avatar answered Sep 29 '22 17:09

AlexisBRENON