Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

tf.cast equivalent in pytorch?

I am new to PyTorch. TensorFlow has an API tf.cast() and tf.shape(). the tf.cast has specific purpose in TensorFlow, is there anything equivalent in torch? i have tensor x= tensor(shape(128,64,32,32)): tf.shape(x) create tensor of dimension 1 x.shape create the true dimension. i need to use tf.shape(x) in torch.

tf.cast has a different role than just changing tensor dtype in torch.

did anyone have equivalent API in torch/PyTorch.

like image 826
SarwarKhan Avatar asked Dec 02 '19 04:12

SarwarKhan


People also ask

What is TF cast?

The "tf. cast" function casts a tensor to new type. The operation "cast" support the data types of int32, int64, float16, float32, float64, complex64, complex128, bfloat16, uint8, uint16, uint32, uint64, int8, int16. Only the real part of "x" is returned in case of casting from complex types to real types.

What is Dtype in PyTorch?

dtype. A torch.dtype is an object that represents the data type of a torch.Tensor . PyTorch has twelve different data types: Data type.

How do you convert a float to a tensor?

You just need to cast Tensor constant to numpy object,then can access by index. Adding . detach() before . numpy() helped me.


1 Answers

Check out the PyTorch Documentation

As they mentioned:

print(x.dtype) # Prints "torch.int64", currently 64-bit integer type
x = x.type(torch.FloatTensor)
print(x.dtype) # Prints "torch.float32", now 32-bit float
print(x.float()) # Still "torch.float32"
print(x.type(torch.DoubleTensor)) # Prints "tensor([0., 1., 2., 3.], dtype=torch.float64)"
print(x.type(torch.LongTensor)) # Cast back to int-64, prints "tensor([0, 1, 2, 3])"
like image 142
furkanayd Avatar answered Oct 17 '22 16:10

furkanayd