Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Theano advanced indexing for tensor, shared index

Tags:

python

theano

I have a tensor probs with probs.shape = (max_time, num_batches, num_labels).

And I have a tensor targets with targets.shape = (max_seq_len, num_batches) where the values are label indices, i.e. for the third dimension in probs.

Now I want to get a tensor probs_y with probs.shape = (max_time, num_batches, max_seq_len) where the third dimension is the index in targets. Basically

probs_y[:,i,:] = probs[:,i,targets[:,i]]

for all 0 <= i < num_batches.

How can I achieve this?

A similar problem with solution was posted here.

The solution there, if I understand correctly, would be:

probs_y = probs[:,T.arange(targets.shape[1])[None,:],targets]

But that doesn't seem to work. I get: IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices.

Also, isn't the creation of the temporal T.arange a bit costly? Esp when I try to workaround by really making it a full dense integer array. There should be a better way.

Maybe theano.map? But as far as I understand, that doesn't parallelize the code, so this is also not a solution.

like image 364
Albert Avatar asked Oct 20 '22 07:10

Albert


1 Answers

This works for me:

import theano
import theano.tensor as T

max_time, num_batches, num_labels = 3, 4, 6
max_seq_len = 5

probs_ = np.arange(max_time * num_batches * num_labels).reshape(
    max_time, num_batches, num_labels)

targets_ = np.arange(num_batches * max_seq_len).reshape(max_seq_len, 
    num_batches) % (num_batches - 1)  # mix stuff up

probs, targets = map(theano.shared, (probs_, targets_))

print probs_
print targets_

probs_y = probs[:, T.arange(targets.shape[1])[:, np.newaxis], targets.T]

print probs_y.eval()

Above used a transposed version of your indices. Your exact proposition also works

probs_y2 = probs[:, T.arange(targets.shape[1])[np.newaxis, :], targets]

print probs_y2.eval()
print (probs_y2.dimshuffle(0, 2, 1) - probs_y).eval()

So maybe your problem is somewhere else.

As for speed, I am at a loss as to what could be faster than this. map, which is a specialization of scan almost certainly is not. I do not know to what extent the arange is actually built rather than simply iterated over.

like image 159
eickenberg Avatar answered Oct 23 '22 09:10

eickenberg