Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Numpy function sum does not work on PyTorch tensors

I am wondering if it is expected that if I call numpy.sum on a PyTorch tensor I get an error. As an aside, it does work on Pandas DataFrames.

data = [[1.1, 2.2], [1.1, 2.2], [1.1, 2.2]]

a = np.array(data)
d = pd.DataFrame(data)
t = torch.tensor(data)

np.sum(a, axis=1) # works obviously
np.sum(d, axis=1) # works
np.sum(t, axis=1) # fails

If you dig a bit inside, you'll see that's quite simply because numpy.sum tries to call torch.Tensor.sum with keyword argumets which do not belong to it, like in the following (https://github.com/numpy/numpy/blob/main/numpy/_core/fromnumeric.py around line 80)

return reduction(axis=axis, out=out, **passkwargs)

which inevitably results in an error.

This seems quite related: https://github.com/numpy/numpy/issues/28024

Can this be considered a bug in numpy? Why can't we just pass the non null arguments to the sum method of the object we are passing to numpy.sum?

I wanted to post an issue on GitHub/numpy but I am not sure I could define this as a bug, and if i click "question" there, then I am told to post here...

like image 837
Pollo Avatar asked Jun 29 '26 03:06

Pollo


1 Answers

Yes, it is expected that numpy.sum() fails on a PyTorch tensor.

The reason is that numpy.sum() tries to call the tensor's own .sum() method and passes along its own specific set of keyword arguments (like axis, out, etc.).

  • Why it works with Pandas: The pandas.DataFrame.sum() method is intentionally designed to be compatible with NumPy and accepts these arguments.

  • Why it fails with PyTorch: The torch.Tensor.sum() method has a different signature and does not accept all of NumPy's arguments, which results in a TypeError.

This is not considered a bug in NumPy. Instead, it's an API incompatibility issue. For seamless integration, the external library (PyTorch) would need to align its method with NumPy's.

The standard solution is to explicitly convert the tensor to a NumPy array first:


np.sum(t.numpy(), axis=1)
like image 156
Pw Wolf Avatar answered Jun 30 '26 16:06

Pw Wolf



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!