Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to cast int32 tensor to float32

How can I cast an int32 tensor to float32 in tensorflow. I don't understand what tf.cast does. It does not seem to do anything.

import tensorflow as tf
import numpy as np

tf.enable_eager_execution()

a = tf.constant([[1, 2, 3, 4], [1, 2, 3, 4]])
b = tf.cast(a, dtype=tf.float32)

print(tf.shape(a))
print(tf.shape(b))

outputs;

tf.Tensor([2 4], shape=(2,), dtype=int32) #a   
tf.Tensor([2 4], shape=(2,), dtype=int32) #b
like image 853
CrazyBrazilian Avatar asked Jul 29 '19 00:07

CrazyBrazilian


People also ask

How do I change the datatype in TensorFlow?

In Tensorflow 2, you can cast the datatype of a tensor to a new datatype by using the tf. cast function.

What is TF cast?

The "tf. cast" function casts a tensor to new type. The operation "cast" support the data types of int32, int64, float16, float32, float64, complex64, complex128, bfloat16, uint8, uint16, uint32, uint64, int8, int16. Only the real part of "x" is returned in case of casting from complex types to real types.

What is TF Where?

tf. where will return the indices of condition that are non-zero, in the form of a 2-D tensor with shape [n, d] , where n is the number of non-zero elements in condition ( tf. count_nonzero(condition) ), and d is the number of axes of condition ( tf. rank(condition) ).


1 Answers

If you just use;

print(a)
print(b)

you shall get the correct result;

tf.Tensor(
[[1 2 3 4]
 [1 2 3 4]], shape=(2, 4), dtype=int32) #a
tf.Tensor(
[[1. 2. 3. 4.]
 [1. 2. 3. 4.]], shape=(2, 4), dtype=float32) #b

So the tf.cast() works as intended!


With tf.shape() you are getting a result that explains the shape details of input.

Returns: A Tensor of type out_type.

out_type: (Optional) The specified output type of the operation (int32 or int64). Defaults to tf.int32

So the dtype of tf.shape() result, is the dtype of the resulting "shape detailing tensor", not of a, or b

like image 108
buræquete Avatar answered Sep 27 '22 21:09

buræquete