Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch [1 if x > 0.5 else 0 for x in outputs ] with tensors

I have a list outputs from a sigmoid function as a tensor in PyTorch

E.g

output (type) = torch.Size([4]) tensor([0.4481, 0.4014, 0.5820, 0.2877], device='cuda:0',

As I'm doing binary classification I want to turn all values bellow 0.5 to 0 and above 0.5 to 1.

Traditionally with a NumPy array you can use list iterators:

output_prediction = [1 if x > 0.5 else 0 for x in outputs ]

This would work, however I have to later convert output_prediction back to a tensor to use

torch.sum(ouput_prediction == labels.data)

Where labels.data is a binary tensor of labels.

Is there a way to use list iterators with tensors?

like image 424
Brian Formento Avatar asked Sep 19 '19 02:09

Brian Formento


People also ask

How do you compare two tensors in PyTorch?

To compare two tensors element-wise in PyTorch, we use the torch. eq() method. It compares the corresponding elements and returns "True" if the two elements are same, else it returns "False".

How do you value a tensor PyTorch?

Define a PyTorch tensor. Access the value of a single element at particular index using indexing or access the values of sequence of elements using slicing. Modify the accessed values with new values using the assignment operator. Finally, print the tensor to check if the tensor is modified with the new values.

What does .item do in PyTorch?

. item() ensures that you append only the float values to the list rather the tensor itself. You are basically converting a single element tensor value to a python number. This should not affect the performance in any way.

What does torch tensor () mean?

A torch. Tensor is a multi-dimensional matrix containing elements of a single data type.


Video Answer


1 Answers

prob = torch.tensor([0.3,0.4,0.6,0.7])

out = (prob>0.5).float()
# tensor([0.,0.,1.,1.])

Explanation: In pytorch, you can directly use prob>0.5 to get a torch.bool type tensor. Then you can convert to float type via .float().

like image 67
zihaozhihao Avatar answered Oct 27 '22 09:10

zihaozhihao