Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What is a tensorflow float ref?

Trying to run the following basic example to run a conditional calculation I got the following error message:

'x' was passed float incompatible with expected float_ref

what is a tensorflow float_ref and how does the code have to be modified?

import tensorflow as tf
from tensorflow.python.ops.control_flow_ops import cond

a = tf.Variable(tf.constant(0.),name="a")
b = tf.Variable(tf.constant(0.),name="b")
x = tf.Variable(tf.constant(0.),name="x")

def add():
    x.assign( a + b)
    return x

def last():
    return x

calculate= cond(x==0.,add,last)

with tf.Session() as s:
    val = s.run([calculate], {a: 1., b: 2., x: 0.})
    print(val) # 3
    val=s.run([calculate],{a:4.,b:5.,x:val})
    print(val) # 3
like image 589
Anona112 Avatar asked Dec 14 '15 11:12

Anona112


3 Answers

FYI. I got a similar error and mine was:

node GradientDescent/update_input/ApplyGradientDescent was passed float from _arg_input_0_1:0 incompatible with expected float_ref.

This happened because somewhere in my node-tree I had a tf.Variable instead of a t.fplaceholder. After replacing the variable with the placeholder, it worked.

like image 135
codesmith Avatar answered Nov 15 '22 14:11

codesmith


float_ref here refers to a reference to a float, i.e. your Tensorflow float variable x.

As explained here you are facing this error because you can't simultaneously assign and pass a variable as a feed_dict in the same session run like you are doing in this statement:

val = s.run([calculate], {a: 1., b: 2., x: 0.})

It becomes more obvious when you resolve that statement to end up with:

val = s.run([x.assign( a + b)], {a: 1., b: 2., x: 0.})
like image 24
reubenjohn Avatar answered Nov 15 '22 14:11

reubenjohn


this doesn't explain what a float_ref is, but it fixes the issues:

1) variables need to be created in the session 2) assignment op was not what we expected

this fixed code works:

def add():
    print("add")
    x = a + b
    return x

def last():
    print("last")
    return x

with tf.Session() as s:
    a = tf.Variable(tf.constant(0.),name="a")
    b = tf.Variable(tf.constant(0.),name="b")
    x = tf.constant(-1.)
    calculate= cond(x.eval()==-1.,add,last)
    val = s.run([calculate], {a: 1., b: 2.})
    print(val) # 3
    print(s.run([calculate],{a:3.,b:4.})) # 7
    print(val) # 3
like image 2
Anona112 Avatar answered Nov 15 '22 13:11

Anona112