Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Clarification on tf.Tensor.set_shape()

Tags:

tensorflow

I have an image that is 478 x 717 x 3 = 1028178 pixels, with a rank of 1. I verified it by calling tf.shape and tf.rank.

When I call image.set_shape([478, 717, 3]), it throws the following error.

"Shapes %s and %s must have the same rank" % (self, other))  ValueError: Shapes (?,) and (478, 717, 3) must have the same rank 

I tested again by first casting to 1028178, but the error still exists.

ValueError: Shapes (1028178,) and (478, 717, 3) must have the same rank 

Well, that does make sense because one is of rank 1 and the other is of rank 3. However, why is it necessary to throw an error, as the total number of pixels still match.

I could of course use tf.reshape and it works, but I think that's not optimal.

As stated on the TensorFlow FAQ

What is the difference between x.set_shape() and x = tf.reshape(x)?

The tf.Tensor.set_shape() method updates the static shape of a Tensor object, and it is typically used to provide additional shape information when this cannot be inferred directly. It does not change the dynamic shape of the tensor.

The tf.reshape() operation creates a new tensor with a different dynamic shape.

Creating a new tensor involves memory allocation and that could potentially be more costly when more training examples are involved. Is this by design, or am I missing something here?

like image 250
jkschin Avatar asked Feb 17 '16 08:02

jkschin


People also ask

What does tensor mean in TensorFlow?

A tensor is a generalization of vectors and matrices to potentially higher dimensions. Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. Each element in the Tensor has the same data type, and the data type is always known.

How do you change the shape of a tensor?

If you want to make a new Tensor with that shape from the contents of t , you can use reshaped_t = tf. reshape(t, (478, 717, 3)) . This creates a new tf. Tensor object in Python; the actual implementation of tf.

How do you find the value of Tensor?

The easiest[A] way to evaluate the actual value of a Tensor object is to pass it to the Session. run() method, or call Tensor. eval() when you have a default session (i.e. in a with tf. Session(): block, or see below).

What is tensor object?

Tensors are multi-dimensional arrays with a uniform type (called a dtype ). You can see all supported dtypes at tf. dtypes. DType . If you're familiar with NumPy, tensors are (kind of) like np.


1 Answers

As far as I know (and I wrote that code), there isn't a bug in Tensor.set_shape(). I think the misunderstanding stems from the confusing name of that method.

To elaborate on the FAQ entry you quoted, Tensor.set_shape() is a pure-Python function that improves the shape information for a given tf.Tensor object. By "improves", I mean "makes more specific".

Therefore, when you have a Tensor object t with shape (?,), that is a one-dimensional tensor of unknown length. You can call t.set_shape((1028178,)), and then t will have shape (1028178,) when you call t.get_shape(). This doesn't affect the underlying storage, or indeed anything on the backend: it merely means that subsequent shape inference using t can rely on the assertion that it is a vector of length 1028178.

If t has shape (?,), a call to t.set_shape((478, 717, 3)) will fail, because TensorFlow already knows that t is a vector, so it cannot have shape (478, 717, 3). If you want to make a new Tensor with that shape from the contents of t, you can use reshaped_t = tf.reshape(t, (478, 717, 3)). This creates a new tf.Tensor object in Python; the actual implementation of tf.reshape() does this using a shallow copy of the tensor buffer, so it is inexpensive in practice.

One analogy is that Tensor.set_shape() is like a run-time cast in an object-oriented language like Java. For example, if you have a pointer to an Object but know that, in fact, it is a String, you might do the cast (String) obj in order to pass obj to a method that expects a String argument. However, if you have a String s and try to cast it to a java.util.Vector, the compiler will give you an error, because these two types are unrelated.

like image 126
mrry Avatar answered Sep 22 '22 06:09

mrry