Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why do I get a Theano TypeError when trying to update a shared variable?

Tags:

python

theano

I'm trying to run a very simple gradient descent on the function y=x^2. I've tried implementing it using the following code:

import theano
from theano import tensor as T
x = theano.shared(2)
y = x ** 2

dy_dx = T.grad(y, x)
learning_rate = 1
updates = [(x, x - learning_rate * dy_dx)]
fn = theano.function([], [y], updates = updates)

But when I try compile the function "fn", I get the following error:

TypeError: ('An update must have the same type as the original shared 
variable (shared_var=<TensorType(int64, scalar)>, 
shared_var.type=TensorType(int64, scalar), 
update_val=Elemwise{sub,no_inplace}.0, 
update_val.type=TensorType(float64, scalar)).', 'If the difference is 
related to the broadcast pattern, you can call the 
tensor.unbroadcast(var, axis_to_unbroadcast[, ...]) function to remove 
broadcastable dimensions.')

I thought this may be a problem with the learning_rate variable, since it might not be the same type as the x shared variable, but if I modify the code as follows:

updates = [(x, x - dy_dx)]

I still get the same error.

I'm stuck :( Any ideas?

like image 342
orrymr Avatar asked Mar 15 '23 10:03

orrymr


1 Answers

The problem is that your shared variable x has no type specified so one is being inferred. Since the value you provide is a Python integer literal, the type is assumed to be int32. This is a problem because gradients don't work well with integers and so dy_dx is actually a float64. This in turn makes the update value a float64 too. A shared variable can only be updated with a value of the same type (this is the error message) so you have a problem: the shared variable is an int32 but the update is a float64.

One solution is to make the shared variable a float as well. This can be achieved by simply adding a decimal point to the initial value of x.

x = theano.shared(2.)
like image 187
Daniel Renshaw Avatar answered Apr 26 '23 11:04

Daniel Renshaw