Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What's the difference between torch.stack() and torch.cat() functions?

OpenAI's REINFORCE and actor-critic example for reinforcement learning has the following code:

REINFORCE:

policy_loss = torch.cat(policy_loss).sum() 

actor-critic:

loss = torch.stack(policy_losses).sum() + torch.stack(value_losses).sum() 

One is using torch.cat, the other uses torch.stack.

As far as my understanding goes, the doc doesn't give any clear distinction between them.

I would be happy to know the differences between the functions.

like image 221
Gulzar Avatar asked Jan 22 '19 11:01

Gulzar


People also ask

What is torch cat function?

torch. cat (tensors, dim=0, *, out=None) → Tensor. Concatenates the given sequence of seq tensors in the given dimension. All tensors must either have the same shape (except in the concatenating dimension) or be empty. torch.cat() can be seen as an inverse operation for torch.

What is torch stack?

PyTorch torch. stack() method joins (concatenates) a sequence of tensors (two or more tensors) along a new dimension. It inserts new dimension and concatenates the tensors along that dimension. This method joins the tensors with the same dimensions and shape.

Does torch stack create new tensor?

torch. stack creates a NEW dimension, and all provided tensors must be the same size.


1 Answers

stack

Concatenates sequence of tensors along a new dimension.

cat

Concatenates the given sequence of seq tensors in the given dimension.

So if A and B are of shape (3, 4), torch.cat([A, B], dim=0) will be of shape (6, 4) and torch.stack([A, B], dim=0) will be of shape (2, 3, 4).

like image 145
Jatentaki Avatar answered Oct 19 '22 03:10

Jatentaki