Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Torch squeeze and the batch dimension

Tags:

pytorch

torch

does anyone here know if the torch.squeeze function respects the batch (e.g. first) dimension? From some inline code it seems it does not.. but maybe someone else knows the inner workings better than I do.

Btw, the underlying problem is that I have tensor of shape (n_batch, channel, x, y, 1). I want to remove the last dimension with a simple function, so that I end up with a shape of (n_batch, channel, x, y).

A reshape is of course possible, or even selecting the last axis. But I want to embed this functionality in a layer so that I can easily add it to a ModuleList or Sequence object.

EDIT: just found out that for Tensorflow (2.5.0) the function tf.linalg.diag DOES respect batch dimension. Just a FYI that it might differ per function you are using

like image 270
zwep Avatar asked Mar 03 '23 17:03

zwep


2 Answers

No! squeeze doesn't respect the batch dimension. It's a potential source of error if you use squeeze when the batch dimension may be 1. Rule of thumb is that only classes and functions in torch.nn respect batch dimensions by default.

This has caused me headaches in the past. I recommend using reshape or only using squeeze with the optional input dimension argument. In your case you could use .squeeze(4) to only remove the last dimension. That way nothing unexpected happens. Squeeze without the input dimension has led me to unexpected results, specifically when

  1. the input shape to the model may vary
  2. batch size may vary
  3. nn.DataParallel is being used (in which case batch size for a particular instance may be reduced to 1)
like image 51
jodag Avatar answered Mar 29 '23 07:03

jodag


Accepted answer is sufficient for the problem - to squeeze last dimension. However, I had tensor of dimension (batch, 1280, 1, 1) and wanted (batch, 1280). Squeeze function didn't allow for that - squeeze(tensor, 1).shape -> (batch, 1280, 1, 1) and squeeze(tensor, 2).shape -> (batch, 1280, 1). I could have used squeeze two times, but you know, aesthetics :).

What helped me was torch.flatten(tensor, start_dim = 1) -> (batch, 1280). Trivial, but I forgot about it. Warning though, this function my create a copy instead view, so be careful.

https://pytorch.org/docs/stable/generated/torch.flatten.html

like image 45
Krzysztof Przygodzki Avatar answered Mar 29 '23 05:03

Krzysztof Przygodzki