I have a mask active that tracks batches that still have not terminated in a recurrent process. It's dimension is [batch_full,], and it's true entries show which elements need to still be used in current step. The recurrent process generates another mask, terminated, which has as many elements as true values in active mask. Now, I want to take values from ~terminated and put them back into active, but at the right indices. Basically I want to do:
import torch
active = torch.ones([4,], dtype=torch.bool)
active[:2] = torch.tensor(False)
terminated = torch.tensor([True, False])
active[active] = ~terminated
print(active) # expected [F, F, F, T]
However, I get error:
RuntimeError: unsupported operation: some elements of the input tensor and the written-to tensor refer to a single memory location. Please clone() the tensor before performing the operation.
How can I do the described above operation in an effective way?
There are a few solutions, I will also give their speed as measured by timeit, 10k repetitions, on 2021 macbook pro.
The simplest solution, taking 0.260s:
active[active.clone()] = ~terminated
We can use masked_scatter_ inplace operation for abt. 2x speedup (0.136s):
active.masked_scatter_(
active,
~terminated,
)
Out of place operation, taking 0.161s, would be:
active = torch.masked_scatter(
active,
active,
~terminated,
)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With