Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TypeError: view() takes at most 2 arguments (3 given)

I try to use view() in pytorch but i can't input 3 arguments.I don't know why it keep giving this error? Can anyone help me with this?

    def forward(self, input):
        lstm_out, self.hidden = self.lstm(input.view(len(input), self.batch_size, -1))
like image 630
Dinh Nha Nguyen Avatar asked Dec 22 '22 23:12

Dinh Nha Nguyen


1 Answers

It looks like your input is a numpy array, not torch tensor. You need to convert it first, like input = torch.Tensor(input).

like image 91
Sergii Dymchenko Avatar answered May 12 '23 19:05

Sergii Dymchenko