Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Elegant Way to Select one Element per Row in Tensorflow

Tags:

tensorflow

Given...

  • a Matrix A of shape [m, n]
  • a tensor I of shape [m]

I want to get a list J of elements from A where J[i] = A[i, I[i]].

That is, I holds the index of the element to select from each row in A.

Context: I already have the argmax(A, 1) and now I also want the max. I know that I can just use reduce_max. And after trying around for a bit I also came up with this:

J = tf.gather_nd(A,
    tf.transpose(tf.pack([tf.to_int64(tf.range(A.get_shape()[0])), I])))

Where the to_int64 is needed because range only produces int32 and argmax only produces int64.

None of the two strike me as particularly elegant. One has runtime overhead (probably about factor n) and the other has an unknown factor cognitive overhead. Am I missing something here?

like image 733
black_puppydog Avatar asked May 04 '16 11:05

black_puppydog


2 Answers

The gather() function provides a way to do it:

r = tf.random.uniform([4,5],0, 9, dtype=tf.int32)
i = tf.random.uniform([4], 0, 4, dtype=tf.int32)
tf.gather(r, i, axis=1, batch_dims=1)
like image 122
Cyril Marlin Avatar answered Nov 13 '22 03:11

Cyril Marlin


This is a rather late answer, but could doing

mask = tf.one_hot(I, depth=n, dtype=tf.bool, on_value=True, off_value=False)
elements = tf.boolean_mask(A, mask)

Accomplish what you're looking for?

edit: I should point out that this is NOT a good idea if A is already a very large tensor, as this ends up making a dense matrix.

like image 2
Peter Kim Avatar answered Nov 13 '22 04:11

Peter Kim