Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

tf.newaxis operation in TensorFlow

x_train = x_train[..., tf.newaxis].astype("float32")

x_test = x_test[..., tf.newaxis].astype("float32")

Can someone please explain how tf.newaxis works ?

I found a brief mention in the documentation

https://www.tensorflow.org/api_docs/python/tf/strided_slice

but I could not properly understand.

like image 438
Celi Avatar asked Oct 12 '20 05:10

Celi


People also ask

How do you evaluate a tensor in TensorFlow?

The easiest[A] way to evaluate the actual value of a Tensor object is to pass it to the Session. run() method, or call Tensor. eval() when you have a default session (i.e. in a with tf.

What does TF concat do?

concat. Concatenates tensors along one dimension.


2 Answers

Check this example:

a = tf.constant([100])
print(a.shape) ## (1)
expanded_1 = tf.expand_dims(a,axis=1)
print(expanded_1.shape) ## (1,1)
expanded_2 = a[:, tf.newaxis]
print(expanded_2.shape) ## (1,1)

It is similar to expand_dims() which adds a new axis.

If you want to add a new axis at the beginning of the tensor, use

expanded_2 = a[tf.newaxis, :]

otherwise (at the end)

expanded_2 = a[:,tf.newaxis]
like image 80
Timbus Calin Avatar answered Oct 26 '22 22:10

Timbus Calin


You can also add dimensions to your tensor whilst keeping the same information present using tf.newaxis.

# Create a rank 2 tensor (2 dimensions)
rank_2_tensor = tf.constant([[10, 7],
                             [3, 4]])

print("dimension: ", rank_2_tensor.ndim)
print("shape    : ", rank_2_tensor.shape)

output:

dimension: 2
shape: TensorShape([2, 2])

# Add an extra dimension (to the end)
rank_3_tensor = rank_2_tensor[..., tf.newaxis] 
# in Python "..." means "all dimensions prior to"

print("dimension: ", rank_3_tensor .ndim)
print("shape    : ", rank_3_tensor .shape)

output:

dimension: 3
shape: TensorShape([2, 2, 1])

You can achieve the same using tf.expand_dims().

rank_new_3_tensor = tf.expand_dims(rank_2_tensor, axis=-1) # "-1" means last axis
print("dimension: ", rank_new_3_tensor .ndim)
print("shape    : ", rank_new_3_tensor .shape)

output:

dimension: 3
shape: TensorShape([2, 2, 1])

like image 28
Furkan Gulsen Avatar answered Oct 26 '22 22:10

Furkan Gulsen