Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Group tensorflow dataset by key and batch by key

Tags:

I'm currently working on a problem in tensorflow where I need to produce batches where all tensors in a batch have a specific key value. If possible I'm trying to use the dataset api. Is this possible?

Filter, map, apply all operate on individual elements, where I need a method of grouping by key. I've come across tf.data.experimental.group_by_window and tf.data.experimental.group_by_reducer, which seem promising, but I have not been able to work out a solution.

It might be best to give an example:

dataset:

feature,label
1,word1
2,word2
3,word3
1,word1
3,word3
1,word1
1,word1
2,word2
3,word3
1,word1
3,word3
1,word1
1,word1

group by "key" feature, maximum batch size = 3, giving batches:

batch1
[[1,word1],
 [1,word1],
 [1,word1]]
batch2
[[1,word1],
 [1,word1],
 [1,word1]]
batch3
[[1,word1]]
batch4
[[2,word2]
 [2,word2]]
batch5
[[3,word3],
 [3,word3],
 [3,word3]]
batch6
[[3,word3]]

EDIT: order of each batch is not important despite the example

like image 252
Trevor Stewart Avatar asked Mar 19 '19 15:03

Trevor Stewart


People also ask

What does From_tensor_slices do?

from_tensor_slices: It accepts single or multiple numpy arrays or tensors. Dataset created using this method will emit only one data at a time.

What is MapDataset?

A MapDataset is a dataset that applies a transform to a source dataset. Public Types using DatasetType = SourceDataset. using TransformType = AppliedTransform.

What is Prefetchdataset?

Creates a dataset that asynchronously prefetches elements from input_dataset .


1 Answers

I think this does the transformation you want:

import tensorflow as tf
import random

random.seed(100)
# Input data
label = list(range(15))
# Shuffle data
random.shuffle(label)
# Make feature from label data
feature = [lbl // 5 for lbl in label]
batch_size = 3

print('Data:')
print(*zip(feature, label), sep='\n')

with tf.Graph().as_default(), tf.Session() as sess:
    # Make dataset from data arrays
    ds = tf.data.Dataset.from_tensor_slices({'feature': feature, 'label': label})
    # Group by window
    ds = ds.apply(tf.data.experimental.group_by_window(
        # Use feature as key
        key_func=lambda elem: tf.to_int64(elem['feature']),
        # Convert each window to a batch
        reduce_func=lambda _, window: window.batch(batch_size),
        # Use batch size as window size
        window_size=batch_size))
    # Iterator
    iter = ds.make_one_shot_iterator().get_next()
    # Show dataset contents
    print('Result:')
    while True:
        try:
            print(sess.run(iter))
        except tf.errors.OutOfRangeError: break

Output:

Data:
(2, 11)
(1, 8)
(2, 12)
(0, 3)
(1, 9)
(0, 0)
(0, 4)
(0, 1)
(2, 10)
(1, 5)
(1, 6)
(2, 14)
(2, 13)
(1, 7)
(0, 2)
Result:
{'feature': array([0, 0, 0]), 'label': array([3, 0, 4])}
{'feature': array([2, 2, 2]), 'label': array([11, 12, 10])}
{'feature': array([1, 1, 1]), 'label': array([8, 9, 5])}
{'feature': array([0, 0]), 'label': array([1, 2])}
{'feature': array([1, 1]), 'label': array([6, 7])}
{'feature': array([2, 2]), 'label': array([14, 13])}
like image 127
jdehesa Avatar answered Oct 03 '22 11:10

jdehesa