Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Masking tensor of same shape in PyTorch

Tags:

python

pytorch

Given an array and mask of same shapes, I want the masked output of the same shape and containing 0 where mask is False.

For example,

# input array
img = torch.randn(2, 2)
print(img)
# tensor([[0.4684, 0.8316],
#        [0.8635, 0.4228]])
print(img.shape)
# torch.Size([2, 2])

# mask
mask = torch.BoolTensor(2, 2)
print(mask)
# tensor([[False,  True],
#        [ True,  True]])
print(mask.shape)
# torch.Size([2, 2])

# expected masked output of shape 2x2
# tensor([[0, 0.8316],
#        [0.8635, 0.4228]])

Issue: The masking changes the shape of the output as follows:

#1: shape changed
img[mask]
# tensor([0.8316, 0.8635, 0.4228])
like image 459
kHarshit Avatar asked Oct 23 '19 11:10

kHarshit


People also ask

How do you get the shape of the PyTorch tensor?

To get the shape of a tensor as a list in PyTorch, we can use two approaches. One using the size() method and another by using the shape attribute of a tensor in PyTorch.

What does .item do in PyTorch?

Returns the value of this tensor as a standard Python number. This only works for tensors with one element.

What does view (- 1 do in PyTorch?

It'll modify the tensor metadata and will not create a copy of it.


2 Answers

Simply type-cast your boolean mask to an integer mask, followed by float to bring the mask to the same type as in img. Perform element-wise multiplication afterwards.

masked_output = img * mask.int().float()

like image 122
Anant Mittal Avatar answered Oct 12 '22 23:10

Anant Mittal


One of the ways I found to solve it was:

img[mask==False] = 0

or using

img[~mask] = 0

It'll change the img itself.

like image 33
kHarshit Avatar answered Oct 13 '22 01:10

kHarshit