Transformer Attention Block, Explained Simply

Uri Almog
6 min readJun 26, 2024

--

Two events in recent years where disruptive in the area of large language models, or LLMs for short. The first one was the publication of Google’s paper Attention Is All You Need, Vaswani et al. in 2017, that laid the foundations for the transformer architecture, that did to natural language processing (NLP) what convolution networks did to computer vision. The second one was the launch of OpenAI’s transformer-based chat-bot, also known as Chat-GPT (Generative Pre-trained Transformer), on November 2022. This short post will address the structure of the transformer block.

Language Models

Just as computer vision models are tasked with various chores like object detection, segmentation, face recognition etc., common tasks for large language models include question answering, code generation and sentiment analysis. All three tasks in this list require that the model gains some insight on the context through the connections between words. The difference between the first two and the third, is that while sentiment analysis is a discriminative task, question answering and coding are generative tasks, that is, the model is required to generate text.

One way to learn to generate text (either a reply to a question or code), is to provide it with context (the text up to the token it is required to predict), and train to predict the next token. This is done iteratively, one token at a time. Some models also train to predict masked tokens in the middle of their given context, and not just the end.

Transformers — An Overview

A context in natural language can be thought of as a time series. Tokens appear in a certain order, and can relate to each other even if they are very far apart. Consider the following text:

“The woman is the sole survivor, and wanders off in a daze; she hides in an apartment when the occupant leaves with suitcases.”

(from the Wikipedia page on David Lynch’s fantastic film Mullholland Drive). The word ‘she’ in the second part refers to ‘the woman’ in the first part, and ‘leaves’ is understood to refer to the ‘apartment’. The word ‘apartment’ puts the word ‘leaves’ in context, despite their distance in the text, and vice versa. We the readers understand that the occupant is leaving the apartment, and that the apartment is, therefore, empty. We also understand that the suitcases belong to the leaving occupant, and not to ‘the woman’. Consider that this is just a simple example; longer texts, such as stories or essays, can easily contain tokens with strong connection, despite being separated by hundreds or even thousands of tokens. CNNs rely on the assumption (that is true for images), that the most meaningful information for understanding a given piece of the data is found in the adjacent pieces of data (e.g. neighboring pixels). They are perfect for detecting relations between neighboring regions of the series, but have a limited receptive field. Therefore, they have a hard time learning relations between distant parts of a text.

The purpose of the attention mechanism at the heart of the transformer is to modify the vector representation of each token, based on the aggregated effect of all the other tokens on it, thus capturing some insight on the context, which cannot necessarily be articulated in human terms (this is analogous to the intermediate output of filters in a CNN operating on an image).

This is done by simultaneously computing each token’s connection, or attention (will be elaborated shortly) with all the other tokens, and then modifying each token in a way that gives the proper weight to its connection with every token. The result is summed with the pre-attention representation of the token. In that way, the learned effect of the attention is just the shift from the original representation to the new representation, which is a simpler task than fully learning a new representation.

If we follow our previous example, the attention between the word ‘She’ and all the other words is computed. The model learns (through great many examples), that the word ‘woman’ has a stronger connection with ‘She’ than the other words, although the other words, like ‘hides’, may have effect as well. Now, when the word ‘She’ is processed, its representation is modified due to its relation with the other words, most predominantly, the word ‘woman’.

With each token undergoing this procedure, they end up encoding insights regarding the context. These insights can be utilized in the following manner: If the context contains a question, then the next token is a special token (let’s call it STRT), telling the model that it needs to fill in the blank. The model repeats the process explained above, with each token donating its attention to STRT and ending up with a vector pointing in a new direction. This new direction is near some real text token, and that token (for example), is then selected to be the next token in the model’s answer to the question.

In the following section we will go over the process at the heart of the transformer — the attention head. For brevity, we will not discuss token embedding mechanisms, positional encoding, and the encoder-decoder structure.

Transformer Structure And Flow

Fig. 1 — A vector representation of a token, and the Q, K and V matrices. Image by the author.

At the heart of the transformer architecture is the attention head. This block receives a vector representing the input token, and learns three matrices: Q (for Query), K (for Key), and V (for Value). See Fig. 1. The names of these matrices will become clearer once we see how they are used:

Let’s assume that the token vector size is Nx1, the Q, K, V matrix dimensions are kxN, kxN, NxN.

The Attention mechanism consists of these 7 steps (with minor omissions):

  1. The Q matrix operates on the token vector to create a new kx1 vector, called the query.
  2. The K vector operates on all the other token vectors in the context, creating a kx1 vector (called a key), for each token.
  3. The inner product (which is a scalar) is calculated between the token query vector and each of the other token keys.
  4. The inner products are softmaxed, so that they are all positive sum up to 1. These values now represent the weight of attention between each token and the token currently being processed.
  5. The V matrix operates on each token to create a Nx1 vector (value vector). The value vectors are multiplied each by its corresponding weight, and summed, to create a weighted average value.
  6. The weighted average value is summed with the original (pre-attention) token, normalized, and fed to a feed-forward network (FFN, which is just a dense layer).
  7. The pre-FFN vector is added to the post-FFN vector and the result is normalized.

Typically, an attention layer will include multi-head attention blocks: These are parallel triplets of Q, K, V matrices that operate simultaneously and give rise to different alterations of the original token, in equivalence to the different features in CNNs. The results of the multiple heads are concatenated and passed through a FFN to bring it back to the required vector length for the next layer.

Fig. 2 — Full Transformer structure. By Yuening Jia — DOI:10.1088/1742–6596/1314/1/012186, CC BY-SA 3.0, https://commons.wikimedia.org/w/index.php?curid=121340680

Summary

  1. Transformers use a mechanism called self attention to process tokens and predict the next token in a context.
  2. The query, key and value vectors are generated as a result of the multiplication of the learned matrices (Q, K and V) by the input token vector.
  3. The role of the query and key vectors is to give the proper weight to the relevance of each token to the token currently being modified.
  4. The role of the value vector is to shift the token’s vector representation in the value vector’s new direction. The amount of shift is determined both by the size of the value vector and by its corresponding weight.
  5. A model can learn to predict the next token by using a special token as an input, passing it through the attention process along with the preceding context, and using the output vector as an input to a simple dense layer with a softmax activation, that learns the probability for each token to be the next.

I hope this explanation was clear! you are welcome to text me questions and read some of my other posts:

  1. predicting-drug-resistance-in-mycobacterium-tuberculosis-using-a-convolutional-network
  2. Generative Adversarial Networks (GANs), Explained And Demonstrated
  3. YOLOv3 Explained

--

--

Uri Almog

A machine learning engineer writing about computer vision. Loves to trek, scubadive, visual arts and combinations thereof. https://www.linkedin.com/in/urialmog/