If you are even vaguely familiar with the term "ChatGPT" chances are that you might have come across a term called attention.
Attention is what powers "transformers" - the seemingly complex architecture behind large language models (LLMs) like ChatGPT.
This blog attempts to take you through an informal approach of answering the question, "what does attention even mean?"
Before going deeper into the concept of attention, let me quickly tell you what the transformer architecture does in short. Don't worry, this will (maybe) feel like a breeze.
So... a transformer has two main parts:
Now, given some input words making an input sentence, the encoder is responsible for converting the plain-text input words into tokens where each token has a unique id associated with it AND is "represented" by a high-dimensional vector.
Wait, high-dimensional vector? Why?
This is because neural networks and hence machines do not understand text as us humans do, so we need to convert our text into something which neural networks understand very well that is... YES! A vector!
These vector "representations" capture a lot of information about the input words such as:
Condensing all this information into a matrix composed of high-dimensional token vectors is what the encoder does.
For example, a token representing the word "cat" will be encoded as:
A vector in some n-dimensional space
Semantic information? Positional information? Ahhh, I don't understand!
Hey don't worry...
Here I'm shamelessly skipping the fine-details of how the plain-text words are converted into these vector "embeddings" that capture the semantic (1st point) and positional (2nd point) information, since our focus is mainly on attention today. Word embeddings can be a whole topic in itself. But, for now, imagine using magic we convert words into some vector "embeddings".
To be honest, a decoder is kind of similar to an encoder.
During the model training phase, it also takes a sequence of words, similar to an encoder, which are the targets and converts them into vector representations or embeddings similar to the output of an encoder.
BUT...
In the decoder, attention is calculated in a slightly different way than the encoder.
Essentially,
Bruh, what are those words surrounding attention?
I know, I know you are lost and honestly I was too but self-attention, masked self-attention, and encoder-decoder attention are attention mechanisms which we will go through today.
For simplicity just know that the encoder provides the decoder with some embeddings, and decoder uses those embeddings (along with its own embeddings) to generate an output token, one at a time.
A diagram of the original transformer architecture will be helpful here:
Have a good look at this... it is taken from the paper "Attention is all you need" which first introduced the concept of transformers.
The diagram is a bit cryptic but the block on the left with two sub-parts is the encoder and the block on the right with three sub-parts is the decoder.
There are arrows flowing from the encoder to the decoder which are just the "representations" of the inputs whose attention is calculated with the partial "representations" of the decoder (2nd sub-part from the bottom inside the decoder)
While reading this paragraph you are involuntarily focussing and "attending" to some particular words MORE than other words.
This helps your human (if you are one) brain to form complex relationships between the words even when they are far apart within a paragraph. Since you are attentive, you implicitly know which words are more "useful" and which words are not that "useful" in this sentence.
The intuition behind attention in transformers is analogous to this.
What attention aims to do is: it calculates a "score" which we can call the usefulness score for each token with other tokens of two different sentences OR even the same sentence (self-attention).
After this, the initial vector representations of the tokens are transformed into final vector representations based on the usefulness score, hoping that it captures the usefulness information i.e. which tokens to "attend" to and which tokens are more "useful".
Can I use a "word" instead of a "token" in order to think in a simple manner?
Yes, for simplicity you can think of a token as a word but in a more general sense, a word can/cannot be made up of two or more tokens for example:
Using the above intuition of the usefulness score of tokens, we can think about what self-attention is and what encoder-decoder attention is.
In self-attention, the usefulness scores are calculated between the tokens of one sentence (representations), and the tokens of the SAME sentence (same representations).
In encoder-decoder attention, the usefulness scores are calculated between the encoder representations, and the decoder representations i.e. the sentences will be different here.
OK, now that I have loosely explained the intuition behind attention, let's think of how this thing is calculated.
Imagine you had two vector "representations" or vectors in short:
and,
In some n-dimensional vector space, then we can calculate the similarity between the two vectors using their dot product which is:
Since we represent tokens (or words) using their vector representations, the value of the dot product will tell us how similar those tokens are in their embeddings vector space.
The original paper "Attention is all you need" converts each token's vector representation into three more vectors called Query, Key, and Value.
Each token has its query vector, key vector, and value vector
But why do they add three more vectors? Why query, key, and value? Why not something else?
Good! You are asking the right questions. Essentially, the usage of Query, Key, and Value and hence attention, is hugely analogous to retrieval systems. And the calculation of attention is done using these three vectors.
Okay...
Consider we have a database having some keys and their corresponding values. Imagine the database to represent some knowledge where the keys are "topics" and the values are the "information" we need about the corresponding "topic".
{K₁, V₁
K₂, V₂
K₃, V₃
..., ...}
Now you as a user wants to query the database to fetch some important "information" (V) based on your "query" (Q).
In order to do that, we need to find which "topics" (K) in the database are similar to your "query" (Q)
Once we find the "topics" (K) that are similar to your "query" (Q) we can sort them according to their relevance and then process their corresponding "information" (V).
Oh! I see a pattern here... so if I have two sentences and I want to find the "usefulness" or "attention" of one word in the first sentence with respect to the words of the second sentence, we can create a "query" out of that one word. We can also create "keys" and "values" for every other word in the second sentence, in order to find the usefulness, which is very similar to retrieval systems! Damn.
Bingo! I presume that's why the original transformers paper introduce the three vectors Query, Key, and Value for each token.
Consider we have a sentence made up of tokens:
S = t₁, t₂, t₃..., tₙ
tᵢ = [u₁, u₂, ..., uₘ] i ∈ [1, 2, ..., n]
Where each token is represented using a vector in some m-dimensional vector space. From the vector representation of each token, we create the query, key, and value vector:
Qᵢ = [q₁, q₂, ..., qₖ], Kᵢ = [k₁, k₂, ..., kₖ], Vᵢ = [v₁, v₂, ..., vᵥ]
Note: here the dimension of the query and the key vectors is equal to some value k, while the dimension of the value vector is equal to some value v.
The imaginary tokens database for the second sentence (having z tokens) now looks like:
{t₁' → [k₁₁, k₁₂, ..., k₁ₖ], [v₁₁, v₁₂, ..., v₁ᵥ]
t₂' → [k₂₁, k₂₂, ..., k₂ₖ], [v₂₁, v₂₂, ..., v₂ᵥ]
..., ...
tₖ' → [kₖ₁, kₖ₂, ..., kₖₖ], [vₖ₁, vₖ₂, ..., vₖᵥ]}
If we want to calculate the attention of the first token of the first sentence with respect to all the tokens of the second sentence, we create the query Q for the first token as:
Q₁ = [q₁, q₂, ..., qₖ]
Now, from the above intuition of retrieval systems:
{Q₁ · [k₁₁, k₁₂, ..., k₁ₖ] = x₁, [v₁₁, v₁₂, ..., v₁ᵥ]
Q₁ · [k₂₁, k₂₂, ..., k₂ₖ] = x₂, [v₂₁, v₂₂, ..., v₂ᵥ]
..., ...
Q₁ · [kₖ₁, kₖ₂, ..., kₖₖ] = xₖ, [vₖ₁, vₖ₂, ..., vₖᵥ]}
In order to "sort" them i.e. to find the weightage of which keys are the most useful, we can pass the above vector to a softmax function:
softmax(x) = softmax(Qᵢ · Kⱼ)
i ∈ [1, 2, ..., n], j ∈ [1, 2, ..., z]
Thus:
attention₁ = softmax(x)V = [y₁, y₂, ..., yₖ] [v₁₁ v₁₂ ... v₁ᵥ
v₂₁ v₂₂ ... v₂ᵥ
⋮ ⋮ ⋱ ⋮
vₖ₁ vₖ₂ ... vₖᵥ] → shape(1, v)
Here the vector [y₁, y₂, ..., yₖ] is the result of the softmax function over the vector x.
This gives us the "attention": vector representation for the query (first token) with respect to the tokens in the second sentence!
Now that we have an understanding of attention for one token with respect to a different sentence or even the same sentence, let me give the formal definition of attention from the original paper "Attention is all you need":
An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.
Think about it... so much information condensed in this one single definition!
If we want to calculate attention of all tokens of one sentence with respect to all tokens of the second sentence, we can replace the query, key, and value vectors with matrices:
t₁ [q₁₁ q₁₂ ... q₁ₖ] [k₁₁ k₁₂ ... k₁ₖ] [v₁₁ v₁₂ ... v₁ᵥ]
t₂ [q₂₁ q₂₂ ... q₂ₖ] [k₂₁ k₂₂ ... k₂ₖ] [v₂₁ v₂₂ ... v₂ᵥ]
⋮ → Q ⋮ ⋮ ⋱ ⋮ , K ⋮ ⋮ ⋱ ⋮ , V ⋮ ⋮ ⋱ ⋮
tₙ [qₙ₁ qₙ₂ ... qₙₖ] [kₙ₁ kₙ₂ ... kₙₖ] [vₙ₁ vₙ₂ ... vₙᵥ]
We can now express dot products between the queries and the keys in the form of matrix multiplication; product of the query matrix with the transpose of the key matrix.
The calculation of attention then becomes:
With appropriate matrices for query, keys, and values.
However, the paper also introduces a "scaling term" in order to keep the problem of exploding/vanishing gradients during backpropagation under control. The scaling term is taken to be equal to the square root of the dimension of query and key vectors i.e. the square root of value k.
Thus our final equation for "self-attention" becomes:
Phew! That was a lot to uncover.
Here's how the shapes of the matrices fit in the above equation:
# consider self-attention, ignore scaling term
# initial vector U of size (n, m)
Q = (n, k)
K = (n, k)
V = (n, v)
Q * K.T = (n, k) * (k, m) = (n, m)
softmax(Q * K.T) = (n, m)
softmax(Q * K.T) * V = (n, m) * (m, v) = (n, v)
# final representations have the shape (n, v)
Finally, the below figure will help you visualize how attention is calculated:
Once you understand what self-attention is, the concept of masked self-attention will feel like a walk in the park.
Basically...
Self-attention calculates the attention of a particular token with respect to all the other tokens (including itself) in the same sentence.
Masked self-attention, on the other hand, calculates the attention of a particular token with respect to all the tokens preceding it (including itself) in the same sentence. Which means, the tokens that come after that one particular token are "masked" i.e. the attention with them is not calculated.
If I remember correctly, the encoder calculates self-attention and the decoder calculates masked self-attention right?
Yes, this is because the decoder is responsible for "predicting" which token could come next after encountering a particular token. Also, to limit the decoder to not foresee the future tokens already, it calculates masked self-attention rather than self-attention.
We discussed about the query, keys, and values vectors, but you might have a question about how do we convert our initial vector representations to these vectors.
It's simple...
We multiply the initial vector representations with weight matrices to get the query, key, and values vectors respectively. You can think of these weights as the weights of a simple feed-forward neural network layer (without adding any bias). So, in terms of pytorch:
# define the dimensions n, m, k, and v
# initial vector of size (n, m)
WQ = nn.Linear(m, k, bias = False)
WK = nn.Linear(m, k, bias = False)
WV = nn.Linear(m, v, bias = False)
In the previous sections, we calculated the attention of tokens with only one Query, and only one Key-Value pairs
The authors of the original transformers paper proposed that we can have different heads of attention, where each head has its own Query, and Key-Value pairs which are derived from the original query, and key-value matrices.
So in this method, the authors multiply the original query, and key-value pairs again with h different weight matrices WQ', WK', and WV' with appropriate dimensions to get h different Query, and Key-Value pairs.
Adding more heads can be beneficial if we want to find which tokens to attend in different representation subspaces. As described in the paper:
Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions.
The attention (as explained above) is calculated for these h different queries, keys and value the output of which, is then concatenated in the end to form a matrix of shape (n, h * v).
The output of the concatenation is also passed through a weight matrix WO of shape (h * v, m) to get the final output vector representations.
So the calculation for multi-head attention is:
multihead(Q, K, V) = Concat(head₁, ..., headₕ)W^O
where,
headᵢ = attention(QW^Q_i, KW^K_i, VW^V_i)
Here the term "attention" is the scaled dot-product attention which I explained above and the weight matrices WQ, WK, and WV are the learnable parameters for each head.
The resulting output vector representations captures the multi-head attention information within them which are then processed further.
Finally, the below figure visualizes what multi-head attention does:
exhales
Okay... that was a lot to take but hopefully now you have some intuition behind the concept of attention mechanisms in transformers which powers the large language models (LLMs) like ChatGPT.
Do note that, there can be several forms of attention which might provide "better" results such as sliding window attention, but my intention behind this blog post was to approach the concept of transformers attention in a very informal yet intuitive manner.
Thank you for reading this anon, see you very soon!
Links of the materials that helped me A LOT in order to write this blog post: