Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Theano switch row-wise efficiently

I have the following code

output = T.switch(cond, a, b)

where cond is a (N,1) bool Tensor, while a and b are (N, M) numeric Tensors with M being quite large. The condition operates on a row-wise manner.

I can easily make the switch work by running T.repeat() on cond, but this is quite slow. Is there a way I can efficiently make the bools in cond decide whether a or b should be returned?

like image 291
pir Avatar asked May 17 '17 02:05

pir


1 Answers

Is there a way I can efficiently make the bools in cond decide whether a or b should be returned?

Yes, you could do

cond * a + (1-cond) * b

cond will be broadcast to (N, M) shape.

This should be close to the theoretical limit, which is the memory bandwidth: this operation needs to read about N*M elements and write N*M.

Instead, we read 2*N*M, but remove the conditional logic.

(I don't have Theano in front of me, so I am not sure if it's faster than T.switch, but it should be about as good as it gets. Also, I'd try casting cond to the same dtype as a and b)


If you want to update a in-place, you can do it using T.set_subtensor:

a = np.random.uniform(size=(N, M)).astype(np.float32)
b = np.random.uniform(size=(N, M)).astype(np.float32)

a = theano.shared(a)
b = theano.shared(b)

c = T.vector() # mostly 0, presumably (1-cond)

nz = T.nonzero(c)

s = T.set_subtensor(a[nz], b[nz])
fn = theano.function([c], [], updates=[(a, s)])

...

fn(1-cond)

It may or may not be faster than the first approach, depending on N, M and other factors.

like image 166
MWB Avatar answered Sep 27 '22 19:09

MWB