Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

AttributeError: 'tuple' object has no attribute 'dim', when feeding input to Pytorch LSTM network

I am trying to run the following code:

import matplotlib.pylab as plt
import numpy as np
import torch
import torch.nn as nn

class LSTM(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(LSTM, self).__init__()

        self.lstm = nn.LSTM(input_shape, 12)
        self.hidden2tag = nn.Linear(12, n_actions)

    def forward(self, x):
        out = self.lstm(x)
        out = self.hidden2tag(out)
        return out


state = [(1,2,3,4,5),(2,3,4,5,6),(3,4,5,6,7),(4,5,6,7,8),(5,6,7,8,9),(6,7,8,9,0)]

device = torch.device("cuda")
net = LSTM(5, 3).to(device)

state_v = torch.FloatTensor(state).to(device)

q_vals_v = net(state_v.view(1, state_v.shape[0], state_v.shape[1]))
_, action = int(torch.max(q_vals_v, dim=1).item())

And that returns this error:

Traceback (most recent call last):
  File "/home/dikkerj/Documents/PycharmProjects/LSTMReactor/QuestionStackoverflow.py", line 26, in <module>
    q_vals_v = net(state_v.view(1, state_v.shape[0], state_v.shape[1]))
  File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/dikkerj/Documents/PycharmProjects/LSTMReactor/QuestionStackoverflow.py", line 15, in forward
    out = self.hidden2tag(out)
  File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/modules/linear.py", line 55, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/functional.py", line 1022, in linear
    if input.dim() == 2 and bias is not None:
AttributeError: 'tuple' object has no attribute 'dim'

Anyone knows how to fix this? (to get rid of the tensor being a tuple so that it can be fed into the LSTM network)

like image 583
JorDik Avatar asked Oct 28 '18 14:10

JorDik


1 Answers

The pytorch LSTM returns a tuple.
So you get this error as your linear layer self.hidden2tag can not handle this tuple.

So change:

out = self.lstm(x)

to

out, states = self.lstm(x)

This will fix your error, by splitting up the tuple so that out is just your output tensor.

out then stores the hidden states, while states is another tuple that contains the last hidden and cell state.

You can also take a look here:
https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM

You will get another error for the last line as max() returns a tuple as well. But this should be easy to fix and is yet different error :)

like image 80
MBT Avatar answered Oct 16 '22 19:10

MBT