Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use double as the default type for floating numbers in PyTorch

I want all the floating numbers in my PyTorch code double type by default, how can I do that?

like image 643
Alaya Avatar asked Jan 02 '23 21:01

Alaya


1 Answers

You are looking for torch.set_default_tensor_type:

torch.set_default_tensor_type(torch.DoubleTensor)

Aletrnatively, you can use torch.set_default_dtype:

torch.set_default_dtype(torch.float64)
like image 149
Shai Avatar answered Jan 06 '23 09:01

Shai