I am trying to use conditionals with tensorflow and I am getting the error:
ValueError: Shapes (1,) and () are not compatible
Below is the code I use that is throwing the error. It is saying the error is in the conditional
import tensorflow as tf
import numpy as np
X = tf.constant([1, 0])
Y = tf.constant([0, 1])
BOTH = tf.constant([1, 1])
WORKING = tf.constant(1)
def create_mult_func(tf, amount, list):
def f1():
return tf.scalar_mul(amount, list)
return f1
def create_no_op_func(tensor):
def f1():
return tensor
return f1
def stretch(tf, points, dim, amount):
"""points is a 2 by ??? tensor, dim is a 1 by 2 tensor, amount is tensor scalor"""
x_list, y_list = tf.split(0, 2, points)
x_stretch, y_stretch = tf.split(1, 2, dim)
is_stretch_X = tf.equal(x_stretch, WORKING, name="is_stretch_x")
is_stretch_Y = tf.equal(y_stretch, WORKING, name="is_stretch_Y")
x_list_stretched = tf.cond(is_stretch_X,
create_mult_func(tf, amount, x_list), create_no_op_func(x_list))
y_list_stretched = tf.cond(is_stretch_Y,
create_mult_func(tf, amount, y_list), create_no_op_func(y_list))
return tf.concat(1, [x_list_stretched, y_list_stretched])
example_points = np.array([[1, 1], [2, 2], [3, 3]], dtype=np.float32)
example_point_list = tf.placeholder(tf.float32)
result = stretch(tf, example_point_list, X, 1)
sess = tf.Session()
with tf.Session() as sess:
result = sess.run(result, feed_dict={example_point_list: example_points})
print(result)
Stack trace:
File "/path/test2.py", line 36, in <module>
result = stretch(tf, example_point_list, X, 1)
File "/path/test2.py", line 28, in stretch
create_mult_func(tf, amount, x_list), create_no_op_func(x_list))
File "/path/tensorflow/python/ops/control_flow_ops.py", line 1142, in cond
p_2, p_1 = switch(pred, pred)
File "/path/tensorflow/python/ops/control_flow_ops.py", line 203, in switch
return gen_control_flow_ops._switch(data, pred, name=name)
File "/path/tensorflow/python/ops/gen_control_flow_ops.py", line 297, in _switch
return _op_def_lib.apply_op("Switch", data=data, pred=pred, name=name)
File "/path/tensorflow/python/ops/op_def_library.py", line 655, in apply_op
op_def=op_def)
File "/path/tensorflow/python/framework/ops.py", line 2156, in create_op
set_shapes_for_outputs(ret)
File "/path/tensorflow/python/framework/ops.py", line 1612, in set_shapes_for_outputs
shapes = shape_func(op)
File "/path/tensorflow/python/ops/control_flow_ops.py", line 2032, in _SwitchShape
unused_pred_shape = op.inputs[1].get_shape().merge_with(tensor_shape.scalar())
File "/path/tensorflow/python/framework/tensor_shape.py", line 554, in merge_with
(self, other))
ValueError: Shapes (1,) and () are not compatible
I have tried changing the WORKING to be an array instead of a scalar.
I believe that the problem is that tf.equal is returning an int32 instead of the bool that it is supposed to return according to the documentation
The problem is in the first argument to tf.cond. From the documentation here, about the type of the first argument to tf.cond :
pred: A scalar determining whether to return the result of fn1 or fn2.
Note that it has to be a scalar. You are using the result of comparing a tensor and a tensor, which gives you a (1,) tensor, NOT a scalar. You can convert it to a scalar using the tf.reshape operator as follows :
t = tf.equal(x_stretch, WORKING, name="is_stretch_x")
x_list_stretched = tf.cond(tf.reshape(t, []),
create_mult_func(tf, amount, x_list), create_no_op_func(x_list))
Complete working program :
import tensorflow as tf
import numpy as np
X = tf.constant([1, 0])
Y = tf.constant([0, 1])
BOTH = tf.constant([1, 1])
WORKING = tf.constant(1)
def create_mult_func(tf, amount, list):
def f1():
return tf.scalar_mul(amount, list)
return f1
def create_no_op_func(tensor):
def f1():
return tensor
return f1
def stretch(tf, points, dim, amount):
"""points is a 2 by ??? tensor, dim is a 1 by 2 tensor, amount is tensor scalor"""
x_list, y_list = tf.split(0, 2, points)
x_stretch, y_stretch = tf.split(0, 2, dim)
is_stretch_X = tf.equal(x_stretch, WORKING, name="is_stretch_x")
is_stretch_Y = tf.equal(y_stretch, WORKING, name="is_stretch_Y")
x_list_stretched = tf.cond(tf.reshape(is_stretch_X, []),
create_mult_func(tf, amount, x_list), create_no_op_func(x_list))
y_list_stretched = tf.cond(tf.reshape(is_stretch_Y, []),
create_mult_func(tf, amount, y_list), create_no_op_func(y_list))
return tf.pack([x_list_stretched, y_list_stretched])
example_points = np.array([[1, 1], [2, 2]], dtype=np.float32)
example_point_list = tf.placeholder(tf.float32)
result = stretch(tf, example_point_list, X, 1)
sess = tf.Session()
with tf.Session() as sess:
result = sess.run(result, feed_dict={example_point_list: example_points})
print(result)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With