Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I remove or omit data using map method for tf.data.Dataset objects?

I am using tensorflow 2.3.0

I have a python data generator-

import tensorflow as tf
import numpy as np

vocab = [1,2,3,4,5]

def create_generator():
    'generates a random number from 0 to len(vocab)-1'
    count = 0
    while count < 4:
        x = np.random.randint(0, len(vocab))
        yield x
        count +=1

I make it a tf.data.Dataset object

gen = tf.data.Dataset.from_generator(create_generator, 
                                     args=[], 
                                     output_types=tf.int32, 
                                     output_shapes = (), )

Now I want to sub-sample items using the map method, such that the tf generator would never output any even number.

def subsample(x):
    'remove item if it is present in an even number [2,4]'
    
    '''
    #TODO
    '''
    return x
    
gen = gen.map(subsample)   

How can I achieve this using map method?

like image 768
n0obcoder Avatar asked Oct 30 '25 11:10

n0obcoder


1 Answers

Shortly no, you cannot filter data using map. Map functions apply some transformation to every element of the dataset. What you want is to check every element for some predicate and get only those elements that satisfy the predicate.

And that function is filter().

So you can do:

gen = gen.filter(lambda x: x % 2 != 0)

Update:

If you want to use a custom function instead of lambda, you can do something like:

def filter_func(x):
    if x**2 < 500:
        return True
    return False
gen = gen.filter(filter_func)

If this function is passed to filter all numbers whose square is less than 500 will be returned.

like image 72
Muslimbek Abduganiev Avatar answered Nov 02 '25 07:11

Muslimbek Abduganiev



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!