I know how to update a tensor after indexing into part of it like this:
import torch
b = torch.tensor([0, 1, 0, 1], dtype=torch.uint8)
b[b] = 2
b
# tensor([0, 2, 0, 2], dtype=torch.uint8)
but is there a way I can update the original tensor after indexing into it twice? E.g.
i = 1
b = torch.tensor([0, 1, 0, 1], dtype=torch.uint8)
b[b][i] = 2
b
# tensor([0, 1, 0, 1], dtype=torch.uint8)
What I'd like is for b to be tensor([0, 1, 0, 2]) at the end. Is there a way to do this?
I know that I can do
masked = b[b]
masked[i] = 2
b[b] = masked
b
# tensor([0, 1, 0, 2], dtype=torch.uint8)
but is there any better way? It seems that this must be inefficient; if masked is very large, I'm updating many locations in b when I've really only changed one.
(In case a different approach than indexing twice would work better, the general problem I have is how to change the value in an original tensor at the ith location of a masked version of that tensor.)
I adopted another solution from here, and compared it to your solution:
Solution:
b[b.nonzero()[i]] = 2
Runtime comparison:
import torch as t
import numpy as np
import timeit
if __name__ == "__main__":
np.random.seed(12345)
b = t.tensor(np.random.randint(0,2, [1000]), dtype=t.uint8)
# inconvenient way to think of a random index halfway that is 1.
halfway = np.array(list(range(len(b))))[b == 1][len(b[b == 1]) //2]
runs = 100000
elapsed1 = timeit.timeit("mask=b[b]; mask[halfway] = 2; b[b] = mask",
"from __main__ import b, halfway", number=runs)
print("Time taken (original): {:.6f} ms per call".format(elapsed1 / runs))
elapsed2 = timeit.timeit("b[b.nonzero()[halfway]]=2",
"from __main__ import b, halfway", number=runs)
print("Time taken (improved): {:.6f} ms per call".format(elapsed2 / runs))
Results:
Time taken (original): 0.000096 ms per call
Time taken (improved): 0.000047 ms per call
Results for vector of length 100000
Time taken: 0.010284 ms per call
Time taken: 0.003667 ms per call
So the solutions differ only by factor 2. I'm not sure if this is the optimal solution, but depending on your size (and how often you call the function) it should give you a rough idea of what you're looking at.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With