Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to find an index of the first matching element in TensorFlow

Tags:

tensorflow

I am looking for a TensorFlow way of implementing something similar to Python's list.index() function.

Given a matrix and a value to find, I want to know the first occurrence of the value in each row of the matrix.

For example,

m is a <batch_size, 100> matrix of integers
val = 23

result = [0] * batch_size
for i, row_elems in enumerate(m):
  result[i] = row_elems.index(val)

I cannot assume that 'val' appears only once in each row, otherwise I would have implemented it using tf.argmax(m == val). In my case, it is important to get the index of the first occurrence of 'val' and not any.

like image 965
Igor Tsvetkov Avatar asked Feb 12 '17 05:02

Igor Tsvetkov


People also ask

What is tf Argmax?

tf.argmax. Returns the index with the largest value across axes of a tensor. tf.equal. Returns the truth value of (x == y) element-wise.

What is tf Ones_like?

Used in the notebooksGiven a single tensor ( tensor ), this operation returns a tensor of the same type and shape as tensor with all elements set to 1.

What is tf constant?

tf. constant is useful for asserting that the value can be embedded that way. If the argument dtype is not specified, then the type is inferred from the type of value . # Constant 1-D Tensor from a python list. tf.

What is tf tile?

tf.tile( input, multiples, name=None ) Constructs a tensor by tiling a given tensor. Parameters. input: a tensor to be tiled. multiples: must be 1-D.


2 Answers

It seems that tf.argmax works like np.argmax (according to the test), which will return the first index when there are multiple occurrences of the max value. You can use tf.argmax(tf.cast(tf.equal(m, val), tf.int32), axis=1) to get what you want. However, currently the behavior of tf.argmax is undefined in case of multiple occurrences of the max value.

If you are worried about undefined behavior, you can apply tf.argmin on the return value of tf.where as @Igor Tsvetkov suggested. For example,

# test with tensorflow r1.0
import tensorflow as tf

val = 3
m = tf.placeholder(tf.int32)
m_feed = [[0  ,   0, val,   0, val],
          [val,   0, val, val,   0],
          [0  , val,   0,   0,   0]]

tmp_indices = tf.where(tf.equal(m, val))
result = tf.segment_min(tmp_indices[:, 1], tmp_indices[:, 0])

with tf.Session() as sess:
    print(sess.run(result, feed_dict={m: m_feed})) # [2, 0, 1]

Note that tf.segment_min will raise InvalidArgumentError when there is some row containing no val. In your code row_elems.index(val) will raise exception too when row_elems don't contain val.

like image 102
Jenny Avatar answered Oct 11 '22 02:10

Jenny


Looks a little ugly but works (assuming m and val are both tensors):

idx = list()
for t in tf.unpack(m, axis=0):
    idx.append(tf.reduce_min(tf.where(tf.equal(t, val))))
idx = tf.pack(idx, axis=0)

EDIT: As Yaroslav Bulatov mentioned, you could achieve the same result with tf.map_fn:

def index1d(t):
    return tf.reduce_min(tf.where(tf.equal(t, val)))

idx = tf.map_fn(index1d, m, dtype=tf.int64)
like image 27
Dmitriy Danevskiy Avatar answered Oct 11 '22 02:10

Dmitriy Danevskiy