Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do I update a tensor in Pytorch after indexing twice?

Tags:

python

pytorch

I know how to update a tensor after indexing into part of it like this:

import torch

b = torch.tensor([0, 1, 0, 1], dtype=torch.uint8)
b[b] = 2
b
# tensor([0, 2, 0, 2], dtype=torch.uint8)

but is there a way I can update the original tensor after indexing into it twice? E.g.

i = 1
b = torch.tensor([0, 1, 0, 1], dtype=torch.uint8)
b[b][i] = 2
b
# tensor([0, 1, 0, 1], dtype=torch.uint8)

What I'd like is for b to be tensor([0, 1, 0, 2]) at the end. Is there a way to do this?

I know that I can do

masked = b[b]
masked[i] = 2
b[b] = masked
b
# tensor([0, 1, 0, 2], dtype=torch.uint8)

but is there any better way? It seems that this must be inefficient; if masked is very large, I'm updating many locations in b when I've really only changed one.

(In case a different approach than indexing twice would work better, the general problem I have is how to change the value in an original tensor at the ith location of a masked version of that tensor.)

like image 361
Nathan Avatar asked Feb 04 '26 19:02

Nathan


1 Answers

I adopted another solution from here, and compared it to your solution:

Solution:

b[b.nonzero()[i]] = 2

Runtime comparison:

import torch as t
import numpy as np
import timeit


if __name__ == "__main__":

    np.random.seed(12345)
    b = t.tensor(np.random.randint(0,2, [1000]), dtype=t.uint8)
    # inconvenient way to think of a random index halfway that is 1.
    halfway = np.array(list(range(len(b))))[b == 1][len(b[b == 1]) //2]

    runs = 100000

    elapsed1 = timeit.timeit("mask=b[b]; mask[halfway] = 2; b[b] = mask", 
                             "from __main__ import b, halfway", number=runs)

    print("Time taken (original): {:.6f} ms per call".format(elapsed1 / runs))

    elapsed2 = timeit.timeit("b[b.nonzero()[halfway]]=2",
                             "from __main__ import b, halfway", number=runs)

    print("Time taken (improved): {:.6f} ms per call".format(elapsed2 / runs))

Results:

Time taken (original): 0.000096 ms per call
Time taken (improved): 0.000047 ms per call

Results for vector of length 100000

Time taken: 0.010284 ms per call
Time taken: 0.003667 ms per call

So the solutions differ only by factor 2. I'm not sure if this is the optimal solution, but depending on your size (and how often you call the function) it should give you a rough idea of what you're looking at.

like image 78
dennlinger Avatar answered Feb 06 '26 10:02

dennlinger