Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Theano: Why does indexing fail in this case?

I'm trying to get the max of a vector given a boolean value.

With Numpy:

>>> this = np.arange(10)
>>> this[~(this>=5)].max()
4

But with Theano:

>>> that = T.arange(10, dtype='int32')
>>> that[~(that>=5)].max().eval()
9
>>> that[~(that>=5).nonzero()].max().eval()
Traceback (most recent call last):
  File "<pyshell#146>", line 1, in <module>
    that[~(that>=5).nonzero()].max().eval()
AttributeError: 'TensorVariable' object has no attribute 'nonzero'

Why does this happen? Is this a subtle nuance that i'm missing?

like image 653
Noob Saibot Avatar asked May 31 '13 01:05

Noob Saibot


1 Answers

You are using a version of Theano that is too old. In fact, tensor_var.nonzero() isn't in any released version. You need to update to the development version.

With the development version I have this:

>>> that[~(that>=5).nonzero()].max().eval()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: bad operand type for unary ~: 'tuple'

This is because you are missing parenthesis in your line. Here is the good line:

>>> that[(~(that>=5)).nonzero()].max().eval()
array(9, dtype=int32)

But we still have unexpected result! The problem is that Theano do not support bool. Doing ~ on int8, is doing the bitwise invert on 8 bits, not 1 bit. It give this result:

>>> (that>=5).eval()
array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], dtype=int8)
>>> (~(that>=5)).eval()
array([-1, -1, -1, -1, -1, -2, -2, -2, -2, -2], dtype=int8)

You can remove the ~ with this:

>>> that[(that<5).nonzero()].max().eval()
array(4, dtype=int32)
like image 159
nouiz Avatar answered Nov 15 '22 20:11

nouiz