Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I change part of a PyTorch tensor based on the values of another tensor?

Tags:

python

pytorch

This question may not be clear, so please ask for clarification in the comments and I will expand.

I have the following tensors of the following shape:

mask.size() == torch.Size([1, 400])
clean_input_spectrogram.size() == torch.Size([1, 400, 161])
output.size() == torch.Size([1, 400, 161])

mask is comprised only of 0 and 1. Since it's a mask, I want to set the elements of output equal to clean_input_spectrogram where that relevant mask value is 1.

How would I do that?

like image 727
Shamoon Avatar asked Jan 19 '26 05:01

Shamoon


1 Answers

You can do something like this, where:

  • m is your mask;
  • x is your spectogram;
  • o is your output;
import torch
torch.manual_seed(2020)

m = torch.tensor([[0, 1, 0]]).to(torch.int32)
x = torch.rand((1, 3, 2))
o = torch.rand((1, 3, 2))

print(o)
# tensor([[[0.5899, 0.8105],
#          [0.2512, 0.6307],
#          [0.5403, 0.8033]]])
print(x)
# tensor([[[0.4869, 0.1052],
#          [0.5883, 0.1161],
#          [0.4949, 0.2824]]])

o[:, m[0].to(torch.bool), :] = x[:, m[0].to(torch.bool), :]
# or 
# o[:, m[0] == 1, :] = x[:, m[0] == 1, :]

print(o)
# tensor([[[0.5899, 0.8105],
#          [0.5883, 0.1161],
#          [0.5403, 0.8033]]])
like image 66
Berriel Avatar answered Jan 20 '26 19:01

Berriel