I am trying to run the following code:
import matplotlib.pylab as plt
import numpy as np
import torch
import torch.nn as nn
class LSTM(nn.Module):
def __init__(self, input_shape, n_actions):
super(LSTM, self).__init__()
self.lstm = nn.LSTM(input_shape, 12)
self.hidden2tag = nn.Linear(12, n_actions)
def forward(self, x):
out = self.lstm(x)
out = self.hidden2tag(out)
return out
state = [(1,2,3,4,5),(2,3,4,5,6),(3,4,5,6,7),(4,5,6,7,8),(5,6,7,8,9),(6,7,8,9,0)]
device = torch.device("cuda")
net = LSTM(5, 3).to(device)
state_v = torch.FloatTensor(state).to(device)
q_vals_v = net(state_v.view(1, state_v.shape[0], state_v.shape[1]))
_, action = int(torch.max(q_vals_v, dim=1).item())
And that returns this error:
Traceback (most recent call last):
File "/home/dikkerj/Documents/PycharmProjects/LSTMReactor/QuestionStackoverflow.py", line 26, in <module>
q_vals_v = net(state_v.view(1, state_v.shape[0], state_v.shape[1]))
File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/modules/module.py", line 477, in __call__
result = self.forward(*input, **kwargs)
File "/home/dikkerj/Documents/PycharmProjects/LSTMReactor/QuestionStackoverflow.py", line 15, in forward
out = self.hidden2tag(out)
File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/modules/module.py", line 477, in __call__
result = self.forward(*input, **kwargs)
File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/modules/linear.py", line 55, in forward
return F.linear(input, self.weight, self.bias)
File "/home/dikkerj/.local/lib/python3.5/site-packages/torch/nn/functional.py", line 1022, in linear
if input.dim() == 2 and bias is not None:
AttributeError: 'tuple' object has no attribute 'dim'
Anyone knows how to fix this? (to get rid of the tensor being a tuple so that it can be fed into the LSTM network)
The pytorch LSTM returns a tuple.
So you get this error as your linear layer self.hidden2tag
can not handle this tuple.
So change:
out = self.lstm(x)
to
out, states = self.lstm(x)
This will fix your error, by splitting up the tuple so that out
is just your output tensor.
out
then stores the hidden states, while states
is another tuple that contains the last hidden and cell state.
You can also take a look here:
https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM
You will get another error for the last line as max()
returns a tuple as well. But this should be easy to fix and is yet different error :)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With