I have a tensor that I process row by row using tf.map_fn. Now I want to include the index as an argument in the function that I am passing to tf.map_fn. In numpy I could use enumerate to get that information and pass it in my lambda function. Here's an example in numpy where I add 0 to the first row, 1 to the second row and so on:
a = np.array([[2, 1], [4, 2], [-1, 2]])
def some_function(x, i):
return x + i
res = map(lambda (i, row): some_function(row, i), enumerate(a))
print(res)
> [array([2, 1]), array([5, 3]), array([1, 4])]
I haven't been able to find an equivalent to enumerate in tensorflow though and I don't know how I can achieve the same result in tensorflow. Does someone know what to use to make it work in tensorflow? Here's a sample code where I add 1 to each row of a:
import tensorflow as tf
a = tf.constant([[2, 1], [4, 2], [-1, 2]])
with tf.Session() as sess:
res = tf.map_fn(lambda row: some_function(row, 1), a)
print(res.eval())
> [[3 2]
[5 3]
[0 3]]
Thanks to anyone who can help me with this problem.
tf.map_fn() can have many inputs / outputs. You could thus use tf.range() to build a tensor of row indices and use it along:
import tensorflow as tf
def some_function(x, i):
return x + i
a = tf.constant([[2, 1], [4, 2], [-1, 2]])
a_rows = tf.expand_dims(tf.range(tf.shape(a)[0], dtype=tf.int32), 1)
res, _ = tf.map_fn(lambda x: (some_function(x[0], x[1]), x[1]),
(a, a_rows), dtype=(tf.int32, tf.int32))
with tf.Session() as sess:
print(res.eval())
# [[2 1]
# [5 3]
# [1 4]]
Note: In many cases, "processing a matrix row by row" can be done at once e.g. through broadcasting, instead of using loops:
import tensorflow as tf
a = tf.constant([[2, 1], [4, 2], [-1, 2]])
a_rows = tf.expand_dims(tf.range(tf.shape(a)[0], dtype=tf.int32), 1)
res = a + a_rows
with tf.Session() as sess:
print(res.eval())
# [[2 1]
# [5 3]
# [1 4]]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With