Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Computational Complexity of Self-Attention in the Transformer Model

I recently went through the Transformer paper from Google Research describing how self-attention layers could completely replace traditional RNN-based sequence encoding layers for machine translation. In Table 1 of the paper, the authors compare the computational complexities of different sequence encoding layers, and state (later on) that self-attention layers are faster than RNN layers when the sequence length n is smaller than the dimensionality of the vector representations d.

However, the self-attention layer seems to have an inferior complexity than claimed if my understanding of the computations is correct. Let X be the input to a self-attention layer. Then, X will have shape (n, d) since there are n word-vectors (corresponding to rows) each of dimension d. Computing the output of self-attention requires the following steps (consider single-headed self-attention for simplicity):

  1. Linearly transforming the rows of X to compute the query Q, key K, and value V matrices, each of which has shape (n, d). This is accomplished by post-multiplying X with 3 learned matrices of shape (d, d), amounting to a computational complexity of O(n d^2).
  2. Computing the layer output, specified in Equation 1 of the paper as SoftMax(Q Kt / sqrt(d)) V, where the softmax is computed over each row. Computing Q Kt has complexity O(n^2 d), and post-multiplying the resultant with V has complexity O(n^2 d) as well.

Therefore, the total complexity of the layer is O(n^2 d + n d^2), which is worse than that of a traditional RNN layer. I obtained the same result for multi-headed attention too, on considering the appropriate intermediate representation dimensionalities (dk, dv) and finally multiplying by the number of heads h.

Why have the authors ignored the cost of computing the Query, Key, and Value matrices while reporting total computational complexity?

I understand that the proposed layer is fully parallelizable across the n positions, but I believe that Table 1 does not take this into account anyway.

like image 247
Newton Avatar asked Jan 13 '21 13:01

Newton


People also ask

What is the complexity of the self attention layer?

The complexity of the initial convolution is O(n × d2) and the complexity of the self-attention layers become O((n/k)2 × d), where k is the kernel-size of the convolution layer. Hence the overall complexity becomes O(n × d2 + (n/k)2 × d).

What is the time complexity of Lstm?

On the other hand, LSTM is local in space and time [23], which means that the input length does not affect the storage requirements of the network and for each time step, the time complexity per weight is O(1).

What is difference between attention and self attention?

The attention mechanism allows output to focus attention on input while producing output while the self-attention model allows inputs to interact with each other (i.e calculate attention of all other inputs wrt one input.

What is Self attention model?

Self Attention, also called intra Attention, is an attention mechanism relating different positions of a single sequence in order to compute a representation of the same sequence. It has been shown to be very useful in machine reading, abstractive summarization, or image description generation.


2 Answers

First, you are correct in your complexity calculations. So, what is the source of confusion?

When the original Attention paper was first introduced, it didn't require to calculate Q, V and K matrices, as the values were taken directly from the hidden states of the RNNs, and thus the complexity of Attention layer is O(n^2·d).

Now, to understand what Table 1 contains please keep in mind how most people scan papers: they read title, abstract, then look at figures and tables. Only then if the results were interesting, they read the paper more thoroughly. So, the main idea of the Attention is all you need paper was to replace the RNN layers completely with attention mechanism in seq2seq setting because RNNs were really slow to train. If you look at the Table 1 in this context, you see that it compares RNN, CNN and Attention and highlights the motivation for the paper: using Attention should have been beneficial over RNNs and CNNs. It should have been advantageous in 3 aspects: constant amount of calculation steps, constant amount of operations and lower computational complexity for usual Google setting, where n ~= 100 and d ~= 1000. But as any idea, it hit the hard wall of reality. And in reality in order for that great idea to work, they had to add positional encoding, reformulate the Attention and add multiple heads to it. The result is the Transformer architecture which while has the computational complexity of O(n^2·d + n·d^2) still is much faster then RNN (in a sense of wall clock time), and produces better results.

So the answer for your question is that attention layer the authors refer to in Table 1 is strictly the attention mechanism. It is not the complexity of the Transformer. They are very well aware about the complexity of their model (I quote):

Separable convolutions [6], however, decrease the complexity considerably, to O(k·n·d + n·d^2). Even with k = n, however, the complexity of a separable convolution is equal to the combination of a self-attention layer and a point-wise feed-forward layer, the approach we take in our model.

like image 136
igrinis Avatar answered Sep 19 '22 09:09

igrinis


Strictly speaking, when considering the complexity of only the self-attention block (Fig 2 left, equation 1) the projection of x to q, k and v is not included in the self-attention. The complexities shown in table 1 are only for the very core of self-attention layer and thus are O(n^2 d).

like image 45
Shai Avatar answered Sep 20 '22 09:09

Shai