Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

In PyTorch, what makes a tensor have non-contiguous memory?

Tags:

pytorch

According to this SO and this PyTorch discussion, PyTorch's view function works only on contiguous memory, while reshape does not. In the second link, the author even claims:

[view] will raise an error on a non-contiguous tensor.

But when does a tensor have non-contiguous memory?

like image 264
jds Avatar asked Mar 05 '23 18:03

jds


1 Answers

This is a very good answer, which explains the topic in the context of NumPy. PyTorch works essentially the same. Its docs don't generally mention whether function outputs are (non)contiguous, but that's something that can be guessed based on the kind of the operation (with some experience and understanding of the implementation). As a rule of thumb, most operations preserve contiguity as they construct new tensors. You may see non-contiguous outputs if the operation works on the array inplace and change its striding. A couple of examples below

import torch

t = torch.randn(10, 10)

def check(ten):
    print(ten.is_contiguous())

check(t) # True

# flip sets the stride to negative, but element j is still adjacent to
# element i, so it is contiguous
check(torch.flip(t, (0,))) # True

# if we take every 2nd element, adjacent elements in the resulting array
# are not adjacent in the input array
check(t[::2]) # False

# if we transpose, we lose contiguity, as in case of NumPy
check(t.transpose(0, 1)) # False

# if we transpose twice, we first lose and then regain contiguity
check(t.transpose(0, 1).transpose(0, 1)) # True

In general, if you have non-contiguous tensor t, you can make it contiguous by calling t = t.contiguous(). If t is contiguous, call to t.contiguous() is essentially a no-op, so you can do that without risking a big performance hit.

like image 124
Jatentaki Avatar answered Apr 30 '23 19:04

Jatentaki