Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Slice multiple slices at once with tensorflow

I am preparing the input tensor for the tensorflow RNN.
Currently I am doing the following

rnn_format = list()
for each in range(batch_size):
    rnn_format.append(tf.slice(input2Dpadded,[each,0],[max_steps,10]))
lstm_input = tf.stack(rnn_format)

Would it be possible to do this at once, without loop, with some tensorflow function?

like image 522
TinyEpic Avatar asked Feb 02 '17 14:02

TinyEpic


1 Answers

As suggested by Peter Hawkins, you can use gather_nd with the appropriate indices to get there.

Your uniform cropping on the inner dimension can simply be done before the call to gather_nd.

Example:

import tensorflow as tf
import numpy as np

sess = tf.InteractiveSession()

# integer image simply because it is more readable to me
im0 = np.random.randint(10, size=(20,20))
im = tf.constant(im0)

max_steps = 3
batch_size = 10

# create the appropriate indices here
indices = (np.arange(max_steps) +
    np.arange(batch_size)[:,np.newaxis])[...,np.newaxis]
# crop then call gather_nd
res = tf.gather_nd(im[:,:10], indices).eval()

# check that the resulting tensors are equal to what you had previously
for each in range(batch_size):
  assert(np.all(tf.slice(im, [each,0],[max_steps,10]).eval() == res[each]))

EDIT

If your slices indices are in a tensor, you simply replace numpy's operations with tensorflow's operations when creating indices:

# indices stored in a 1D array
my_indices = tf.constant([1, 8, 3, 0, 0])
indices = (np.arange(max_steps) +
    my_indices[:,tf.newaxis])[...,tf.newaxis]

Further remarks:

  • indices is created by taking advantage of broadcasting during the addition: arrays are virtually tiled so that their dimensions match. Broadcasting is supported by numpy and by tensorflow in a similar fashion.
  • Ellipsis ... is part of the standard numpy slicing notation, it basically fills all remaining dimensions left by the other slicing indices. So [..., newaxis] is basically equivalent to expand_dims(·, -1).
like image 73
P-Gn Avatar answered Sep 29 '22 09:09

P-Gn