Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Stateful LSTM - Hidden State transfer between and within batches (Keras)

I've been confused over how the hidden/cell states transfer from within one batch when you have a batch_size > 1, and across batches.

My assumption currently is that hidden states never transfer from one series to another WITHIN a batch. I.e. when batch_size = 3, the hidden state of the first item is not passed to the second.

Next, setting Stateful=True will mean that the hidden state of the 1st element of the 1st batch gets passed to the 1st element of the 2nd batch and so on..

This is what the docs state:

    If True, the last state
    for each sample at index i in a batch will be used as initial
    state for the sample of index i in the following batch.

I have been struggling to find a confirmation of this, and it also doesn't seem very intuitive, because assuming my data is sequential and I've batched my data in batch size of 3, I don't see why I would want the hidden state from the 1st unit to transfer to the 4th unit, and for the 2nd unit to go to the 5th unit, and so on.

like image 589
Michael Du Avatar asked Nov 18 '22 14:11

Michael Du


1 Answers

Yes, it works as described by you above.

And yes you're right, from the perspective of a data stream t[] it is maybe not very intuitive to order the data like that:

  • Batch1 = [t[0], t[3], t[6]]
  • Batch2 = [t[1], t[4], t[7]]
  • Batch3 = [t[2], t[5], t[8]]

But I think it makes completely sense out of a data-processing perspective. This data ordering allow you to completely parallelize/vectorize the processing of all batches. In other words: [t[0], t[3], t[6]] can be processed in parallel, whereas the (sequential) sequence [t[0], t[1], t[2]] could be only processed one after each other.

like image 74
Luuu Avatar answered Dec 20 '22 10:12

Luuu