Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pytorch tensor, how to switch channel position - Runtime error

I have my training dataset as below, where X_train is 3D with 3 channels

Shape of X_Train: (708, 256, 3) Shape of Y_Train: (708, 4)

Then I convert them into a tensor and input into the dataloader:

X_train=torch.from_numpy(X_data)
y_train=torch.from_numpy(y_data)
training_dataset = torch.utils.data.TensorDataset(X_train, y_train)
train_loader = torch.utils.data.DataLoader(training_dataset, batch_size=50, shuffle=False)

However when training the model, I get the following error: RuntimeError: Given groups=1, weight of size 24 3 5, expected input[708, 256, 3] to have 3 channels, but got 256 channels instead

I suppose this is due to the position of the channel? In Tensorflow, the channel position is at the end, but in PyTorch the format is "Batch Size x Channel x Height x Width"? So how do I swap the positions in the x_train tensor to match the expected format in the dataloader?

class TwoLayerNet(torch.nn.Module):
    def __init__(self):
        super(TwoLayerNet,self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv1d(3, 3*8, kernel_size=5, stride=1),  
            nn.Sigmoid(),
            nn.AvgPool1d(kernel_size=2, stride=0))
        self.conv2 = nn.Sequential(
            nn.Conv1d(3*8, 12, kernel_size=5, stride=1),
            nn.Sigmoid(),
            nn.AvgPool1d(kernel_size=2, stride = 0))
        #self.drop_out = nn.Dropout()

        self.fc1 = nn.Linear(708, 732) 
        self.fc2 = nn.Linear(732, 4)

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = out.reshape(out.size(0), -1)
        out = self.drop_out(out)
        out = self.fc1(out)
        out = self.fc2(out)
        return out
like image 773
Dametime Avatar asked Jan 08 '20 14:01

Dametime


People also ask

Is PyTorch channel first or last?

PyTorch Best Practice The best way to get the most performance from your PyTorch vision models is to ensure that your input tensor is in a Channels Last memory format before it is fed into the model.

What dim 0?

The first dimension (dim=0) of this 3D tensor is the highest one and contains 3 two-dimensional tensors.

What is stride in PyTorch?

Stride is the jump necessary to go from one element to the next one in the specified dimension dim . A tuple of all strides is returned when no argument is passed in. Otherwise, an integer value is returned as the stride in the particular dimension dim .


1 Answers

Use permute.

X_train = torch.rand(708, 256, 3)
X_train = X_train.permute(2, 0, 1)
X_train.shape
# => torch.Size([3, 708, 256])
like image 144
Coolness Avatar answered Sep 17 '22 18:09

Coolness