Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I use Numba for Pytorch tensors?

I am new to Numba and I need to use Numba to speed up some Pytorch functions. But I find even a very simple function does not work :(

import torch
import numba

@numba.njit()
def vec_add_odd_pos(a, b):
    res = 0.
    for pos in range(len(a)):
        if pos % 2 == 0:
            res += a[pos] + b[pos]
    return res

x = torch.tensor([3, 4, 5.])
y = torch.tensor([-2, 0, 1.])
z = vec_add_odd_pos(x, y)

But the following error appears def vec_add_odd_pos(a, b): res = 0. ^

This error may have been caused by the following argument(s):

  • argument 0: cannot determine Numba type of <class 'torch.Tensor'>
  • argument 1: cannot determine Numba type of <class 'torch.Tensor'>

Can anyone help me? A link with more examples would be also appreciated. Thanks.

like image 360
rozyang Avatar asked Jul 30 '20 09:07

rozyang


People also ask

Does Numba work with torch?

numba supports numpy-arrays but not torch's tensors.

Is torch tensor faster than NumPy?

Tensors in CPU and GPU It is nearly 15 times faster than Numpy for simple matrix multiplication!

Does PyTorch use tensors?

PyTorch: TensorsThis implementation uses PyTorch tensors to manually compute the forward pass, loss, and backward pass.


2 Answers

Pytorch now exposes an interface on GPU tensors which can be consumed by numba directly:

numba.cuda.as_cuda_array(tensor)

The test script provides a few usage examples: https://github.com/pytorch/pytorch/blob/master/test/test_numba_integration.py

like image 131
pixelou Avatar answered Sep 20 '22 16:09

pixelou


As others have mentioned, numba currently doesn't support torch tensors, only numpy tensors. However there is TorchScript, which has a similar goal. Your function can then be rewritten as such:

import torch

@torch.jit.script
def vec_add_odd_pos(a, b):
    res = 0.
    for pos in range(len(a)):
        if pos % 2 == 0:
            res += a[pos] + b[pos]
    return res

x = torch.tensor([3, 4, 5.])
y = torch.tensor([-2, 0, 1.])
z = vec_add_odd_pos(x, y)

Beware: although you said your code snippet was just a simple example, for loops are really slow and running TorchScript might not help you much, you should avoid them at any cost and only use then when no other solution exist. That being said, here's how to implement your function in a more performant way:

def vec_add_odd_pos(a, b):
    evenids = torch.arange(len(a)) % 2 == 0
    return (a[evenids] + b[evenids]).sum()
like image 35
Johan Dettmar Avatar answered Sep 21 '22 16:09

Johan Dettmar