Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can i solve backward() got an unexpected keyword argument 'retain_variables'?

Tags:

python

pytorch

I write the code following below but I got this error:

TypeError: backward() got an unexpected keyword argument 'retain_variables'

My code is:

def learn(self, batch_state, batch_next_state, batch_reward, batch_action):
    outputs = self.model(batch_state).gather(1, batch_action.unsqueeze(1)).squeeze(1)
    next_outputs = self.model(batch_next_state).detach().max(1)[0]
    target = self.gamma*next_outputs + batch_reward
    td_loss = F.smooth_l1_loss(outputs, target)
    self.optimizer.zero_grad()
    td_loss.backward(retain_variables = True)
    self.optimizer.step()
like image 579
Hmm.. Avatar asked Apr 07 '19 23:04

Hmm..


2 Answers

I was having the same problem. This solution worked for me.

td_loss.backward(retain_graph = True)

It worked.

like image 156
Jagrit Bhupal Avatar answered Nov 14 '22 22:11

Jagrit Bhupal


As a_guest mentions in the comments:

It should be retain_graph=True.

like image 45
2 revs, 2 users 75% Avatar answered Nov 14 '22 22:11

2 revs, 2 users 75%