Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Loop over a tensor and apply function to each element

I want to loop over a tensor which contains a list of Int, and apply a function to each of the elements. In the function every element will get the value from a dict of python. I have tried the easy way with tf.map_fn, which will work on add function, such as the following code:

import tensorflow as tf

def trans_1(x):
    return x+10

a = tf.constant([1, 2, 3])
b = tf.map_fn(trans_1, a)
with tf.Session() as sess:
    res = sess.run(b)
    print(str(res))
# output: [11 12 13]

But the following code throw the KeyError: tf.Tensor'map_8/while/TensorArrayReadV3:0' shape=() dtype=int32 exception:

import tensorflow as tf

kv_dict = {1:11, 2:12, 3:13}

def trans_2(x):
    return kv_dict[x]

a = tf.constant([1, 2, 3])
b = tf.map_fn(trans_2, a)
with tf.Session() as sess:
    res = sess.run(b)
    print(str(res))

My tensorflow version is 1.13.1. Thanks ahead.

like image 854
xuhai Avatar asked Jun 28 '19 10:06

xuhai


1 Answers

There is a simple way to achieve, what you are trying.

The problem is that the function passed to map_fn must have tensors as its parameters and tensor as the return value. However, your function trans_2 takes plain python int as parameter and returns another python int. That's why your code doesn't work.

However, TensorFlow provides a simple way to wrap ordinary python functions, which is tf.py_func, you can use it in your case as follows:

import tensorflow as tf

kv_dict = {1:11, 2:12, 3:13}

def trans_2(x):
    return kv_dict[x]

def wrapper(x):
    return tf.cast(tf.py_func(trans_2, [x], tf.int64), tf.int32)

a = tf.constant([1, 2, 3])
b = tf.map_fn(wrapper, a)
with tf.Session() as sess:
    res = sess.run(b)
    print(str(res))

you can see I have added a wrapper function, which expects tensor parameter and returns a tensor, that's why it can be used in map_fn. The cast is used because python by default uses 64-bit integers, whereas TensorFlow uses 32-bit integers.

like image 51
Addy Avatar answered Oct 27 '22 00:10

Addy