I've been confused over how the hidden/cell states transfer from within one batch when you have a batch_size > 1, and across batches.
My assumption currently is that hidden states never transfer from one series to another WITHIN a batch. I.e. when batch_size = 3, the hidden state of the first item is not passed to the second.
Next, setting Stateful=True will mean that the hidden state of the 1st element of the 1st batch gets passed to the 1st element of the 2nd batch and so on..
This is what the docs state:
If True, the last state
for each sample at index i in a batch will be used as initial
state for the sample of index i in the following batch.
I have been struggling to find a confirmation of this, and it also doesn't seem very intuitive, because assuming my data is sequential and I've batched my data in batch size of 3, I don't see why I would want the hidden state from the 1st unit to transfer to the 4th unit, and for the 2nd unit to go to the 5th unit, and so on.
Yes, it works as described by you above.
And yes you're right, from the perspective of a data stream t[] it is maybe not very intuitive to order the data like that:
But I think it makes completely sense out of a data-processing perspective. This data ordering allow you to completely parallelize/vectorize the processing of all batches. In other words: [t[0], t[3], t[6]] can be processed in parallel, whereas the (sequential) sequence [t[0], t[1], t[2]] could be only processed one after each other.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With