I want to loop over a tensor which contains a list of Int
, and apply a function to each of the elements.
In the function every element will get the value from a dict of python.
I have tried the easy way with tf.map_fn
, which will work on add
function, such as the following code:
import tensorflow as tf
def trans_1(x):
return x+10
a = tf.constant([1, 2, 3])
b = tf.map_fn(trans_1, a)
with tf.Session() as sess:
res = sess.run(b)
print(str(res))
# output: [11 12 13]
But the following code throw the KeyError: tf.Tensor'map_8/while/TensorArrayReadV3:0' shape=() dtype=int32
exception:
import tensorflow as tf
kv_dict = {1:11, 2:12, 3:13}
def trans_2(x):
return kv_dict[x]
a = tf.constant([1, 2, 3])
b = tf.map_fn(trans_2, a)
with tf.Session() as sess:
res = sess.run(b)
print(str(res))
My tensorflow version is 1.13.1
. Thanks ahead.
There is a simple way to achieve, what you are trying.
The problem is that the function passed to map_fn
must have tensors as its parameters and tensor as the return value. However, your function trans_2
takes plain python int
as parameter and returns another python int
. That's why your code doesn't work.
However, TensorFlow provides a simple way to wrap ordinary python functions, which is tf.py_func
, you can use it in your case as follows:
import tensorflow as tf
kv_dict = {1:11, 2:12, 3:13}
def trans_2(x):
return kv_dict[x]
def wrapper(x):
return tf.cast(tf.py_func(trans_2, [x], tf.int64), tf.int32)
a = tf.constant([1, 2, 3])
b = tf.map_fn(wrapper, a)
with tf.Session() as sess:
res = sess.run(b)
print(str(res))
you can see I have added a wrapper function, which expects tensor parameter and returns a tensor, that's why it can be used in map_fn. The cast is used because python by default uses 64-bit integers, whereas TensorFlow uses 32-bit integers.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With