Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow conditional throwing value error

I am trying to use conditionals with tensorflow and I am getting the error:

ValueError: Shapes (1,) and () are not compatible

Below is the code I use that is throwing the error. It is saying the error is in the conditional

import tensorflow as tf
import numpy as np

X = tf.constant([1, 0])
Y = tf.constant([0, 1])
BOTH = tf.constant([1, 1])
WORKING = tf.constant(1)

def create_mult_func(tf, amount, list):
    def f1():
        return tf.scalar_mul(amount, list)
    return f1

def create_no_op_func(tensor):
    def f1():
        return tensor
    return f1

def stretch(tf, points, dim, amount):
    """points is a 2 by ??? tensor, dim is a 1 by 2 tensor, amount is tensor scalor"""
    x_list, y_list = tf.split(0, 2, points)
    x_stretch, y_stretch = tf.split(1, 2, dim)
    is_stretch_X = tf.equal(x_stretch, WORKING, name="is_stretch_x")
    is_stretch_Y = tf.equal(y_stretch, WORKING, name="is_stretch_Y")
    x_list_stretched = tf.cond(is_stretch_X,
                               create_mult_func(tf, amount, x_list), create_no_op_func(x_list))
    y_list_stretched = tf.cond(is_stretch_Y,
                               create_mult_func(tf, amount, y_list), create_no_op_func(y_list))
    return tf.concat(1, [x_list_stretched, y_list_stretched])

example_points = np.array([[1, 1], [2, 2], [3, 3]], dtype=np.float32)
example_point_list = tf.placeholder(tf.float32)

result = stretch(tf, example_point_list, X, 1)
sess = tf.Session()

with tf.Session() as sess:
    result = sess.run(result, feed_dict={example_point_list: example_points})
    print(result)

Stack trace:

  File "/path/test2.py", line 36, in <module>
    result = stretch(tf, example_point_list, X, 1)
  File "/path/test2.py", line 28, in stretch
    create_mult_func(tf, amount, x_list), create_no_op_func(x_list))
  File "/path/tensorflow/python/ops/control_flow_ops.py", line 1142, in cond
    p_2, p_1 = switch(pred, pred)
  File "/path/tensorflow/python/ops/control_flow_ops.py", line 203, in switch
    return gen_control_flow_ops._switch(data, pred, name=name)
  File "/path/tensorflow/python/ops/gen_control_flow_ops.py", line 297, in _switch
    return _op_def_lib.apply_op("Switch", data=data, pred=pred, name=name)
  File "/path/tensorflow/python/ops/op_def_library.py", line 655, in apply_op
    op_def=op_def)
  File "/path/tensorflow/python/framework/ops.py", line 2156, in create_op
    set_shapes_for_outputs(ret)
  File "/path/tensorflow/python/framework/ops.py", line 1612, in set_shapes_for_outputs
    shapes = shape_func(op)
  File "/path/tensorflow/python/ops/control_flow_ops.py", line 2032, in _SwitchShape
    unused_pred_shape = op.inputs[1].get_shape().merge_with(tensor_shape.scalar())
  File "/path/tensorflow/python/framework/tensor_shape.py", line 554, in merge_with
    (self, other))
ValueError: Shapes (1,) and () are not compatible

I have tried changing the WORKING to be an array instead of a scalar.

I believe that the problem is that tf.equal is returning an int32 instead of the bool that it is supposed to return according to the documentation

like image 449
dtracers Avatar asked Jun 07 '26 01:06

dtracers


1 Answers

The problem is in the first argument to tf.cond. From the documentation here, about the type of the first argument to tf.cond :

pred: A scalar determining whether to return the result of fn1 or fn2.

Note that it has to be a scalar. You are using the result of comparing a tensor and a tensor, which gives you a (1,) tensor, NOT a scalar. You can convert it to a scalar using the tf.reshape operator as follows :

t = tf.equal(x_stretch, WORKING, name="is_stretch_x")
x_list_stretched = tf.cond(tf.reshape(t, []),
                           create_mult_func(tf, amount, x_list), create_no_op_func(x_list))

Complete working program :

import tensorflow as tf
import numpy as np

X = tf.constant([1, 0])
Y = tf.constant([0, 1])
BOTH = tf.constant([1, 1])
WORKING = tf.constant(1)

def create_mult_func(tf, amount, list):
    def f1():
        return tf.scalar_mul(amount, list)
    return f1

def create_no_op_func(tensor):
    def f1():
        return tensor
    return f1

def stretch(tf, points, dim, amount):
    """points is a 2 by ??? tensor, dim is a 1 by 2 tensor, amount is tensor scalor"""
    x_list, y_list = tf.split(0, 2, points)
    x_stretch, y_stretch = tf.split(0, 2, dim)
    is_stretch_X = tf.equal(x_stretch, WORKING, name="is_stretch_x")
    is_stretch_Y = tf.equal(y_stretch, WORKING, name="is_stretch_Y")
    x_list_stretched = tf.cond(tf.reshape(is_stretch_X, []),
                               create_mult_func(tf, amount, x_list), create_no_op_func(x_list))
    y_list_stretched = tf.cond(tf.reshape(is_stretch_Y, []),
                               create_mult_func(tf, amount, y_list), create_no_op_func(y_list))
    return tf.pack([x_list_stretched, y_list_stretched])

example_points = np.array([[1, 1], [2, 2]], dtype=np.float32)
example_point_list = tf.placeholder(tf.float32)

result = stretch(tf, example_point_list, X, 1)
sess = tf.Session()

with tf.Session() as sess:
    result = sess.run(result, feed_dict={example_point_list: example_points})
    print(result)
like image 64
keveman Avatar answered Jun 08 '26 15:06

keveman