Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Updating variable values in tensorflow

Tags:

tensorflow

I've have a basic question about updating the values of tensors via the tensorflow python api.

Consider the code snippet:

x = tf.placeholder(shape=(None,10), ... )
y = tf.placeholder(shape=(None,), ... )
W = tf.Variable( randn(10,10), dtype=tf.float32 )
yhat = tf.matmul(x, W)

Now let's assume I want to implement some sort of algorithm that iteratively updates the value of W (e.g. some optimization algo). This will involve steps like:

for i in range(max_its):
     resid = y_hat - y
     W = f(W , resid) # some update 

the problem here is that W on the LHS is a new tensor, not the W that is used in yhat = tf.matmul(x, W)! That is, a new variable is created and the value of W used in my "model" doesn't update.

Now one way around this would be

 for i in range(max_its):
     resid = y_hat - y
     W = f(W , resid) # some update 
     yhat = tf.matmul( x, W)

which results in the creation of a new "model" for each iteration of my loop !

Is there a better way to implement this (in python) without creating a whole bunch of new models for each iteration of the loop - but instead updating the original tensor W "in-place" so to speak?

like image 234
firdaus Avatar asked Dec 24 '22 23:12

firdaus


2 Answers

Variables have an assign method. Try:W.assign(f(W,resid))

like image 148
aarbelle Avatar answered Dec 30 '22 22:12

aarbelle


@aarbelle's terse answer is correct, I'll expand it a bit in case someone needs more info. The last 2 lines below is used for updating W.

x = tf.placeholder(shape=(None,10), ... )
y = tf.placeholder(shape=(None,), ... )
W = tf.Variable(randn(10,10), dtype=tf.float32 )
yhat = tf.matmul(x, W)

...

for i in range(max_its):
    resid = y_hat - y
    update = W.assign(f(W , resid)) # do not forget to initialize tf variables. 
    # "update" above is just a tf op, you need to run the op to update W.
    sess.run(update)
like image 38
Lance Avatar answered Dec 30 '22 22:12

Lance