Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Filter out np.nan values from pytorch 1d tensor

Tags:

python

pytorch

I have a 1d tensor looking kinda like this:

import numpy as np
import torch

my_list = [0, 1, 2, np.nan, np.nan, 4]
tensor = torch.Tensor(my_list)

How do i filter out the nan-values, so it becomes a tensor of size 4?

like image 750
Mathias Byskov Avatar asked Mar 05 '26 10:03

Mathias Byskov


1 Answers

You can use torch.isnan

my_list = [0, 1, 2, np.nan, np.nan, 4]
tensor = torch.Tensor(my_list)

tensor[~torch.isnan(tensor)]
tensor([0., 1., 2., 4.])
like image 149
Dishin H Goyani Avatar answered Mar 07 '26 22:03

Dishin H Goyani