Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Truncated Backpropagation Through Time (BPTT) in Pytorch

In pytorch, I train a RNN/GRU/LSTM network by starting the Backpropagation (Through Time) with :

loss.backward()

When the sequence is long, I'd like to do a Truncated Backpropagation Through Time instead of a normal Backpropagation Through Time where the whole sequence is used.

But I can't find in the Pytorch API any parameters or functions to set up the truncated BPTT. Did I miss it? Am I supposed to code it myself in Pytorch ?

like image 383
u2gilles Avatar asked Nov 07 '22 22:11

u2gilles


1 Answers

Here is an example:

for t in range(T):
   y = lstm(y)
   if T-t == k:
      out.detach()
out.backward()

So in this example, k is the parameter you use to control the timesteps you want to unroll.

like image 128
angerhang Avatar answered Nov 23 '22 23:11

angerhang