Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I speed up this Keras Attention computation?

I have written a custom keras layer for an AttentiveLSTMCell and AttentiveLSTM(RNN) in line with keras' new approach to RNNs. This attention mechanism is described by Bahdanau where, in an encoder/decoder model a "context" vector is created from all the ouputs of the encoder and the decoder's current hidden state. I then append the context vector, at every timestep, to the input.

The model is being used in to make a Dialog Agent, but is very similar to NMT models in architecture (similar tasks).

However, in adding this attention mechanism, I have slowed down the training of my network 5 fold, and I would really like to know how I could write the part of the code that is slowing it down so much in a more efficient way.

The brunt of the computation is done here:

h_tm1 = states[0]  # previous memory state
c_tm1 = states[1]  # previous carry state

# attention mechanism

# repeat the hidden state to the length of the sequence
_stm = K.repeat(h_tm1, self.annotation_timesteps)

# multiplty the weight matrix with the repeated (current) hidden state
_Wxstm = K.dot(_stm, self.kernel_w)

# calculate the attention probabilities
# self._uh is of shape (batch, timestep, self.units)
et = K.dot(activations.tanh(_Wxstm + self._uh), K.expand_dims(self.kernel_v))

at = K.exp(et)
at_sum = K.sum(at, axis=1)
at_sum_repeated = K.repeat(at_sum, self.annotation_timesteps)
at /= at_sum_repeated  # vector of size (batchsize, timesteps, 1)

# calculate the context vector
context = K.squeeze(K.batch_dot(at, self.annotations, axes=1), axis=1)

# append the context vector to the inputs
inputs = K.concatenate([inputs, context])

in the call method of the AttentiveLSTMCell (one timestep).

The full code can be found here. If it is necessary that I provide some data and ways to interact with the model, then I can do that.

Any ideas? I am, of course, training on a GPU if there is something clever here.

like image 432
modesitt Avatar asked Mar 08 '18 14:03

modesitt


People also ask

What is the attention mechanism How does it improve the translation task?

The idea behind the attention mechanism was to permit the decoder to utilize the most relevant parts of the input sequence in a flexible manner, by a weighted combination of all of the encoded input vectors, with the most relevant vectors being attributed the highest weights.

How is attention score calculated?

Steps to calculating AttentionTake the query vector for a word and calculate it's dot product with the transpose of the key vector of each word in the sequence — including itself. This is the attention score or attention weight . 2. Then divide each of the results by the square root of the dimension of the key vector.


2 Answers

I would recommend training your model using relu rather than tanh, as this operation is significantly faster to compute. This will save you computation time on the order of your training examples * average sequence length per example * number of epochs.

Also, I would evaluate the performance improvement of appending the context vector, keeping in mind that this will slow your iteration cycle on other parameters. If it's not giving you much improvement, it might be worth trying other approaches.

like image 148
mr_snuffles Avatar answered Oct 17 '22 01:10

mr_snuffles


You modified the LSTM class which is good for CPU computation, but you mentioned that you're training on GPU.

I recommend looking into the cudnn-recurrent implementation or further into the tf part that is used. Maybe you can extend the code there.

like image 1
Benedikt Fuchs Avatar answered Oct 17 '22 02:10

Benedikt Fuchs