Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I filter tf.data.Dataset by specific values?

I create a dataset by reading the TFRecords, I map the values and I want to filter the dataset for specific values, but since the result is a dict with tensors, I am not able to get the actual value of a tensor or to check it with tf.cond() / tf.equal. How can I do that?

def mapping_func(serialized_example):
    feature = { 'label': tf.FixedLenFeature([1], tf.string) }
    features = tf.parse_single_example(serialized_example, features=feature)
    return features

def filter_func(features):
    # this doesn't work
    #result = features['label'] == 'some_label_value'
    # neither this
    result = tf.reshape(tf.equal(features['label'], 'some_label_value'), [])
    return result

def main():
    file_names = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
    dataset = tf.contrib.data.TFRecordDataset(file_names)
    dataset = dataset.map(mapping_func)
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.filter(filter_func)
    dataset = dataset.repeat()
    iterator = dataset.make_one_shot_iterator()
    sample = iterator.get_next()
like image 653
tsveti_iko Avatar asked Feb 16 '18 11:02

tsveti_iko


People also ask

What is TF filter?

The two-stage Eaton Type TF Separator/Filter removes 99% of all particles larger than 10 microns and is for applications requiring extremely fine entrainment removal. The first stage is a centrifugal separator that removes slugs and heavy liquid loads.

What does TF data dataset From_tensor_slices do?

With that knowledge, from_tensors makes a dataset where each input tensor is like a row of your dataset, and from_tensor_slices makes a dataset where each input tensor is column of your data; so in the latter case all tensors must be the same length, and the elements (rows) of the resulting dataset are tuples with one ...

What is a prefetch dataset?

Dataset. prefetch transformation. It can be used to decouple the time when data is produced from the time when data is consumed. In particular, the transformation uses a background thread and an internal buffer to prefetch elements from the input dataset ahead of the time they are requested.


2 Answers

I am answering my own question. I found the issue!

What I needed to do is tf.unstack() the label like this:

label = tf.unstack(features['label'])
label = label[0]

before I give it to tf.equal():

result = tf.reshape(tf.equal(label, 'some_label_value'), [])

I suppose the problem was that the label is defined as an array with one element of type string tf.FixedLenFeature([1], tf.string), so in order to get the first and single element I had to unpack it (which creates a list) and then get the element with index 0, correct me if I'm wrong.

like image 170
tsveti_iko Avatar answered Sep 17 '22 14:09

tsveti_iko


I think you don't need to make label a 1-dimensional array in the first place.

with:

feature = {'label': tf.FixedLenFeature((), tf.string)}

you won't need to unstack the label in your filter_func

like image 26
Vincent Setiawan Avatar answered Sep 20 '22 14:09

Vincent Setiawan