Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Repeating a pytorch tensor without copying memory

Does pytorch support repeating a tensor without allocating significantly more memory?

Assume we have a tensor

t = torch.ones((1,1000,1000))
t10 = t.repeat(10,1,1)

Repeating t 10 times will require take 10x the memory. Is there a way how I can create a tensor t10 without allocating significantly more memory?

Here is a related question, but without answers.

like image 652
mcb Avatar asked Jan 15 '20 18:01

mcb


People also ask

How do you repeat a tensor?

You can use following Einops function. Where b is the number of times you want your tensor to be repeated and h , w the additional dimensions to the tensor. Show activity on this post. Repeated values are memory heavy, in most cases best practice is to use broadcasting.

What does .expand do PyTorch?

Expanding a tensor does not allocate new memory, but only creates a new view on the existing tensor where a dimension of size one is expanded to a larger size by setting the stride to 0.

How do I flatten in PyTorch?

flatten. Flattens input by reshaping it into a one-dimensional tensor. If start_dim or end_dim are passed, only dimensions starting with start_dim and ending with end_dim are flattened.


1 Answers

You can use torch.expand

t = torch.ones((1, 1000, 1000))
t10 = t.expand(10, 1000, 1000)

Keep in mind that the t10 is just a reference to t. So for example, a change to t10[0,0,0] will result in the same change in t[0,0,0] and every member of t10[:,0,0].

Other than direct access, most operations performed on t10 will cause memory to be copied which will break the reference and cause more memory to be used. For example: changing the device (.cpu(), .to(device=...), .cuda()), changing the datatype (.float(), .long(), .to(dtype=...)), or using .contiguous().

like image 107
jodag Avatar answered Oct 16 '22 12:10

jodag