I want to make a function f1(arg_tensor) which gets a pytorch tensor as an argument.
In this function I use another function:
f2(tensor_row_1, tensor_row_2) which gets two pytorch's tensor rows as an arguments and outputs a scalar.
f2(..) should be applied over all combinations of tensor's rows [1..n] (i.e. apply function f2(..) on tensor rows' indices: [0,1], [0,2], [0,3]...[0,n-1]...[n-1,0]..[n-1,n-1]).
The output of f1(..) should be a tensor such that at element [0,0] there will the output value of f2(tensor_rows[0], tensor_rows[0]) and so on...
Is there a way to perform it efficiently (and not with double for loop)?
Yes, one can do it with a simple broadcasting trick:
def f1(tensor):
tensor = tensor.permute(1, 0)
return torch.nn.functional.kl_div(
tensor.unsqueeze(dim=2), tensor.unsqueeze(dim=1), reduction="none"
).mean(dim=0)
def manual_f1(tensor):
result = []
for row1 in tensor:
for row2 in tensor:
result.append(torch.nn.functional.kl_div(row1, row2))
return torch.stack(result).reshape(tensor.shape[0], -1)
data = torch.randn(5, 7)
result = f1(data)
manual_result = manual_f1(data)
print(torch.all(result == manual_result).item())
Please notice, for more rows the result will differ due to numerical difference. You can:
print the values and inspect manuallytorch.isclose to verify similarityIn the second case, last print would become:
print(torch.all(torch.isclose(result, manual_result)).item())
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With