Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

equivalent of enumerate in tensorflow to use index in tf.map_fn

I have a tensor that I process row by row using tf.map_fn. Now I want to include the index as an argument in the function that I am passing to tf.map_fn. In numpy I could use enumerate to get that information and pass it in my lambda function. Here's an example in numpy where I add 0 to the first row, 1 to the second row and so on:

a = np.array([[2, 1], [4, 2], [-1, 2]])

def some_function(x, i):
    return x + i

res = map(lambda (i, row): some_function(row, i), enumerate(a))
print(res)

> [array([2, 1]), array([5, 3]), array([1, 4])]

I haven't been able to find an equivalent to enumerate in tensorflow though and I don't know how I can achieve the same result in tensorflow. Does someone know what to use to make it work in tensorflow? Here's a sample code where I add 1 to each row of a:

import tensorflow as tf

a = tf.constant([[2, 1], [4, 2], [-1, 2]])

with tf.Session() as sess:
    res = tf.map_fn(lambda row: some_function(row, 1), a)
    print(res.eval())

> [[3 2]
   [5 3]
   [0 3]]

Thanks to anyone who can help me with this problem.

like image 712
kluu Avatar asked Oct 25 '25 14:10

kluu


1 Answers

tf.map_fn() can have many inputs / outputs. You could thus use tf.range() to build a tensor of row indices and use it along:

import tensorflow as tf

def some_function(x, i):
    return x + i

a = tf.constant([[2, 1], [4, 2], [-1, 2]])
a_rows = tf.expand_dims(tf.range(tf.shape(a)[0], dtype=tf.int32), 1)

res, _ = tf.map_fn(lambda x: (some_function(x[0], x[1]), x[1]), 
                   (a, a_rows), dtype=(tf.int32, tf.int32))

with tf.Session() as sess:
    print(res.eval())
    # [[2 1]
    #  [5 3]
    #  [1 4]]

Note: In many cases, "processing a matrix row by row" can be done at once e.g. through broadcasting, instead of using loops:

import tensorflow as tf

a = tf.constant([[2, 1], [4, 2], [-1, 2]])
a_rows = tf.expand_dims(tf.range(tf.shape(a)[0], dtype=tf.int32), 1)
res = a + a_rows

with tf.Session() as sess:
    print(res.eval())
    # [[2 1]
    #  [5 3]
    #  [1 4]]
like image 128
benjaminplanche Avatar answered Oct 27 '25 04:10

benjaminplanche