In pytorch, I train a RNN/GRU/LSTM network by starting the Backpropagation (Through Time) with :
loss.backward()
When the sequence is long, I'd like to do a Truncated Backpropagation Through Time instead of a normal Backpropagation Through Time where the whole sequence is used.
But I can't find in the Pytorch API any parameters or functions to set up the truncated BPTT. Did I miss it? Am I supposed to code it myself in Pytorch ?
Here is an example:
for t in range(T):
y = lstm(y)
if T-t == k:
out.detach()
out.backward()
So in this example, k
is the parameter you use to control the timesteps you want to unroll.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With