Suppose, I have a 3D tensor A
A = torch.arange(24).view(4, 3, 2)
print(A)
and require masking it using 2D tensor
mask = torch.zeros((4, 3), dtype=torch.int64) # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)
Using masked_select functionality from PyTorch leads to the following error.
torch.masked_select(X, (mask == 1))
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-72-fd6809d2c4cc> in <module>
12
13 # Select based on new mask
---> 14 Y = torch.masked_select(X, (mask == 1))
15 #Y = X * mask_
16 print(Y)
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 2
How to mask a 3D tensor with a 2D mask and keep the dimensions of the original vector? Any hints will be appreciated.
Essentially, we need to match the dimension of the tensor mask with the tensor being masked.
There are two ways to do it.
Approach 1: Does not preserve original tensor dimensions.
X = torch.arange(24).view(4, 3, 2)
print(X)
mask = torch.zeros((4, 3), dtype=torch.int64) # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)
# Add a dimension to the mask tensor and expand it to the size of original tensor
mask_ = mask.unsqueeze(-1).expand(X.size())
print(mask_)
# Select based on the new expanded mask
Y = torch.masked_select(X, (mask_ == 1)) # does not preserve the dims
print(Y)
The output for approach 1:
tensor([ 0, 1, 8, 9, 18, 19])
Approach 2: Preserves the original tensor dimensions (by padding).
X = torch.arange(24).view(4, 3, 2)
print(X)
mask = torch.zeros((4, 3), dtype=torch.int64) # or dtype=torch.ByteTensor
mask[0, 0] = 1
mask[1, 1] = 1
mask[3, 0] = 1
print('Mask: ', mask)
# Add a dimension to the mask tensor and expand it to the size of original tensor
mask_ = mask.unsqueeze(-1).expand(X.size())
print(mask_)
# Select based on the new expanded mask
Y = X * mask_
print(Y)
The output for approach 2:
tensor([[[ 0, 1],
[ 2, 3],
[ 4, 5]],
[[ 6, 7],
[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15],
[16, 17]],
[[18, 19],
[20, 21],
[22, 23]]])
Mask: tensor([[1, 0, 0],
[0, 1, 0],
[0, 0, 0],
[1, 0, 0]])
tensor([[[1, 1],
[0, 0],
[0, 0]],
[[0, 0],
[1, 1],
[0, 0]],
[[0, 0],
[0, 0],
[0, 0]],
[[1, 1],
[0, 0],
[0, 0]]])
tensor([[[ 0, 1],
[ 0, 0],
[ 0, 0]],
[[ 0, 0],
[ 8, 9],
[ 0, 0]],
[[ 0, 0],
[ 0, 0],
[ 0, 0]],
[[18, 19],
[ 0, 0],
[ 0, 0]]]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With