In the simple attention mechanism, the attention weights are computed deterministically from the input context. We call the combination of context-free embedding (eg word2vec) and positional embedding, the input embedding. What we would like to do is to Given the input embedding, the output embedding has no information about the data distribution that the input token is part of outside the context.Lets look at the grammatical structure of the following sentences:
“I love bears”
Subject: I
Verb: love
Object: bears
“She bears the pain”
Subject: She
Verb: bears
Object: the pain
“Bears won the game”
Subject: Bears
Verb: won
Object: the game
Each sentence above follows a subject-verb-object structure, where the subject performs the action expressed by the verb on the object. You can notice that bears plays a different function in each sentence: as an object, as a verb, and as a subject. In addition, we have other and more complicated language patterns, such as this:
“The hiking trail led us through bear country.”
Subject: “The hiking trail”
Verb: “led”
Object: “us”
Prepositional Phrase: “through bear country”
Preposition: “through”
Object of the Preposition: “bear country”
Noun: “country”
Adjective: “bear” (modifying “country”)
Here, “bear” serves as an adjective describing the type of country.The central question that one can pose now is this: Since the location of the token seems highly correlated with its function, as captured by the language patterns above, how can we help the embedding mapping become location aware and function aware ?
This is done by projecting the input embeddings using a projection matrix W obtaining a new coordinate system whose axes can be associated with object, verb, subject, adjective etc. The semantic meaning of such new coordinate system is not necessary to be explicitly defined but it helps understand what this projection may achieve.We define three such trainable matrices Wq,Wk,Wv and ask each token (or equivalently the matrix X) to project itself to the three different spaces.Q=XWqK=XWkV=XWvwhere X is the input embedding matrix of size T×d where T is the input sequence length and d is the embedding dimension. Q,K,V are the query, keys and values respectively and the dimensions of the Wq,Wk,Wv matrices are d×dq,d×dk,d×dv respectively. Queries and keys occupy typically the same dimensional space (dk).Each token is undertaking three roles, lets focus here on the first two:In a query role the current token effectively seeks to find other functions eg ‘adjective’
In a key role the current token effectively expresses its own function eg ‘noun’.For example: ‘I am key=noun’ and I am seeking for a earlier query=adjective’.The premise that that after training, the attention mechanism will be able to reveal the keys of the input context that can best respond to the query.Let us now recall what we saw already during the word2vec construction: we trained a network that will take one-hot vectors of semantically similar tokens that were orthonormal and projected them to vectors that are close to each other in the embedding space. So we have seen evidence that a projection with proper weights can cause all sorts of interesting mappings to happen from a large dimensional space to a lower space. By analogy, the multiplication of the matrix Wq with the input token embedding will create a vector (a point) in the d_k dimensional space that will represent the query. Similarly the multiplication of the matrix Wk with each and every input token embedding will create vectors (a point) in the d_k dimensional space that will represent the keys. After training the keys that can best respond to the query will end up close to it.
Lets see an example of the projection of the input embeddings to the query and key spaces for the input context.
Given matrices X and W(q) (the key parameter matrix is analogous) where:X=x1,1x2,1⋮x16,1x1,2x2,2⋮x16,2x1,3x2,3⋮x16,3x1,4x2,4⋮x16,4x1,5x2,5⋮x16,5x1,6x2,6⋮x16,6x1,7x2,7⋮x16,7x1,8x2,8⋮x16,8W(q)=w1,1w2,1⋮w8,1w1,2w2,2⋮w8,2w1,3w2,3⋮w8,3w1,4w2,4⋮w8,4The resulting matrix Q, where Q=X×W(q), will have dimensions 16×8. Each element qi,j of Q is computed as:qi,j=xi,1⋅w1,j+xi,2⋅w2,j+xi,3⋅w3,j+xi,4⋅w4,jNotice something not obvious earlier: the matrix W(q) allows to weigh differently some features of the input embedding (across the d dimensions) than others when it calculates the query and key vectors.Lets now proceed with the evolved dot product that now is done at the query-key dq=dk space.“The hiking trail led us through bear country.” where T=8 Now that we have projected the tokens in their new space we can form the generalized dot product(Wqxi)TWkxj=xiT(Wq)TWkxj=xiTWxjGeometrically you can visualize this as shown below:After training the keys that can the most to change the query will end up close to it. The actual change of the query is done by the values.The scores are then given in matrix form by:S=QKT
We divide the scores by the square root of the dimension of the key vector (dk). This is done in to prevent the softmax from saturating on the higher attention score elements and severely attenuating the attention weights that correspond to the lower attention scores.We can do an experiment to see the behavior of softmax.
Copy
import numpy as np# Creating an 8-element numpy vector with random gaussian valuesvector = np.random.randn(8)# Softmax functiondef softmax(x): e_x = np.exp(x - np.max(x)) # Stability improvement by subtracting the max return e_x / e_x.sum()# Applying softmax to the vectorsoftmax_vector = softmax(vector)softmax_vector
Lets plot the two results - the first case is when for the original vector and the second case is when the original vector is element-wise multiplied by 100.Multiply the attention scores by 100 and then pass them through a softmax. You will see that the softmax will output a vector of values that are either very close to 0 or 1.The division by the dk prevents this behavior.The code for the scaled dot product attention is shown below.
When we decode we do not want to use the attention scores of the future tokens since we dont want to train the tranformer using ground truth that will simply wont be available during inference.To prevent this from happening we mask the attention scores of the future tokens by setting them to −∞ before passing them through the softmax. This will cause the softmax to output a vector of values that are very close to 0 for the future tokens - for query qi, only keys k1 through ki are accessible. All keys kj with j>i are masked (i.e., set to −∞).q1q2q3q4q5q6q7q8k1q1k1q2k1q3k1q4k1q5k1q6k1q7k1q8k1k2−∞q2k2q3k2q4k2q5k2q6k2q7k2q8k2k3−∞−∞q3k3q4k3q5k3q6k3q7k3q8k3k4−∞−∞−∞q4k4q5k4q6k4q7k4q8k4k5−∞−∞−∞−∞q5k5q6k5q7k5q8k5k6−∞−∞−∞−∞−∞q6k6q7k6q8k6k7−∞−∞−∞−∞−∞−∞q7k7q8k7k8−∞−∞−∞−∞−∞−∞−∞q8k8
The dot-product terms will be positive or negative numbers and as also in the case of the simpler version of the attention mechanism, we will pass them through a softmax function column-wise to obtain the attention weights αij for each of the tokens.
We then use the attention weights to create a weighted sum of each of the values to obtain the output embedding.v^i=∑j=1Tαijvjwhere αij is the attention weight of the j−th token of the input sequence for the i−th value of the input sequence of length T.What purpose the values play though and why the Wv matrix ?The values are the actual information that the input token will use to update its embedding. The Wv matrix is used to project the input tokens to values (points) in a dv dimensional space. There is no reason to make the dimensionality of the value space the same as the dimensionality of the key space but typically they are the same. We use the value projection (V) as a way to decouple the determination of the attention weights from the actual adjustment of their specific embeddings.As an example, in the context “The hiking trail led us through bear country”, if the key represents the adjective of an input token that responded to a noun query, the value will represent the specific adjective (bear) that adjusts the specific noun (country) and makes it a bear country vector.Note that masking is not shown in this figure. Also vector subspaces maintain the same dimensions throughout.Closing, the overall equation for the scaled self-attention can be formulated as:X^=Attention(Q,K,V)=softmax(dkQKT)Vand the output has dimension T×dv where T is the number of tokens in the input sequence and dv is the dimension of the value vectors.The dimensions of the tensors can also be extended to accommodate the batch dimension.
Its worthwhile highlighting the difference between a dense layer and the self-attention mechanism. In a dense layer, if the layer learns to apply eg a very small weight at the input, it does so for all data that are fed into the layer. In the self-attention mechanism if one of the attention weights is for example very small, this is purely due to the specific data being present at that moment.
An example of the output of the scaled dot-product self-attention is shown below using the bertviz library.
Copy
from bertviz import head_viewfrom transformers import AutoModelmodel = AutoModel.from_pretrained(model_ckpt, output_attentions=True)sentence_a = "time flies like an arrow"sentence_b = "fruit flies like a banana"viz_inputs = tokenizer(sentence_a, sentence_b, return_tensors='pt')attention = model(**viz_inputs).attentionssentence_b_start = (viz_inputs.token_type_ids == 0).sum(dim=1)tokens = tokenizer.convert_ids_to_tokens(viz_inputs.input_ids[0])head_view(attention, tokens, sentence_b_start, heads=[8])
Links between the tokens in the input sequence.This visualization shows the attention weights as lines connecting the token whose embedding is getting updated (left) with every word that is being attended to (right). The intensity of the lines indicates the strength of the attention weights, with dark lines representing values close to 1, and faint lines representing values close to 0.The end result is that the token ‘flies’ will receive the context of ‘soars’ in one sentence and the context of ‘insect’ in the other sentence.Output of Bertviz showing the attention weights of the token ‘flies’ in two different sentences. The left side shows the token being updated, while the right side shows the tokens it attends to.