Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow tuples with different shapes

I have a problem with returning a tuple of two variable v , wt where v has shape=(20,20) and wt has shape=(1,). wt is a variable that is a weight value. I want to return the tuple (v,wt) inside a map_fn

my code look somewhat close to this

tf.map_fn(fn, nonzeros(Matrix, dim, row))

nonzeros(Matrix, dim, row) returns a (index, value)

the fn will return a tuple but the error output I get is :

ValueError: The two structures don't have the same number of elements. First 
structure: <dtype: 'int64'>, second structure: (<tf.Tensor 
'map_2/while/while/Exit_1:0' shape=(20,) dtype=float32>, <tf.Tensor 
'map_2/while/Sub:0' shape=() dtype=int64>).
like image 995
Marcus Lagerstedt Avatar asked Mar 28 '17 09:03

Marcus Lagerstedt


People also ask

What is a TensorShape?

A TensorShape represents a possibly-partial shape specification for a Tensor . It may be one of the following: Fully-known shape: has a known number of dimensions and a known size for each dimension.

How do you flatten a tensor in TensorFlow?

To flatten the tensor, we're going to use the TensorFlow reshape operation. So tf. reshape, we pass in our tensor currently represented by tf_initial_tensor_constant, and then the shape that we're going to give it is a -1 inside of a Python list.


1 Answers

You are returning the results of a tf.while loop here. the tf.while loop returns a tuple of multiple values, in your case we can see that your while loop returned a value of interest and a counter value as a tuple.

(<tf.Tensor 'map_2/while/while/Exit_1:0' shape=(20,) dtype=float32>, <tf.Tensor 'map_2/while/Sub:0' shape=() dtype=int64>)

What you mean to pass back from map_fn is probably just the first of these two values. So in the code that you haven't shown here you should have something like:

value, counter = tf.while(...)
return value

What you have is:

return tf.while(...)

So the error you see is complaining that an <dtype: 'int64'> doesn't match the tuple you're passing in. When you fix the while loop then you'll be comparing <dtype: 'int64'> to <tf.Tensor 'map_2/while/while/Exit_1:0' shape=(20,) dtype=float32> which presumably are both (20,) and will match (though you might end up with an int/float issue).

like image 116
David Parks Avatar answered Sep 25 '22 15:09

David Parks