Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch, apply different functions element-wise

Tags:

python

pytorch

I defined a tensor like this

t_shape = [4, 1]
data = torch.rand(t_shape)

I want to apply different functions to each row.

funcs = [lambda x: x+1, lambda x: x**2, lambda x: x-1, lambda x: x*2]  # each function for each row.

I can do it with the following code

d = torch.tensor([f(data[i]) for i, f in enumerate(funcs)])

How can I do it in a proper way with more advanced APIs defined in PyTorch?

like image 613
GoingMyWay Avatar asked Sep 13 '25 09:09

GoingMyWay


1 Answers

I think your solution is good. But it won't work with any tensor shape. You can slightly modify the solution as follows.

t_shape = [4, 10, 10]
data = torch.rand(t_shape)

funcs = [lambda x: x+1, lambda x: x**2, lambda x: x-1, lambda x: x*2]

# only change the following 2 lines
d = [f(data[i]) for i, f in enumerate(funcs)]
d = torch.stack(d, dim=0) 
like image 157
Wasi Ahmad Avatar answered Sep 16 '25 00:09

Wasi Ahmad