Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Chunk tensorflow dataset records into multiple records

I have an unbatched tensorflow dataset that looks like this:

ds = ...
for record in ds.take(3):
    print('data shape={}'.format(record['data'].shape))

-> data shape=(512, 512, 87)
-> data shape=(512, 512, 277)
-> data shape=(512, 512, 133)

I want to feed the data to my network in chunks of depth 5. In the example above, the tensor of shape (512, 512, 87) would be divided into 17 tensors of shape (512, 512, 5). The final 2 rows of the matrix (tensor[:,:, 85:87]) should be discarded.

For example:

chunked_ds = ...
for record in chunked_ds.take(1):
    print('chunked data shape={}'.format(record['data'].shape))

-> chunked data shape=(512, 512, 5)

How can I get from ds to chunked_ds? tf.data.Dataset.window() looks like what I need but I cannot get this working.

like image 590
Ollie Avatar asked Oct 14 '22 21:10

Ollie


1 Answers

This can be actually done using tf.data.Dataset-only operations:

data = tf.random.normal( shape=[ 10 , 512 , 512 , 87 ] )
ds = tf.data.Dataset.from_tensor_slices( ( data ) )
chunk_size = 5
chunked_ds = ds.flat_map(lambda x: tf.data.Dataset.from_tensor_slices(tf.transpose(x, perm=[2, 0, 1])).batch(chunk_size, drop_remainder=True)) \
                    .map(lambda rec: tf.transpose(rec, perm=[1, 2, 0]))

What is going on there:

First, we treat each each record as a separate Dataset and we permute it so that the last dimension becomes the batch dimension (flat_map will flatten our internal datasets to Tensors again)

.flat_map(lambda x: tf.data.Dataset.from_tensor_slices(tf.transpose(x, perm=[2, 0, 1])

Then we batch it by 5, but we do not care about remainder

.batch(chunk_size, drop_remainder=True))

Finally, re-permute tensors so that we have 512x512 at the beggining:

.map(lambda rec: tf.transpose(rec, perm=[1, 2, 0]))
like image 98
Proko Avatar answered Oct 20 '22 16:10

Proko