Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

apply a function over all combination of tensor rows in pytorch

Tags:

python

pytorch

I want to make a function f1(arg_tensor) which gets a pytorch tensor as an argument.

In this function I use another function: f2(tensor_row_1, tensor_row_2) which gets two pytorch's tensor rows as an arguments and outputs a scalar.

f2(..) should be applied over all combinations of tensor's rows [1..n] (i.e. apply function f2(..) on tensor rows' indices: [0,1], [0,2], [0,3]...[0,n-1]...[n-1,0]..[n-1,n-1]).

The output of f1(..) should be a tensor such that at element [0,0] there will the output value of f2(tensor_rows[0], tensor_rows[0]) and so on...

Is there a way to perform it efficiently (and not with double for loop)?

like image 270
Gal R Avatar asked Dec 20 '25 20:12

Gal R


1 Answers

Yes, one can do it with a simple broadcasting trick:

def f1(tensor):
    tensor = tensor.permute(1, 0)
    return torch.nn.functional.kl_div(
        tensor.unsqueeze(dim=2), tensor.unsqueeze(dim=1), reduction="none"
    ).mean(dim=0)


def manual_f1(tensor):
    result = []
    for row1 in tensor:
        for row2 in tensor:
            result.append(torch.nn.functional.kl_div(row1, row2))
    return torch.stack(result).reshape(tensor.shape[0], -1)


data = torch.randn(5, 7)

result = f1(data)
manual_result = manual_f1(data)

print(torch.all(result == manual_result).item())

Please notice, for more rows the result will differ due to numerical difference. You can:

  • print the values and inspect manually
  • use torch.isclose to verify similarity

In the second case, last print would become:

print(torch.all(torch.isclose(result, manual_result)).item())
like image 190
Szymon Maszke Avatar answered Dec 23 '25 10:12

Szymon Maszke



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!