Notations#
Dimensions and Indexing#
This section outlines the common dimensions and indexing conventions utilized in the Transformer model.
\(\mathcal{B}\): The minibatch size.
\(D\): Embedding dimension. In the original Transformer paper, this is represented as \(d_{\text{model}}\).
\(d\): Index within the embedding vector, where \(0 \leq d < D\).
\(T\): Sequence length.
\(t\): Positional index of a token within the sequence, where \(0 \leq t < T\).
\(V\): Size of the vocabulary.
\(j\): Index of a word in the vocabulary, where \(0 \leq j < V\).
General Notations#
Elementwise and Vectorwise Operations#
Element-wise operations like dropout or activation functions are applied to each element of a tensor independently. For example, applying the ReLU activation function to a tensor \(\mathbf{X} \in \mathbb{R}^{B \times T \times D}\) results in a tensor of the same shape, where the ReLU function is applied to each element of \(\mathbf{X}\) independently (i.e. you can think of it as applying the ReLU a total of \(B \times T \times D\) times).
For vector-wise operations, the operation is applied to each vector along a specific dimension or axis of the tensor. For example, applying layer normalization to a tensor \(\mathbf{X} \in \mathbb{R}^{B \times T \times D}\) will apply the normalization operation to each vector along the feature dimension \(D\) independently. This means that the normalization operation is applied to each vector of size \(D\) independently across all batches and sequence positions. You can then think of the normalization operation as being applied a total of \(B \times T\) times.
Vocabulary#
\(\mathcal{V}\): The set of all unique words in the vocabulary, defined as:
where
\(V\) (denoted as \(|\mathcal{V}|\)): The size of the vocabulary.
\(w_j\): A unique word in the vocabulary \(\mathcal{V}\), where \(w_j \in \mathcal{V}\).
\(j\): The index of a word in \(\mathcal{V}\), explicitly defined as \(0 \leq j \leq V\).
For example, consider the following sentences in the training set:
“cat eat mouse”
“dog chase cat”
“mouse eat cheese”
The resulting vocabulary \(\mathcal{V}\) is:
where
\(V = 6\).
\(w_1 = \text{cat}, w_2 = \text{eat}, w_3 = \text{mouse}, w_4 = \text{dog}, w_5 = \text{chase}, w_6 = \text{cheese}\).
\(j = 1, 2, \ldots, 6\).
Note: Depending on the transformer model, special tokens (e.g., [PAD]
,
[CLS]
, [BOS]
, [EOS]
, [UNK]
, etc.) may also be included in \(\mathcal{V}\).
Input Sequence#
The input sequence \(\mathbf{x}\) for a GPT model is defined as a sequence of \(T\) tokens. Each token in this sequence is typically represented as an integer that corresponds to a position in the vocabulary set \(\mathcal{V}\). The sequence is represented as:
where
\(T\): Total length of the sequence. It denotes the number of tokens in the sequence \(\mathbf{x}\).
\(x_t\): Represents a token at position \(t\) within the sequence. Each token \(x_t\) is an integer where \(0 \leq x_t < V\). Here, \(V\) is the size of the vocabulary, and each integer corresponds to a unique word or symbol in \(\mathcal{V}\).
\(t\): The index of a token within the sequence \(\mathbf{x}\), where \(1 \leq t \leq T\).
Batched Input Sequences#
In practice, GPT models are often trained on batches of sequences to improve computational efficiency. A batched input is represented as \(\mathbf{x}^{\mathcal{B}}\), where \(\mathcal{B}\) denotes the batch size. The batched input \(\mathbf{x}^{\mathcal{B}}\) can be visualized as a matrix:
In this matrix:
Each row corresponds to a sequence in the batch.
Each column corresponds to a token position across all sequences in the batch.
\(x_{b,t}\) refers to the token at position \(t\) in sequence \(b\), with \(1 \leq b \leq \mathcal{B}\) and \(1 \leq t \leq T\).
\(T\) and \(V\) are as defined previously.
Token to Index, and Index to Token Mappings#
String-to-Index Mapping#
Function: \(f_{\text{stoi}}\)
Domain: \(\mathcal{V}\), the set of all tokens in the vocabulary.
Codomain: \(\{0, 1, \ldots, V-1\}\), where \(V\) is the size of the vocabulary.
Purpose: This function maps each token (word) from the vocabulary to a unique index. For a token \(w \in \mathcal{V}\), the value \(f_{\text{stoi}}(w) = j\) indicates that the token \(w\) corresponds to the \(j\)-th position in the vocabulary \(\mathcal{V}\).
Example: If \(\mathcal{V} = \{\text{cat}, \text{dog}, \text{mouse}\}\) and \(V = 3\), then \(f_{\text{stoi}}(\text{cat}) = 0\), \(f_{\text{stoi}}(\text{dog}) = 1\), and \(f_{\text{stoi}}(\text{mouse}) = 2\).
Index-to-String Mapping#
Function: \(f_{\text{itos}}\)
Domain: \(\{0, 1, \ldots, V-1\}\)
Codomain: \(\mathcal{V}\), the set of all tokens in the vocabulary.
Purpose: This function maps each index back to its corresponding token (word) in the vocabulary. For an index \(j\), the value \(f_{\text{itos}}(j) = w\) indicates that the index \(j\) corresponds to the token \(w\) in the vocabulary \(\mathcal{V}\).
Example: Continuing the previous example, \(f_{\text{itos}}(0) = \text{cat}\), \(f_{\text{itos}}(1) = \text{dog}\), and \(f_{\text{itos}}(2) = \text{mouse}\).
One-Hot Representation of Input Sequence \(\mathbf{x}\)#
The one-hot representation of the input sequence \(\mathbf{x}\) is denoted as \(\mathbf{X}^{\text{ohe}}\). This representation converts each token in the sequence to a one-hot encoded vector, where each vector has a length equal to the size of the vocabulary \(V\).
Definition#
The one-hot encoded matrix \(\mathbf{X}^{\text{ohe}}\) is defined as:
where:
\(T\): Total length of the sequence \(\mathbf{x}\).
\(V\): Size of the vocabulary \(\mathcal{V}\).
\(o_{t,j}\): Element of the one-hot encoded matrix \(\mathbf{X}^{\text{ohe}}\) at row \(t\) and column \(j\).
In addition, we have:
\(\mathbf{X}^{\text{ohe}}\) is a \(T \times V\) matrix.
Elements of \(\mathbf{X}^{\text{ohe}}\) are binary, i.e., they belong to \(\{0, 1\}\).
The row vector \(\mathbf{o}_{t, :}\) represents the one-hot encoded vector for the token at position \(t\) in the sequence \(\mathbf{x}\).
One-Hot Encoding Process#
For each token \(x_t\) at position \(t\) in the sequence \(\mathbf{x}\) (\(1 \leq t \leq T\)), the corresponding row vector \(\mathbf{o}_{t, :}\) in \(\mathbf{X}^{\text{ohe}}\) is defined as:
for \(j = 1, 2, \ldots, V\).
Here, \(f_{\text{stoi}}(x_t)\) maps the token \(x_t\) to its index \(j-1\) in the vocabulary \(\mathcal{V}\), the \(j-1\) is because zero-based indexing used in python (where \(0 \leq j-1 < V\)). Each row \(\mathbf{o}_{t, :}\) in \(\mathbf{X}^{\text{ohe}}\) contains a single ‘1’ at the column \(j\) corresponding to the vocabulary index of \(x_t\), and ‘0’s elsewhere.
(Example)
For example, if the vocabulary \(\mathcal{V} = \{\text{cat}, \text{dog}, \text{mouse}\}\) and the sequence \(\mathbf{x} = (\text{mouse}, \text{dog})\), then the one-hot encoded matrix \(\mathbf{X}^{\text{ohe}}\) will be:
In this example:
The sequence length \(T = 2\).
The vocabulary size \(V = 3\).
“mouse” corresponds to the third position in the vocabulary, and “dog” to the second, which is seen in their respective one-hot vectors.
Batched#
The batched one-hot encoded matrix \(\mathbf{X}^{\text{ohe, }\mathcal{B}}\) for a batch of size \(\mathcal{B}\) is defined as a three-dimensional tensor, where each “slice” (or matrix) along the first dimension corresponds to the one-hot encoded representation of a sequence in the batch:
Here:
\(\mathcal{B}\): Batch size, the number of sequences processed together.
\(T\): Length of each sequence, assumed uniform across the batch.
\(V\): Size of the vocabulary.
\(\mathbf{X}^{\text{ohe}}_b\): One-hot encoded matrix of the \(b^{th}\) sequence in the batch.
Weights And Embeddings#
Matrix Multiplication Primer#
See source.
If \(A=(a_{ij})\in M_{mn}(\Bbb F), B=(b_{ij})\in M_{np}(\Bbb F)\) then \(C=A\times B=(c_{ij})\in M_{mp}(\Bbb F)\). \(c_{ij}=\sum_{k=1}^{n} a_{ik}b_{kj}\) where \(i=1,...m, j=1,...p\)
Let’s take a look at one specific element in the product \(C=AB\), namely the element on position \((i,j)\), i.e. in the \(i\)th row and \(j\)th column.
To obtain this element, you:
first multiply all elements of the \(i\)th row of the matrix \(A\) pairwise with all the elements of the \(j\)th column of the matrix \(B\);
and then you add these \(n\) products.
You have to repeat this procedure for every element of \(C\), but let’s zoom in on that one specific (but arbitrary) element on position \((i,j)\) for now:
with element \(\color{purple}{\mathbf{c_{ij}}}\) equal to:
Now notice that in the sum above, the left outer index is always \(i\) (\(i\)th row of \(A\)) and the right outer index is always \(j\) (\(j\)th column of \(B\)). The inner indices run from \(1\) to \(n\) so you can introduce a summation index \(k\) and write this sum compactly using summation notation:
The formule above thus gives you the element on position \((i,j)\) in the product matrix \(C=AB\) and therefore completely defines \(C\) by letting \(i=1,...,m\) and \(j=1,...,p\).
\(\mathbf{X}\): Output of the Embedding Layer#
Once the one hot encoding representation \(\mathbf{X}^{\text{ohe}}\) is well defined, we can then pass it as input through our GPT model, in which the first layer is a embedding lookup table. In the GPT model architecture, the first layer typically involves mapping the one-hot encoded input vectors into a lower-dimensional, dense embedding space using the embedding matrix \(\mathbf{W}_e\).
Matrix Description |
Symbol |
Dimensions |
Description |
---|---|---|---|
One-Hot Encoded Input Matrix |
\(\mathbf{X}^{\text{ohe}}\) |
\(T \times V\) |
Each row corresponds to a one-hot encoded vector representing a token in the sequence. |
Embedding Matrix |
\(\mathbf{W}_e\) |
\(V \times D\) |
Each row is the embedding vector of the corresponding token in the vocabulary. |
Embedded Input Matrix |
\(\mathbf{X}\) |
\(T \times D\) |
Each row is the embedding vector of the corresponding token in the input sequence. |
Embedding Vector for Token \(t\) |
\(\mathbf{X}_t\) |
\(1 \times D\) |
The embedding vector for the token at position \(t\) in the input sequence. |
Batched Input Tensor |
\(\mathbf{X}^{\mathcal{B}}\) |
\(B \times T \times D\) |
A batched tensor containing \(B\) input sequences, each sequence is of shape \(T \times D\). |
More concretely, we create an embedding matrix \(\mathbf{W}_{e}\) of size \(V \times D\), where \(V\) is the vocabulary size, \(D\) is the dimensions of the embeddings, we would then matrix multiply \(\mathbf{X}^{\text{ohe}}\) with \(\mathbf{W}_{e}\) to get the output tensor \(\mathbf{X}\).
Definition#
The embedding matrix \(\mathbf{W}_{e}\) is structured as follows:
where
\(\mathbf{w}_j = (w_{j,1}, w_{j,2}, \ldots, w_{j,D}) \in \mathbb{R}^{1 \times D}\):
Each row vector \(\mathbf{w}_j\) of the matrix \(\mathbf{W}_e\) represents the \(D\)-dimensional embedding vector for the \(j\)-th token in the vocabulary \(\mathcal{V}\).
The subscript \(j\) ranges from 1 to \(V\), indexing the tokens.
\(V\) is the vocabulary size.
\(D\) is the hidden embedding dimension.
Here is a visual representation of how each embedding vector is selected through matrix multiplication:
Each row in the resulting matrix \(\mathbf{X}\) is the embedding of the corresponding token in the input sequence, picked directly from \(\mathbf{W}_e\) by the one-hot vectors. In other words, the matrix \(\mathbf{W}_e\) can be visualized as a table where each row corresponds to a token’s embedding vector:
Lookup#
When the one-hot encoded input matrix \(\mathbf{X}^{\text{ohe}}\) multiplies with the embedding matrix \(\mathbf{W}_e\), each row of \(\mathbf{X}^{\text{ohe}}\) effectively selects a corresponding row from \(\mathbf{W}_e\). This operation simplifies to row selection because each row of \(\mathbf{X}^{\text{ohe}}\) contains exactly one ‘1’ and the rest are ‘0’s.
Semantic Representation#
Now each row of the output tensor, indexed by \(t\), \(\mathbf{X}_{t, :}\): is the \(D\) dimensional embedding vector for the token \(x_t\) at the \(t\)-th position in the sequence. In this context, each token in the sequence is represented by a \(D\) dimensional vector. So, the output tensor \(\mathbf{X}\) captures the dense representation of the sequence. Each token in the sequence is replaced by its corresponding embedding vector from the embedding matrix \(\mathbf{W}_{e}\). As before, the output tensor \(\mathbf{X}\) carries semantic information about the tokens in the sequence. The closer two vectors are in this embedding space, the more semantically similar they are.
\(\mathbf{W}_{e}\): Embedding Matrix#
The embedding matrix \(\mathbf{W}_{e}\) is structured as follows:
where
\(\mathbf{w}_j = (w_{j,1}, w_{j,2}, \ldots, w_{j,D}) \in \mathbb{R}^{1 \times D}\):
Each row vector \(\mathbf{w}_j\) of the matrix \(\mathbf{W}_e\) represents the \(D\)-dimensional embedding vector for the \(j\)-th token in the vocabulary \(\mathcal{V}\).
The subscript \(j\) ranges from 1 to \(V\), indexing the tokens.
\(V\) is the vocabulary size.
\(D\) is the hidden embedding dimension.
\(PE\): Positional Encoding Layer#
For a given input matrix \(\mathbf{X} \in \mathbb{R}^{T \times D}\), where \(T\) is the sequence length and \(D\) is the embedding dimension (denoted as \(d_{\text{model}}\) in typical Transformer literature), the positional encoding \(\operatorname{PE}\) is applied to integrate sequence positional information into the embeddings. The resultant matrix \(\mathbf{X}'\) after applying positional encoding can be expressed as follows:
where each element of \(\mathbf{X}'\), denoted as \(x'_{i, j}\), is calculated based on the sinusoidal function:
for \(i = 1, \ldots, T\) and \(j = 1, \ldots, D\).
\(\tilde{\mathbf{X}}\): Output of the Positional Encoding Layer#
We can update our original embeddings tensor \(\mathbf{X}\) to include positional information:
This operation adds the positional encodings to the original embeddings, giving the final embeddings that are passed to subsequent layers in the Transformer model.
Matrix Description |
Symbol |
Dimensions |
Description |
---|---|---|---|
Positional Encoding Matrix |
\(\mathbf{W}_{p}\) |
\(T \times D\) |
Matrix with positional encoding vectors for each position in the sequence, computed using sinusoidal functions. |
Output of Positional Encoding Layer |
\(\tilde{\mathbf{X}}\) |
\(T \times D\) |
The resultant embeddings matrix after adding positional encoding \(\mathbf{W}_{p}\) to the embedded input matrix \(\mathbf{X}\). Each row now includes positional information. |
Embedding Vector for Token \(t\) |
\(\tilde{\mathbf{X}}_t\) |
\(1 \times D\) |
The token and positional embedding vector for the token at position \(t\) in the input sequence. |
Batched Input Tensor |
\(\tilde{\mathbf{X}}^{\mathcal{B}}\) |
\(B \times T \times D\) |
A batched tensor containing \(B\) input sequences, each sequence is of shape \(T \times D\). |
Layer Normalization#
Layer normalization modifies the activations within each layer to have zero mean and unit variance across the features for each data point in a batch independently, which helps in stabilizing the learning process. It then applies a learnable affine transformation to each normalized activation, allowing the network to scale and shift these values where beneficial.
For a given layer with inputs \(\mathbf{Z} \in \mathbb{R}^{B \times T \times D}\) (where \(B\) is the batch size, \(T\) is the sequence length, and \(D\) is the feature dimension or hidden dimension size), the layer normalization of \(\mathbf{Z}\) is computed as follows:
Mean and Variance Calculation: Calculate the mean \(\mu_t\) and variance \(\sigma_t^2\) for each feature vector across the feature dimension \(D\):
\[ \mu_t = \frac{1}{D} \sum_{d=1}^D \mathbf{Z}_{t, d}, \quad \sigma_t^2 = \frac{1}{D} \sum_{d=1}^D (\mathbf{Z}_{t, d} - \mu_t)^2 \]\(\mu_t\) and \(\sigma_t^2\) are computed for each token \(t\) across all batches \(B\) and sequence positions \(T\), but independently for each batch and sequence position.
Normalization: Normalize the activations for each feature dimension:
\[ \hat{\mathbf{Z}}_{t, d} = \frac{\mathbf{Z}_{t, d} - \mu_t}{\sqrt{\sigma_t^2 + \epsilon}} \]Where \(\epsilon\) is a small constant (e.g., \(10^{-5}\)) added for numerical stability.
Affine Transformation: Apply a learnable affine transformation to each normalized feature:
\[ \overline{\mathbf{Z}}_{t, d} = \hat{\mathbf{Z}}_{t, d} \cdot \gamma_d + \beta_d \]\(\gamma_d\) and \(\beta_d\) are learnable parameters that scale and shift the normalized feature respectively. They are of the same dimensionality \(D\) as the features and are shared across all tokens and batches.
In practice, these operations are implemented vector-wise across the feature dimension \(D\), and can be compactly expressed as:
Here, \(\odot\) denotes element-wise multiplication, emphasizing that \(\gamma\) and \(\beta\) scale and shift each normalized feature dimension identically across all tokens and batches.
For better understanding, we can calculate in a loop all \(T\) rows of \(\overline{\mathbf{Z}}_t\), and we stack them together to get the final output tensor \(\overline{\mathbf{Z}}\).
Input/Output |
Shape |
Description |
---|---|---|
\(\mathbf{Z}\) |
\(\mathcal{B} \times T \times D\) |
The input tensor to the layer normalization operation. |
\(\overline{\mathbf{Z}}\) |
\(\mathcal{B} \times T \times D\) |
The output tensor after applying layer normalization. The dimensionality remains the same as the input, only the scale and shift of the activations within each feature vector are adjusted. |
\(\operatorname{LayerNorm}\) |
\(\mathbb{R}^{B \times T \times D} \to \mathbb{R}^{B \times T \times D}\) |
The layer normalization function that takes an input tensor \(\mathbf{Z}\) and returns the normalized tensor \(\overline{\mathbf{Z}}\) with the same shape. |
Attention Notations#
Dimensions#
Symbol |
Description |
---|---|
\(H\) |
Number of attention heads. |
\(h\) |
Index of the attention head. |
\(d_k = D/H\) |
Dimension of the keys. In the multi-head attention case, this would typically be \(D/H\) where \(D\) is the dimensionality of input embeddings and \(H\) is the number of attention heads. |
\(d_q = D/H\) |
Dimension of the queries. Also usually set equal to \(d_k\). |
\(d_v = D/H\) |
Dimension of the values. Usually set equal to \(d_k\). |
\(L\) |
Total number of decoder blocks in the GPT architecture. |
\(\ell\) |
Index of the decoder block, ranging from \(1\) to \(L\). |
Query, Key and Values#
Matrix Description |
Symbol |
Dimensions |
Description |
---|---|---|---|
Generic Query Matrix for All Heads |
\(\mathbf{Q}\) |
\(T \times D\) |
Contains the query representations for all tokens in the sequence using combined weights of all heads. |
Generic Key Matrix for All Heads |
\(\mathbf{K}\) |
\(T \times D\) |
Contains the key representations for all tokens in the sequence using combined weights of all heads. |
Generic Value Matrix for All Heads |
\(\mathbf{V}\) |
\(T \times D\) |
Contains the value representations for all tokens in the sequence using combined weights of all heads. |
Query Matrix for All Heads in Layer \(\ell\) |
\(\mathbf{Q}^{(\ell)}\) |
\(T \times D\) |
Contains the query representations for all tokens in the sequence using combined weights of all heads in the \(\ell\)-th layer. |
Key Matrix for All Heads in Layer \(\ell\) |
\(\mathbf{K}^{(\ell)}\) |
\(T \times D\) |
Contains the key representations for all tokens in the sequence using combined weights of all heads in the \(\ell\)-th layer. |
Value Matrix for All Heads in Layer \(\ell\) |
\(\mathbf{V}^{(\ell)}\) |
\(T \times D\) |
Contains the value representations for all tokens in the sequence using combined weights of all heads in the \(\ell\)-th layer. |
Query Weight Matrix for All Heads in Layer \(\ell\) |
\(\mathbf{W}^{\mathbf{Q}, (\ell)}\) |
\(D \times D\) |
The query transformation matrix applicable to all heads in the \(\ell\)-th layer. It transforms the embeddings \(\tilde{\mathbf{X}}\) into query representations. |
Key Weight Matrix for All Heads in Layer \(\ell\) |
\(\mathbf{W}^{\mathbf{K}, (\ell)}\) |
\(D \times D\) |
The key transformation matrix applicable to all heads in the \(\ell\)-th layer. It transforms the embeddings \(\tilde{\mathbf{X}}\) into key representations. |
Value Weight Matrix for All Heads in Layer \(\ell\) |
\(\mathbf{W}^{\mathbf{V}, (\ell)}\) |
\(D \times D\) |
The value transformation matrix applicable to all heads in the \(\ell\)-th layer. It transforms the embeddings \(\tilde{\mathbf{X}}\) into value representations. |
Query Matrix for Head \(h\) in Layer \(\ell\) |
\(\mathbf{Q}_h^{(\ell)}\) |
\(T \times d_q\) |
Contains the query representations for all tokens in the sequence specific to head \(h\) in the \(\ell\)-th layer. |
Key Matrix for Head \(h\) in Layer \(\ell\) |
\(\mathbf{K}_h^{(\ell)}\) |
\(T \times d_k\) |
Contains the key representations for all tokens in the sequence specific to head \(h\) in the \(\ell\)-th layer. |
Value Matrix for Head \(h\) in Layer \(\ell\) |
\(\mathbf{V}_h^{(\ell)}\) |
\(T \times d_v\) |
Contains the value representations for all tokens in the sequence specific to head \(h\) in the \(\ell\)-th layer. |
Query Weight Matrix for Head \(h\) in Layer \(\ell\) |
\(\mathbf{W}_h^{\mathbf{Q}, (\ell)}\) |
\(D \times d_q\) |
Linear transformation matrix for queries in masked attention head \(h \in \{1, \ldots, H\}\) of decoder block \(\ell \in \{1, \ldots, L\}\). |
Key Weight Matrix for Head \(h\) in Layer \(\ell\) |
\(\mathbf{W}_h^{\mathbf{K}, (\ell)}\) |
\(D \times d_k\) |
Linear transformation matrix for keys in masked attention head \(h \in \{1, \ldots, H\}\) of decoder block \(\ell \in \{1, \ldots, L\}\). |
Value Weight Matrix for Head \(h\) in Layer \(\ell\) |
\(\mathbf{W}_h^{\mathbf{V}, (\ell)}\) |
\(D \times d_v\) |
Linear transformation matrix for values in masked attention head \(h \in \{1, \ldots, H\}\) of decoder block \(\ell \in \{1, \ldots, L\}\). |
Projection Weight Matrix for Layer \(\ell\) |
\(\mathbf{W}^{O, (\ell)}\) |
\(D \times D\) |
The projection matrix used to combine and transform the concatenated outputs from all heads in the \( \ell \)-th layer back to the original dimension \(D\). |
We will talk about the relevant shapes in the last section.
General Attention Mechanism#
To calculate the embeddings after attention in the GPT model:
Multi-Head Attention for Layer \(\ell\)#
For the multi-head attention in layer \(\ell\) of the Transformer (applicable to both encoder and decoder in architectures that have both components, but here tailored for GPT which primarily uses decoder stacks):
where
\(\mathbf{W}_{h}^{\mathbf{Q}, (\ell)}, \mathbf{W}_{h}^{\mathbf{K}, (\ell)}, \mathbf{W}_{h}^{\mathbf{V}, (\ell)}\) are the weight matrices for queries, keys, and values for the \(h\)-th head in the \(\ell\)-th layer, respectively.
\(\mathbf{W}^{O, (\ell)}\) is the output transformation matrix for the \(\ell\)-th layer.
Masked Multi-Head Attention for Decoder Layer \(\ell\)#
Masked multi-head attention, used in the decoder to ensure that the predictions for position \(t\) can only depend on known outputs at positions less than \(t\) (auto-regressive property):
where
\(M\) denotes the masked condition.
\(\mathbf{W}_{h}^{\mathbf{Q}, M, (\ell)}, \mathbf{W}_{h}^{\mathbf{K}, M, (\ell)}, \mathbf{W}_{h}^{\mathbf{V}, M, (\ell)}\) are the masked weight matrices for queries, keys, and values for the \(h\)-th head in the \(\ell\)-th layer, specifically used under the masked condition.
\(\mathbf{W}^{O, M, (\ell)}\) is the masked output transformation matrix for the \(\ell\)-th layer, ensuring that future tokens do not influence the predictions of the current token in an auto-regressive manner.
Updated Matrix Description Table with Batch and Head Dimensions#
To accurately reflect the practical shapes of the Query (Q), Key (K), and Value (V) matrices in implementations like GPT, where batch processing and multi-head attention are used, we should adjust the notation to include batch size \(B\), number of heads \(H\), and the dimensions \(d_k, d_q, d_v\) corresponding to each head.
Matrix Description |
Symbol |
Dimensions |
Description |
---|---|---|---|
Query Matrix for All Heads in Layer \(\ell\) |
\(\mathbf{Q}^{(\ell)}\) |
\(\mathcal{B} \times T \times D \rightarrow B \times T \times \underset{D}{\underbrace{H \times d_q}} \xrightarrow[]{\text{transpose 1-2}} B \times H \times T \times d_q\) |
|
Contains the query representations for all tokens in all sequences of a batch, separated by heads in the \(\ell\)-th layer. |
|||
Key Matrix for All Heads in Layer \(\ell\) |
\(\mathbf{K}^{(\ell)}\) |
\(\mathcal{B} \times T \times D \rightarrow B \times T \times \underset{D}{\underbrace{H \times d_k}} \xrightarrow[]{\text{transpose 1-2}} B \times H \times T \times d_k\) |
Contains the key representations for all tokens in all sequences of a batch, separated by heads in the \(\ell\)-th layer. |
Value Matrix for All Heads in Layer \(\ell\) |
\(\mathbf{V}^{(\ell)}\) |
\(\mathcal{B} \times T \times D \rightarrow B \times T \times \underset{D}{\underbrace{H \times d_v}} \xrightarrow[]{\text{transpose 1-2}} B \times H \times T \times d_v\) |
Contains the value representations for all tokens in all sequences of a batch, separated by heads in the \(\ell\)-th layer. |
Positionwise Feed-Forward Networks#
The term “positionwise feed-forward network” (FFN) in the context of Transformer models refers to a dense neural network (otherwise known as multilayer perceptron) that operates on the output of the Multi-Head Attention mechanism. This component is called “positionwise” because it applies the same feed-forward neural network (FFN) independently and identically to each position \(t\) in the sequence of length \(T\).
Independent Processing#
In the Transformer architecture, after the Multi-Head Attention mechanism aggregates information from different positions in the sequence based on attention scores, each element (or position) \(t\) in the sequence has an updated representation. The positionwise FFN then processes each of these updated representations. However, rather than considering the sequence as a whole or how elements relate to each other at this stage, the FFN operates on each position separately. This means that for a sequence of length \(T\), the same FFN is applied \(T\) times independently, and by extension, given a batch of sequences, the FFN is applied \(T \times \mathcal{B}\) times, where \(\mathcal{B}\) is the batch size.
Identical Application#
The term “using the same FFN” signifies that the same set of parameters (weights and biases) of the feed-forward neural network is used for each position in the sequence. The rationale is that the transformation is consistent across all sequence positions, so each element is transformed by the same learned function. This means the weight matrices and bias vectors of the FFN are shared across all positions in the sequence. In other words, if a sequence has \(T=3\) positions/tokens, the weight matrices and bias vectors of the FFN are the same for all three positions.
Definition#
Typically, a positionwise FFN consists of two linear transformations with a non-linear activation function in between. The general form can be represented as follows.
(Position-wise Feedforward Networks)
Given an input matrix \(\mathbf{Z} \in \mathbb{R}^{T \times D}\), the position-wise feedforward network computes the output matrix \(\mathbf{Z}^{\prime} \in \mathbb{R}^{T \times D}\) via the following operations:
where:
\(\mathbf{W}^{\text{FF}}_1 \in \mathbb{R}^{D \times d_{\text{ff}}}\) and \(\mathbf{W}^{\text{FF}}_2 \in \mathbb{R}^{d_{\text{ff}} \times D}\) are learnable weight matrices.
\(\mathbf{b}^{\text{FF}}_1 \in \mathbb{R}^{d_{\text{ff}}}\) and \(\mathbf{b}^{\text{FF}}_2 \in \mathbb{R}^{D}\) are learnable bias vectors.
\(\sigma_Z\) is a non-linear activation function, such as the Gaussian Error Linear Unit (GELU) or the Rectified Linear Unit (ReLU).
Projection to a Higher Dimension Space#
In the Transformer architecture, the dimensionality of the hidden layer in the positionwise FFN, denoted as \(d_{\text{ff}}\), is often chosen to be larger than the dimensionality of the input and output embeddings, \(D\). This means that the FFN projects the input embeddings into a higher-dimensional space before projecting them back to the original dimensionality.
The motivation behind this design choice is to allow the model to learn more complex and expressive representations. By projecting the input embeddings into a higher-dimensional space, the model capacity is increased, and the FFN can capture more intricate patterns and relationships among the features. We then project back (“unembedding”) the higher-dimensional representations to the original dimensionality to maintain the consistency of the model.
In practice, a common choice for the dimensionality of the hidden layer is to set \(d_{\text{ff}}\) to be a multiple of the input and output dimensionality \(D\). For example, in the original Transformer paper [Vaswani et al., 2017], the authors used \(d_{\text{ff}} = 4 \times D\).
Gaussian Error Linear Unit (GELU)#
The Gaussian Error Linear Unit (GELU) is a non-linear activation function used in the context of neural networks, which allows the model to capture more complex patterns in the data compared to traditional activation functions like ReLU. The GELU activation function is defined as:
where \(x\) is the input to the activation function, and \(\Phi(x)\) represents the cumulative distribution function (CDF) of the standard Gaussian distribution. The GELU function, effectively, models inputs with a non-linear transformation that weights inputs by their value, with a probabilistic gating mechanism derived from the Gaussian distribution.
The cumulative distribution function \(\Phi(x)\) for a standard Gaussian distribution is given by:
where \(\text{erf}\) denotes the error function, which is a special function integral of the Gaussian distribution. Combining these, the GELU function can be expressed as:
I will not pretend I have went through the entire paper and motivation of GELU, but usually, when new and “better” activation functions are proposed, they usually serve as an alternative to the common activation functions such as ReLU etc, where they solve some of the problems that the common activation functions have. From the formulation, we can see that GELU obeys the following properties:
Non-linearity: GELU introduces non-linearity to the model, a given requirement.
Differentiability: GELU is smooth and differentiable everywhere, which is beneficial for gradient-based optimization methods.
Boundedness: GELU seems to be bounded below by \(-0.17\) and not upper bounded, but practice we can show there is an upper bound if we normalize the input.
(Approximation of GELU)
To further simplify the GELU function and enhance computational efficiency, an approximation of the Gaussian CDF is commonly used in practice (extracted from Mathematical Analysis and Performance Evaluation of the GELU Activation Function in Deep Learning):
where \(\beta>0\) and \(\gamma \in \mathbb{R}\) are constants, selected to minimize approximation error. Substituting this approximation into the GELU function, we arrive at the final approximate form of the GELU activation function (Figure 1):
(GELU Activation Function)
For a matrix \(\mathbf{Z}\) with elements \(\mathbf{Z}_{t d}\) where \(t\) indexes the sequence (from 1 to \(T\) ) and \(d\) indexes the feature dimension (from 1 to \(D\) ), the GELU activation is applied element-wise to each element \(\mathbf{Z}_{t d}\) independently:
References
Matrix Description |
Symbol |
Dimensions |
Description |
---|---|---|---|
Input to FFN in Layer \(\ell\) |
\(\mathbf{Z}^{(\ell)}_4\) |
\(\mathcal{B} \times T \times D\) |
Output from the residual connection that adds the normalized self-attention outputs to the initial input embeddings. |
First Linear Transformation in FFN |
\(\mathbf{Z}^{FF, (\ell)}_1\) |
\(\mathcal{B} \times T \times d_{\text{ff}}\) |
Applies the first linear transformation to each position’s embedding, projecting it to a higher dimensional space (\(d_{\text{ff}}\)). |
Activation (e.g., GELU) Applied to First Linear Output |
\(\sigma\left(\mathbf{Z}^{FF, (\ell)}_1\right)\) |
\(\mathcal{B} \times T \times d_{\text{ff}}\) |
Applies the GELU non-linear activation function to the output of the first linear transformation. |
Second Linear Transformation in FFN |
\(\mathbf{Z}^{(\ell)}_5\) |
\(\mathcal{B} \times T \times D\) |
Transforms the activated output back down to the original dimensionality \(D\) of the embeddings. |
Weights and Biases |
|||
Weights for First Linear Transformation |
\(\mathbf{W}^{FF, (\ell)}_1\) |
\(D \times d_{\text{ff}}\) |
Weights used to transform the input embeddings from dimension \(D\) to \(d_{\text{ff}}\). |
Biases for First Linear Transformation |
\(\mathbf{b}^{FF, (\ell)}_1\) |
\(d_{\text{ff}}\) |
Biases added to the linearly transformed embeddings in the first FFN layer. |
Weights for Second Linear Transformation |
\(\mathbf{W}^{FF, (\ell)}_2\) |
\(d_{\text{ff}} \times D\) |
Weights used to project the activated embeddings from dimension \(d_{\text{ff}}\) back to \(D\). |
Biases for Second Linear Transformation |
\(\mathbf{b}^{FF, (\ell)}_2\) |
\(D\) |
Biases added to the output of the second linear transformation in the FFN, shaping it back to the original embedding dimension. |
The Training Phase#
Autoregressive Self-Supervised Learning Paradigm#
Let \(\mathcal{D}\) be the true but unknown distribution of the natural language space. In the context of unsupervised learning with self-supervision, such as language modeling, we consider both the inputs and the implicit labels derived from the same data sequence. Thus, while traditionally we might decompose the distribution \(\mathcal{D}\) of a supervised learning task into input space \(\mathcal{X}\) and label space \(\mathcal{Y}\), in this scenario, \(\mathcal{X}\) and \(\mathcal{Y}\) are intrinsically linked, because \(\mathcal{Y}\) is a shifted version of \(\mathcal{X}\), and so we can consider \(\mathcal{D}\) as a distribution over \(\mathcal{X}\) only.
Since \(\mathcal{D}\) is a distribution, we also define it as a probability distribution over \(\mathcal{X}\), and we can write it as:
where \(\boldsymbol{\Theta}\) is the parameter space that defines the distribution \(\mathbb{P}(\mathcal{X} ; \boldsymbol{\Theta})\) and \(\mathbf{x}\) is a sample from \(\mathcal{X}\) generated by the distribution \(\mathcal{D}\). It is common to treat \(\mathbf{x}\) as a sequence of tokens (i.e. a sentence is a sequence of tokens), and we can write \(\mathbf{x} = \left(x_1, x_2, \ldots, x_T\right)\), where \(T\) is the length of the sequence.
Given such a sequence \(\mathbf{x}\), the joint probability of the sequence can be factorized into the product of the conditional probabilities of each token in the sequence via the chain rule of probability:
We can do this because natural language are inherently ordered. Such decomposition allows for tractable sampling from and estimation of the distribution \(\mathbb{P}(\mathbf{x} ; \boldsymbol{\Theta})\) as well as any conditionals in the form of \(\mathbb{P}(x_{t-k}, x_{t-k+1}, \ldots, x_{t} \mid x_{1}, x_{2}, \ldots, x_{t-k-1} ; \boldsymbol{\Theta})\) [Radford et al., 2019].
To this end, consider a corpus \(\mathcal{S}\) with \(N\) sequences \(\left\{\mathbf{x}_{1}, \mathbf{x}_{2}, \ldots, \mathbf{x}_{N}\right\}\),
where each sequence \(\mathbf{x}_{n}\) is a sequence of tokens that are sampled \(\text{i.i.d.}\) from the distribution \(\mathcal{D}\).
Then, we can frame the likelihood function \(\hat{\mathcal{L}}(\cdot)\) as the likelihood of observing the sequences in the corpus \(\mathcal{S}\),
where \(\hat{\boldsymbol{\Theta}}\) is the estimated parameter space that approximates the true parameter space \(\boldsymbol{\Theta}\).
Subsequently, the objective function is now well-defined, to be the maximization of the likelihood of the sequences in the corpus \(\mathcal{S}\),
where \(T_n\) is the length of the sequence \(\mathbf{x}_{n}\).
Owing to the fact that multiplying many probabilities together can lead to numerical instability because the product of many probabilities can be very small, it is common and necessary to use the log-likelihood as the objective function, because it can be proven that maximizing the log-likelihood is equivalent to maximizing the likelihood itself.
Furthermore, since we are treating the the loss function as a form of minimization, we can simply negate the log-likelihood to obtain the negative log-likelihood as the objective function to be minimized,
It is worth noting that the objective function is a function of the parameter space \(\hat{\boldsymbol{\Theta}}\), and not the data \(\mathcal{S}\), so all analysis such as convergence and consistency will be with respect to the parameter space \(\hat{\boldsymbol{\Theta}}\).
To this end, we denote the GPT model \(\mathcal{G}\) to be an autoregressive and self-supervised learning model that is trained to maximize the likelihood of observing all data points \(\mathbf{x} \in \mathcal{S}\) via the objective function \(\hat{\mathcal{L}}\left(\mathcal{S} ; \hat{\boldsymbol{\Theta}}\right)\) by learning the conditional probability distribution \(\mathbb{P}(x_t \mid x_{<t} ; \hat{\boldsymbol{\Theta}})\) over the vocabulary \(\mathcal{V}\) of tokens, conditioned on the contextual preciding tokens \(x_{<t} = \left(x_1, x_2, \ldots, x_{t-1}\right)\). We are clear that although the goal is to model the joint probability distribution of the token sequences, we can do so by estimating the joint probability distribution via the conditional probability distributions.
Corpus and Tokenization#
Step 1. Corpus
Consider a corpus \(\mathcal{S}\) consisting of \(N\) sequences, denoted as \({\mathbf{x}_1, \mathbf{x}_2, \ldots, \mathbf{x}_N}\), where each sequence \(\mathbf{x} = (x_1, x_2, \ldots, x_T) \in \mathcal{S}\) is a sequence of \(T\) tokens. These tokens are sampled i.i.d. from a true, unknown distribution \(\mathcal{D}\):
Each sequence \(\mathbf{x} \in \mathcal{S}\) represents a collection of tokenized elements (e.g., words or characters), where each token \(x_t\) comes from a finite vocabulary \(\mathcal{V}\).
Step 2. Vocabulary and Tokenization
Let \(\mathcal{V} = \{w_1, w_2, \ldots, w_V\}\) be the vocabulary set, where \(w_j\) is the \(j\)-th token in the vocabulary and \(V = |\mathcal{V}|\) is the size of the vocabulary. It is worth noting that it is common to train one’s own vocabulary and tokenizer on the corpus \(\mathcal{S}\), but for simplicity, we assume that the vocabulary set \(\mathcal{V}\) is predefined.
Let \(\mathcal{X}\) be the set of all possible sequences that can be formed by concatenating tokens from the vocabulary set \(\mathcal{V}\). Each sequence \(\mathbf{x} \in \mathcal{X}\) is a finite sequence of tokens, and the length of each sequence is denoted by \(\tau\). Formally:
where \(\mathcal{V}^\tau\) represents the set of all sequences of length \(\tau\) formed by concatenating tokens from \(\mathcal{V}\), and \(T\) is the maximum sequence length.
Now, let \(\mathcal{S} = \{\mathbf{x}_1, \mathbf{x}_2, \ldots, \mathbf{x}_N\} \subset \mathcal{X}\) be a corpus of \(N\) sequences, where each sequence \(\mathbf{x}_n \in \mathcal{X}\) is a finite sequence of tokens from the vocabulary set \(\mathcal{V}\).
The tokenizer algorithm \(\mathcal{T}\) is a function that operates on individual sequences \(\mathbf{x}_n\) from the corpus \(\mathcal{S}\) and maps the tokens to their corresponding integer indices using the vocabulary set \(\mathcal{V}\):
where \(\mathbb{N}^{\leq T}\) represents the set of all finite sequences of natural numbers (non-negative integers) with lengths up to \(T\). The output of \(\mathcal{T}\) is a tokenized sequence, which is a finite sequence of integer indices corresponding to the tokens in the input sequence.
To map the tokens to their corresponding integer indices, we define a bijective mapping function \(f: \mathcal{V} \rightarrow \{1, 2, \ldots, V\}\) such that:
where \(f(w_j)\) represents the integer index assigned to the token \(w_j \in \mathcal{V}\).
Given a sequence \(\mathbf{x} = (x_1, x_2, \ldots, x_\tau) \in \mathcal{X}\), where \(\tau \leq T\) is the length of the sequence, the tokenizer algorithm \(\mathcal{T}\) maps each token \(x_t\) to its corresponding integer index using the bijective mapping function \(f\). The tokenized representation of the sequence \(\mathbf{x}\) can be defined as:
where \(f(x_t)\) is the integer index assigned to the token \(x_t\) based on \(f\).
In the case where a token \(x_t\) is not present in the vocabulary set \(\mathcal{V}\), a special token index, such as \(f(\text{<UNK>})\), can be assigned to represent an unknown token.
The tokenizer algorithm \(\mathcal{T}\) can be applied to each sequence \(\mathbf{x}_n\) in the corpus \(\mathcal{S}\) to obtain the tokenized corpus \(\mathcal{S}^{\mathcal{T}}\):
where \(\mathcal{T}(\mathbf{x}_n)\) is the tokenized representation of the sequence \(\mathbf{x}_n \in \mathcal{S}\).
The tokenized corpus \(\mathcal{S}^{\mathcal{T}}\) is a set of sequences, where each sequence is a finite sequence of integer indices representing the tokens in the original sequences from the corpus \(\mathcal{S}\).
Token Embedding and Positional Encoding#
Step 3. One Hot Encoding
For each sequence \(\mathbf{x} \in \mathcal{S}^{\mathcal{T}}\) in the corpus, we would apply one hot encoding so that each sample/sequence is transformed to \(\mathbf{X}^{\text{ohe}} \in \{0, 1\}^{T \times V}\) where \(V\) is the vocabulary size and \(T\) the pre-defined context window size.
Each row \(\mathbf{X}^{\text{ohe}}_{t} \in \mathbb{R}^{1 \times V}\) represents the one-hot encoded representation of the token at position \(t\) in the sequence.
Step 4. Token Embedding
Given the one-hot encoded input \(\mathbf{X}^{\text{ohe}} \in \{0, 1\}^{T \times |\mathcal{V}|}\), where \(T\) is the sequence length and \(V = |\mathcal{V}|\) is the vocabulary size, we obtain the token embedding matrix \(\mathbf{X} \in \mathbb{R}^{T \times D}\) by matrix multiplying \(\mathbf{X}^{\text{ohe}}\) with the token embedding weight matrix \(\mathbf{W}_e \in \mathbb{R}^{V \times D}\), where \(D\) is the embedding dimension:
Weight Sharing
Note carefully that with the addition of batch dimension \(\mathcal{B}\) the matrix multiplication is still well-defined for such tensor in PyTorch because we are essentially just performing matrix multiplication in \(T \times D\) for each sequence \(\mathbf{X}_b \in \mathbf{X}^{\mathcal{B}}\) with the same weight matrix \(\mathbf{W}_{e}\).
The token embedding weight matrix \(\mathbf{W}_e\) with dimensions \(V \times D\) is shared across all sequences in the batch. Each sequence \(\mathbf{X}^{(b)}\) in the batched input tensor \(\mathbf{X}^{\mathcal{B}}\) undergoes the same matrix multiplication with \(\mathbf{W}_e\) to obtain the corresponding embedded sequence representation.
The idea of weight sharing is that the same set of parameters (in this case, the embedding weights) is used for processing multiple instances of the input (sequences in the batch). Instead of having separate embedding weights for each sequence, the same embedding matrix is applied to all sequences. This parameter sharing allows the model to learn a common representation for the tokens across different sequences.
Step 5. Positional Embedding
In addition to the token embeddings, we incorporate positional information into the input representation to capture the sequential nature of the input sequences. Let \(\operatorname{PE}(\cdot)\) denote the positional encoding function that maps the token positions to their corresponding positional embeddings.
Given the token embedding matrix \(\mathbf{X} \in \mathbb{R}^{T \times D}\), where \(T\) is the sequence length and \(D\) is the embedding dimension, we add the positional embeddings to obtain the position-aware input representation \(\tilde{\mathbf{X}} \in \mathbb{R}^{T \times D}\):
The positional encoding function \(\operatorname{PE}(\cdot)\) can be implemented in various ways, such as using fixed sinusoidal functions or learned positional embeddings. For the latter, we can easily replace \(\operatorname{PE}(\cdot)\) with a learnable positional embedding layer in the model architecture (\(\mathbf{W}_{p}\)).
Dropout And Elementwise Operation
At this stage, it is common practice to apply a dropout layer \(\operatorname{Dropout}(\cdot)\) to the position-aware input representation \(\tilde{\mathbf{X}}\) (or \(\tilde{\mathbf{X}}_{\text{batch}}\) in the case of a batch). Dropout is a regularization technique that randomly sets a fraction of the elements in the input tensor to zero during training and is an element-wise operation that acts independently on each element in the tensor. This means that each element has a fixed probability (usually denoted as \(p\)) of being set to zero, regardless of its position or the values of other elements in the tensor.
Mathematically, for an input tensor \(\mathbf{X} \in \mathbb{R}^{T \times D}\), elementwise dropout can be expressed as:
where \(\odot\) denotes the elementwise (Hadamard) product, and \(\mathbf{M} \in \{0, 1\}^{T \times D}\) is a binary mask tensor of the same shape as \(\mathbf{X}\). Each element in \(\mathbf{M}\) is independently sampled from a Bernoulli distribution with probability \(p\) of being 0 (i.e., dropped) and probability \(1-p\) of being 1 (i.e., retained).
Backbone Architecture#
Step 6. Pre-Layer Normalization For Masked Multi-Head Attention
Before passing the input through the Multi-Head Attention (MHA) layer, we apply Layer Normalization to the positionally encoded embeddings \(\tilde{\mathbf{X}}\). This is known as pre-layer Normalization in the more modern GPT architecture (as opposed to post-layer Normalization, which is applied after the MHA layer).
The Layer Normalization function \(\operatorname{LayerNorm}(\cdot)\) is a vectorwise operation that operates on the feature dimension \(D\) of the input tensor. It normalizes the activations to have zero mean and unit variance across the features for each token independently. The vectorwise nature of Layer Normalization arises from the fact that it computes the mean and standard deviation along the feature dimension, requiring aggregation of information across the entire feature vector for each token.
Mathematically, for an input tensor \(\mathbf{X} \in \mathbb{R}^{T \times D}\), Layer Normalization is applied independently to each row \(\mathbf{x}_t \in \mathbb{R}^{1 \times D}\), where \(t \in \{1, 2, \ldots, T\}\). The normalization is performed using the following formula:
where \(\mu_t \in \mathbb{R}\) and \(\sigma_t^2 \in \mathbb{R}\) are the mean and variance of the features in \(\mathbf{x}_t\) (broadcasted), respectively, \(\epsilon\) is a small constant for numerical stability, \(\gamma \in \mathbb{R}^D\) and \(\beta \in \mathbb{R}^D\) are learnable affine parameters (scale and shift), and \(\odot\) denotes the elementwise product.
Applying Layer Normalization to the positionally encoded embeddings \(\tilde{\mathbf{X}}\) at layer \(\ell\) results in the normalized embeddings \(\mathbf{Z}^{(\ell)}_1\):
Here, \(\mathbf{Z}^{(\ell)}_1\) represents the normalized embeddings at layer \(\ell\), and the index \(1\) refers to the first sub-layer/sub-step in the decoder block.
For the first layer (\(\ell = 1\)), \(\tilde{\mathbf{X}}\) is the output from Step 4 (Positional Embedding). So we have:
In code we have:
def forward(self, x: torch.Tensor) -> torch.Tensor:
mean = x.mean(dim=-1, keepdim=True) # [1]
std = x.std(dim=-1, keepdim=True, unbiased=False) # [2]
if self.elementwise_affine:
return self.gamma * (x - mean) / (std + self.eps) + self.beta # [3]
return (x - mean) / (std + self.eps) # [4]
Line |
Code |
Operation Description |
Input Shape |
Output Shape |
Notes |
---|---|---|---|---|---|
[1] |
|
Computes the mean of |
\(\mathcal{B} \times T \times D\) |
\(\mathcal{B} \times T \times 1\) |
|
[2] |
|
Computes the standard deviation along the last dimension. |
\(\mathcal{B} \times T \times D\) |
\(\mathcal{B} \times T \times 1\) |
Similar to the mean, |
[3] |
|
Applies the normalization formula with learnable parameters gamma (\(\gamma\)) and beta (\(\beta\)). |
\(\mathcal{B} \times T \times D\) |
\(\mathcal{B} \times T \times D\) |
Element-wise operations are used. \(\gamma\) and \(\beta\) are of shape \(D\), and are broadcasted to match the input shape. This line only executes if |
[4] |
|
Applies the normalization formula without learnable parameters. |
\(\mathcal{B} \times T \times D\) |
\(\mathcal{B} \times T \times D\) |
Simple normalization where each element in the feature vector \(x\) is normalized by the corresponding mean and standard deviation. |
Step 7. Masked Multi-Head Self-Attention
Given the normalized input embeddings \(\mathbf{Z}^{(\ell)}_1 \in \mathbb{R}^{\mathcal{B} \times T \times D}\) from Step 6 (Pre-Layer Normalization), we apply the masked multi-head self-attention mechanism to compute the output embeddings \(\mathbf{Z}^{(\ell)}_2\), where the index \(2\) denotes the second sub-layer within the \(\ell\)-th decoder layer (multi-head attention).
Let \(\operatorname{MaskedMultiHead}^{(\ell)}(\cdot)\) denote the masked multi-head self-attention function at layer \(\ell\). The masked multi-head self-attention operation takes the normalized input embeddings \(\mathbf{Z}^{(\ell)}_1\) as the query, key, and value matrices, and produces the output embeddings \(\mathbf{Z}^{(\ell)}_2\).
For the first layer (\(\ell = 1\)), the masked multi-head self-attention operation can be expressed as:
Here, \(\mathbf{Z}^{(1)}_2 \in \mathbb{R}^{\mathcal{B} \times T \times D}\) represents the output embeddings of the masked multi-head self-attention operation at layer \(1\), and \(\mathbf{Z}^{(1)}_1 \in \mathbb{R}^{\mathcal{B} \times T \times D}\) represents the normalized input embeddings from Step 6.
The \(\operatorname{MaskedMultiHead}^{(\ell)}(\cdot)\) function internally performs the following steps:
Linearly projects the input embeddings \(\mathbf{Z}^{(\ell)}_1\) into query, key, and value matrices for each attention head.
Computes the scaled dot-product attention scores between the query and key matrices, and applies the attention mask to prevent attending to future tokens.
Applies the softmax function to the masked attention scores to obtain the attention weights.
Multiplies the attention weights with the value matrices to produce the output embeddings for each attention head.
Concatenates the output embeddings from all attention heads and linearly projects them to obtain the final output embeddings \(\mathbf{Z}^{(\ell)}_2\).
The specifics of the scaled dot-product attention mechanism and the multi-head attention operation will be discussed in the next few steps.
Step 7.1. Linear Projections, Query, Key, and Value Matrices
In the masked multi-head self-attention mechanism, the first step is to linearly project the normalized input embeddings \(\mathbf{Z}^{(\ell)}_1\) into query, key, and value matrices for each attention head. This step is performed using learnable weight matrices \(\mathbf{W}^{Q, (\ell)}\), \(\mathbf{W}^{K, (\ell)}\), and \(\mathbf{W}^{V, (\ell)}\).
Mathematically, the linear projections can be expressed as:
where:
\(\mathbf{Q}^{(\ell)} \in \mathbb{R}^{\mathcal{B} \times T \times D}\) is the query matrix for the \(\ell\)-th decoder layer.
\(\mathbf{K}^{(\ell)} \in \mathbb{R}^{\mathcal{B} \times T \times D}\) is the key matrix for the \(\ell\)-th decoder layer.
\(\mathbf{V}^{(\ell)} \in \mathbb{R}^{\mathcal{B} \times T \times D}\) is the value matrix for the \(\ell\)-th decoder layer.
\(\mathbf{W}^{Q, (\ell)} \in \mathbb{R}^{D \times D}\), \(\mathbf{W}^{K, (\ell)} \in \mathbb{R}^{D \times D}\), and \(\mathbf{W}^{V, (\ell)} \in \mathbb{R}^{D \times D}\) are the learnable weight matrices that transform the normalized embeddings into queries, keys, and values, respectively.
Again notice that we are using the same weight matrices for all heads, weight/parameters sharing.
The linear projections are performed using matrix multiplication between the normalized input embeddings \(\mathbf{Z}^{(\ell)}_1\) and the corresponding weight matrices. The resulting query, key, and value matrices have the same shape as the input embeddings: \(\mathcal{B} \times T \times D\).
In the provided code snippet, the linear projections are implemented using the
torch.nn.Linear
modules self.W_Q
, self.W_K
, and self.W_V
:
Q: torch.Tensor = self.W_Q(z).contiguous() # Z @ W_Q = [B, T, D] @ [D, D] = [B, T, D]
K: torch.Tensor = self.W_K(z).contiguous() # Z @ W_K = [B, T, D] @ [D, D] = [B, T, D]
V: torch.Tensor = self.W_V(z).contiguous() # Z @ W_V = [B, T, D] @ [D, D] = [B, T, D]
Step 7.2. Reshaping and Transposing Query, Key, and Value Matrices
Subsequently, we have already known that instead of for loop to compute each head, we can compute all heads in parallel using matrix operations. The query, key, and value matrices are split into \(H\) heads, and the attention scores are computed in parallel. So our aim is simple, we want to reshape the query, key, and value matrices to include the head dimension, basically splitting the \(D\) dimension into \(H\) heads. We can denote the reshaping and transposition operation using tensor index notation which makes it explicit how indices are permuted and combined:
To this end, we have reshaped and transposed the query, key, and value matrices as follows:
In code, the reshaping and transposition operations are performed as follows:
Q = Q.view(B, T, self.H, D // self.H).transpose(dim0=1, dim1=2) # [B, T, D] -> [B, T, H, D // H] -> [B, H, T, D//H]
K = K.view(B, T, self.H, D // self.H).transpose(dim0=1, dim1=2)
V = V.view(B, T, self.H, D // self.H).transpose(dim0=1, dim1=2)
The view
operation reshapes the matrices to include the head dimension, and
the transpose
operation swaps the sequence and head dimensions to obtain the
desired ordering of dimensions.
Step 7.3. Scaled Dot-Product Attention and Masking
The attention mask matrix \(\mathbf{M} \in \{0, -\infty\}^{T \times T}\) is initially constructed as a lower triangular matrix of ones and zeros:
In this matrix, “1” (conceptually) allows attention, and “0” blocks it. This mask is broadcastable to the attention scores tensor \(\mathcal{B} \times H \times T \times T\) since this is a typical shape for multi-head self-attention.
However, before applying the mask, the ones in \(\mathbf{M}\) are typically
replaced with zeros, and the zeros are replaced with a large negative value
(e.g., \(-\infty\)) to effectively set the attention weights of the masked
positions to zero after the softmax operation (which is to mimic the
masked_fill
operation in the code). We define
\(\mathbf{M}^{-\infty} \in \{0, -\infty\}^{T \times T} = 1 - \mathbf{M}\), where:
Notation Abuse
For the ease of notation, we abuse notation by using \(\mathbf{M}\) to denote the final mask matrix \(\mathbf{M}^{-\infty}\).
The masking operation will then be applied elementwise to the attention scores tensor \(\mathbf{A}_{s}^{(\ell)}\), which is first obtained by computing the scaled dot product between the query matrix \(\mathbf{Q}^{(\ell)}\) and the key matrix \(\mathbf{K}^{(\ell)}\):
The masking operation is performed using the elementwise sum (\(\oplus\)) between the attention scores tensor \(\mathbf{A}_{s}^{(\ell)} \in \mathbb{R}^{\mathcal{B} \times H \times T \times T}\) and the broadcasted mask matrix \(\mathbf{M} \in \{0, -\infty\}^{\mathcal{B} \times H \times T \times T}\):
The elementwise sum ensures that the attention scores corresponding to future tokens (positions above the diagonal) are added by \(-\infty\) to effectively block attention to those positions and added by \(0\) for the rest of the positions that are allowed to attend to.
Alternative Masking
I have some quirks with the above because it “feels weird” to use \(0\) as “allowed” instead of \(1\). After all the above formulation neatly fits the notation. However, an alternative way is to use \(1\) as “allowed” - where we define \(\mathbf{M}^{-\infty}\) again but this time we have:
The elementwise multiplication \(\mathbf{A}_{s}^{(\ell)} \odot \mathbf{M}\) preserves the attention scores for the allowed positions (where \(\mathbf{M}\) is 1) and sets the scores to zero for the masked positions (where \(\mathbf{M}\) is 0). Then, the elementwise addition with \(\mathbf{M}_{-\infty}\) effectively pushes the attention scores of the masked positions towards negative infinity, while leaving the scores of the allowed positions unchanged.
Finally, the masked attention scores \(\mathbf{A}_{s}^{M, (\ell)}\) are passed through the softmax function to obtain the attention weights \(\mathbf{A}_{w}^{(\ell)}\):
The softmax function is then applied to the masked attention scores tensor in a vectorwise manner. Specifically, the softmax operation is applied independently to each row of the last two dimensions (\(T \times T\)) for each batch and head. This ensures that the attention weights for each token position across the sequence length sum up to 1. The large negative values in the masked positions ensure that the corresponding attention weights become close to zero after the softmax operation.
Finally, the context matrix \(\mathbf{C}^{(\ell)}\) is obtained by multiplying the attention weights \(\mathbf{A}_{w}^{(\ell)}\) with the value matrix \(\mathbf{V}^{(\ell)}\):
The resulting context matrix \(\mathbf{C}^{(\ell)}\) contains the attended values for each head in layer \(\ell\).
Note \(\mathbf{C}^{(\ell)}\) is the context matrix which is the output of the self-attention mechanism and it contains \(\operatorname{head}_{\ell, h}^{M}\) for each head \(h\) in the layer \(\ell\).
d_q = query.size(dim=-1)
attention_scores = torch.matmul(query, key.transpose(dim0=-2, dim1=-1)) / torch.sqrt(torch.tensor(d_q).float()) # [B, H, T, d_q] @ [B, H, d_q, T] = [B, H, T, T]
attention_scores = attention_scores.masked_fill(mask == 0, float("-inf")) if mask is not None else attention_scores # [B, H, T, T]
attention_weights = attention_scores.softmax(dim=-1) # [B, H, T, T]
attention_weights = self.dropout(attention_weights) # [B, H, T, T]
context_vector = torch.matmul(attention_weights, value) # [B, H, T, T] @ [B, H, T, d_v] = [B, H, T, d_v]
Line |
Code |
Operation Description |
Input Shape |
Output Shape |
Notes |
---|---|---|---|---|---|
[1] |
|
Retrieves the dimension of the query vectors. |
\(\mathcal{B} \times H \times T \times d_q\) |
Scalar value |
The last dimension of the query tensor represents the query vector dimension. |
[2] |
|
Computes the scaled dot-product attention scores by matrix multiplying the query and key matrices and scaling by the square root of the query dimension. |
Query: \(\mathcal{B} \times H \times T \times d_q\) |
\(\mathcal{B} \times H \times T \times T\) |
The key matrix is transposed to align the dimensions for matrix multiplication. The scaling factor helps stabilize gradients during training. |
[3] |
|
Applies the attention mask to the attention scores. Positions where the mask is 0 are filled with \(-\infty\) to effectively block attention to those positions. |
\(\mathcal{B} \times H \times T \times T\) |
\(\mathcal{B} \times H \times T \times T\) |
The |
[4] |
|
Applies the softmax function to the masked attention scores along the last dimension to obtain the attention weights. |
\(\mathcal{B} \times H \times T \times T\) |
\(\mathcal{B} \times H \times T \times T\) |
The softmax operation is applied in a vectorwise manner, independently for each row of the last two dimensions (\(T \times T\)) for each batch and head. This ensures that the attention weights for each token position across the sequence length sum up to 1. |
[5] |
|
Applies dropout regularization to the attention weights to prevent overfitting. |
\(\mathcal{B} \times H \times T \times T\) |
\(\mathcal{B} \times H \times T \times T\) |
Dropout randomly sets a fraction of the attention weights to zero during training, which helps improve generalization. This is element-wise operation. |
[6] |
|
Computes the context vector by matrix multiplying the attention weights with the value matrix. |
Attention Weights: \(\mathcal{B} \times H \times T \times T\) |
\(\mathcal{B} \times H \times T \times d_v\) |
The attention weights are used to weight the importance of each token’s value vector. The resulting context vector captures the attended information from the input sequence. |
Step 7.4. Concatenation and Projection
Recall that the output from the masked multi-head self-attention operation, denoted as \(\mathbf{C}^{(\ell)}\), has a shape of \(\mathcal{B} \times H \times T \times D//H\), where \(\mathcal{B}\) is the batch size, \(H\) is the number of attention heads, \(T\) is the sequence length, and \(D//H\) is the dimension of each head.
To concatenate the heads and obtain a tensor of shape \(\mathcal{B} \times T \times D\), we first need to transpose the dimensions of \(\mathbf{C}^{(\ell)}\) from \(\mathcal{B} \times H \times T \times D//H\) to \(\mathcal{B} \times T \times H \times D//H\) - necessary to concatenate the heads along the last dimension (feature dimension).
Using tensor index notation or semi einsum notation, we can denote the transposition operation as follows:
After transposition, the heads are concatenated along the last dimension (feature dimension) to obtain a tensor of shape \(\mathcal{B} \times T \times D\). The concatenation operation can be expressed using the direct sum notation as follows:
where \(\mathbf{C}^{(\ell)}_{b,t,h,:}\) represents the tensor slice corresponding to the \(h\)-th head at batch index \(b\) and time step \(t\), and \(\oplus\) denotes the concatenation operation along the feature dimension.
Direct Sum Notation Is Concatenation
In the concatenation of attention heads, the direct sum notation (\(\oplus\)) is used to represent the concatenation operation, not elementwise addition. The direct sum combines the output tensors from each attention head along a specific dimension (usually the feature dimension). It is a vectorwise operation that stacks the tensors along the specified dimension, creating a new tensor with an increased dimension size.
The concatenation operation can be summarized as:
To summarize, the transposition and concatenation operations can be represented as:
Finally, the concatenated tensor is linearly transformed using the projection matrix \(\mathbf{W}^{O, (\ell)}\) to obtain the output \(\mathbf{Z}^{(\ell)}_2\):
In code, these operations can be implemented as follows:
self.context_vector = self.context_vector.transpose(dim0=1, dim1=2).contiguous().view(B, T, D) # merge all heads together
projected_context_vector: torch.Tensor = self.resid_dropout(
self.context_projection(self.context_vector) # [B, T, D] @ [D, D] = [B, T, D]
)
To this end, for layer \(\ell=1\), we would have the output tensor \(\mathbf{Z}^{(1)}_2\) - which becomes the input to the next sub-layer within the same block \(\ell\). Optionally, we can apply a projection dropout to the output tensor \(\mathbf{Z}^{(\ell)}_2\) before passing it to the next sub-layer.
Step 8. Residual Connection
In the Transformer/GPT architecture, residual connections are used to facilitate the flow of information and gradients throughout the network. The residual connection in the decoder block is added between the input to the block and the output of the Masked Multi-Head Attention layer.
For the first decoder block (\(\ell = 1\)), the residual connection is added between the positionally encoded embeddings \(\tilde{\mathbf{X}}\) and the output of the Masked Multi-Head Attention layer \(\mathbf{Z}^{(1)}_2\):
where \(\mathbf{Z}^{(1)}_{3}\) represents the output of the residual connection for the first decoder block, \(\tilde{\mathbf{X}}\) is the positionally encoded embeddings, and \(\mathbf{Z}^{(1)}_{2}\) is the output of the Masked Multi-Head Attention layer.
For subsequent decoder blocks (\(\ell > 1\)), the residual connection is added between the output of the previous decoder block \(\mathbf{Z}^{(\ell-1)}_{\text{out}}\) and the output of the Masked Multi-Head Attention layer \(\mathbf{Z}^{(\ell)}_2\) of the current block:
where \(\mathbf{Z}^{(\ell)}_3\) represents the output of the residual connection for the \(\ell\)-th decoder block, \(\mathbf{Z}^{(\ell-1)}_{\text{out}}\) is the output of the previous decoder block after the Position-wise Feed-Forward Network and the second residual connection, and \(\mathbf{Z}^{(\ell)}_2\) is the output of the Masked Multi-Head Attention layer of the current block.
The complete set of equations for the \(\ell\)-th decoder block (\(\ell > 1\)) can be summarized as follows:
where \(\mathbf{Z}^{(\ell)}_1\) represents the output of the Layer Normalization step, \(\mathbf{Z}^{(\ell)}_2\) represents the output of the Masked Multi-Head Attention layer, and \(\mathbf{Z}^{(\ell)}_3\) represents the output of the residual connection.
The residual connection allows the model to learn the identity function more easily, enabling the flow of information and gradients across multiple layers. By adding the input of the block to the output of the Masked Multi-Head Attention layer, the model can choose to either learn new information from the attention mechanism or retain the original input information if it is already sufficient.
In code, the whole series of operation up till now is simply:
z = z + self.attn(z=self.ln_1(z))
Step 9. Pre-Layer Normalization For Position-wise Feed-Forward Network
After the masked multi-head attention block and the residual connection, the next step is to apply the position-wise feed-forward network (FFN) to the output of the self-attention mechanism and the residual block \(\mathbf{Z}^{(\ell)}_3\). However, before applying the FFN, we perform pre-layer normalization on the input to the FFN.
Mathematically, the pre-layer normalization step for the FFN can be expressed as:
where \(\mathbf{Z}^{(\ell)}_4\) represents the normalized input to the FFN at layer \(\ell\), and \(\mathbf{Z}^{(\ell)}_3\) is the output of the residual connection from the previous step.
As discussed in Step 6, the Layer Normalization function \(\operatorname{LayerNorm}(\cdot)\) is a vectorwise operation that operates on the feature dimension \(D\) of the input tensor. It normalizes the activations to have zero mean and unit variance across the features for each token independently. It’s important to note that the pre-layer normalization step is applied independently to each token \(\mathbf{Z}_{3, t}^{(\ell)} \in \mathbf{Z}_3^{(\ell)}\) in the sequence with shape \(T \times D\), where \(T\) is the sequence length and \(D\) is the feature dimension, normalizing the features across the feature dimension \(D\).
For the first layer (\(\ell = 1\)), the pre-layer normalization step can be written as:
Step 10. Position-wise Feed-Forward Network
Given the normalized input \(\mathbf{Z}^{(\ell)}_4\) to the Position-wise Feed-Forward Network (FFN) in layer \(\ell\), the FFN applies two linear transformations with a GELU activation function in between. The operations within the FFN can be mathematically represented as follows:
where:
\(\mathbf{W}^{FF, (\ell)}_1 \in \mathbb{R}^{D \times d_{\text{ff}}}\) and \(\mathbf{b}^{FF, (\ell)}_1 \in \mathbb{R}^{d_{\text{ff}}}\) are the weights and biases of the first linear transformation, respectively.
\(\mathbf{W}^{FF, (\ell)}_2 \in \mathbb{R}^{d_{\text{ff}} \times D}\) and \(\mathbf{b}^{FF, (\ell)}_2 \in \mathbb{R}^{D}\) are the weights and biases of the second linear transformation, respectively.
\(d_{\text{ff}}\) is the dimensionality of the hidden layer in the FFN, which is typically larger than the input dimensionality \(D\).
\(\operatorname{GELU}(\cdot)\) denotes the Gaussian Error Linear Unit activation function.
Note the slight abuse of notation where \(\mathbf{Z}^{FF, (\ell)}_1\) is used to denote the intermediate output of the first linear transformation in the FFN. This should not be confused with the earlier notation \(\mathbf{Z}^{(\ell)}_1\).
For the first layer (\(\ell = 1\)), the FFN operations can be written as:
Step 11. Residual Connection
After obtaining the output \(\mathbf{Z}^{(\ell)}_5\) from the Position-wise Feed-Forward Network (FFN) in layer \(\ell\), the final step in the decoder block is to apply a residual connection.
Mathematically, this step can be represented as follows:
where:
\(\mathbf{Z}^{(\ell)}_3 \in \mathbb{R}^{\mathcal{B} \times T \times D}\) is the output of the residual connection from the Masked Multi-Head Attention block (Step 8).
\(\mathbf{Z}^{(\ell)}_5 \in \mathbb{R}^{\mathcal{B} \times T \times D}\) is the output from the FFN (Step 10).
\(\mathbf{Z}^{(\ell)}_{\text{out}} \in \mathbb{R}^{\mathcal{B} \times T \times D}\) is the output of the decoder block at layer \(\ell\) and serves as the input to the next decoder block (\(\ell + 1\)).
The residual connection is performed by adding the output of the Masked Multi-Head Attention block \(\left(\mathbf{Z}^{(\ell)}_3\right)\) and the output of the FFN \(\left(\mathbf{Z}^{(\ell)}_5\right)\). This addition operation is element-wise, where corresponding elements of the two tensors are added together.
For the first layer (\(\ell = 1\)), the residual connection step can be written as:
The output of this step, \(\mathbf{Z}^{(\ell)}_{\text{out}}\), becomes the input to the next decoder block.
Iterative Process Through L Decoder Blocks#
Let \(\mathbf{Z}^{(1)}_{\text{out}}\) be the output of the first decoder block. This output becomes the input to the next decoder block, and the process continues iteratively. Each decoder block builds upon the output of the previous block through a series of mathematical transformations. The subscript notation \(\mathbf{Z}^{(\ell)}_i\) indicates the i-th step output of the \(\ell\)-th decoder block.
First Decoder Block (\(\ell = 1\))#
Subsequent Decoder Blocks (\(\ell > 1\))#
For each subsequent decoder block \(\ell\), the output of the previous block’s final step \(\mathbf{Z}^{(\ell-1)}_{\text{out}}\) serves as the input for the current block’s operations.
After processing the input through a total of \(L\) decoder blocks, the final output is denoted as \(\mathbf{Z}^{(L)}_{\text{out}}\), which is the output of the last decoder block. The shape of \(\mathbf{Z}^{(L)}_{\text{out}}\) is \(\mathcal{B} \times T \times D\), where \(\mathcal{B}\) is the batch size, \(T\) is the sequence length, and \(D\) is the hidden dimension.
Layer Normalization Before Projection#
Step 10. Layer Normalization Before Projection
The final output of the decoder block \(\mathbf{Z}^{(L)}_{\text{out}}\) undergoes a layer normalization step before being projected to the vocabulary space. This step is commonly referred to as the “pre-projection layer normalization” or “final layer normalization” in the context of the Transformer/GPT architecture.
The pre-projection/head layer normalization can be represented as follows:
where:
\(\mathbf{Z}^{(L)}_{\text{out}} \in \mathbb{R}^{\mathcal{B} \times T \times D}\) is the output of the last decoder block (\(\ell = L\)).
\(\operatorname{LayerNorm}(\cdot)\) denotes the Layer Normalization operation, which normalizes the activations across the feature dimension \(D\) for each token independently.
\(\mathbf{Z}_{\text{pre-proj}} \in \mathbb{R}^{\mathcal{B} \times T \times D}\) is the normalized output, which is ready to be projected to the vocabulary space.
The usual purpose of applying layer normalization before the projection step is to stabilize the activations and improve training stability.
Head#
The final step in the GPT architecture is to project the normalized output of the last decoder block, \(\mathbf{Z}_{\text{pre-head}}\), to the vocabulary space. This projection is performed using a linear transformation, where the weights of the projection layer are denoted as \(\mathbf{W}_{s} \in \mathbb{R}^{D \times V}\). The subscript \(s\) indicates that this is the projection layer before the softmax operation, and \(V\) represents the size of the vocabulary.
Mathematically, the projection operation can be expressed as follows:
where:
\(\mathbf{Z}_{\text{pre-head}} \in \mathbb{R}^{\mathcal{B} \times T \times D}\) is the normalized output from the pre-projection layer normalization step (Step 12).
\(\mathbf{W}_{s} \in \mathbb{R}^{D \times V}\) is the weight matrix of the projection layer, which maps the hidden dimension \(D\) to the vocabulary size \(V\).
\(\mathbf{Z} \in \mathbb{R}^{\mathcal{B} \times T \times V}\) is the resulting logits tensor, representing the unnormalized scores for each token in the vocabulary at each position in the sequence.
The purpose of the projection layer is to map the hidden representations from the decoder to the vocabulary space, allowing the model to generate probability distributions over the vocabulary for each token position. The logits tensor \(\mathbf{Z}\) can be further processed by applying a softmax function to obtain the final probability distribution for token prediction.
Softmax Layer#
The softmax function is applied to the logits tensor \(\mathbf{Z} \in \mathbb{R}^{\mathcal{B} \times T \times V}\) to obtain the predicted probability distribution over the vocabulary for each token position in the sequence. The softmax operation is performed vector-wise along the vocabulary dimension (i.e., the last dimension) of \(\mathbf{Z}\) independently for each token position and each instance in the batch.
The softmax operation can be defined as follows:
where \(\mathbf{P} \in \mathbb{R}^{\mathcal{B} \times T \times V}\) is the resulting probability tensor, and \(\mathbf{P}_{b,t,v}\) represents the predicted probability of token \(v\) at position \(t\) in the sequence for batch instance \(b\). The softmax function ensures that the probabilities sum to 1 along the vocabulary dimension for each token position and each batch instance, i.e., \(\sum_{v=1}^{V} \mathbf{P}_{b,t,v} = 1\) for all \(b \in {1, \ldots, \mathcal{B}}\) and \(t \in {1, \ldots, T}\).
m = nn.Softmax(dim=-1)
input = torch.randn(2, 3, 4)
output = m(input)
torch.testing.assert_close(output.sum(-1), torch.ones(2, 3))
The softmax operation is applied independently to each row of the last two dimensions \((T \times V)\) of the logits tensor \(\mathbf{Z}\), while the batch dimension \(\mathcal{B}\) remains unchanged. This means that for each batch instance \(b\) and each token position \(t\), the softmax function takes the corresponding row vector \(\mathbf{Z}_{b,t,:} \in \mathbb{R}^{1 \times V}\) and computes the predicted probability distribution \(\mathbf{P}_{b,t,:} \in \mathbb{R}^{1 \times V}\) over the vocabulary.
Cross-Entropy Loss Function#
The cross-entropy loss function \(\mathcal{L}\) is used to measure the dissimilarity between the predicted probability distribution \(\mathbf{P}\) and the true token distribution \(\mathbf{Y}\). The true token distribution is typically represented as a one-hot encoded tensor \(\mathbf{Y} \in \{0, 1\}^{\mathcal{B} \times T \times V}\) (though in practice we just pass in a non one-hot encoded version).
The one-hot encoded tensor \(\mathbf{Y}\) has the same shape as the predicted probability tensor \(\mathbf{P}\), i.e., \(\mathcal{B} \times T \times V\). For each batch instance \(b\) and each token position \(t\), the corresponding row vector \(\mathbf{Y}_{b,t,:} \in {0, 1}^{1 \times V}\) represents the true token distribution, where the element corresponding to the true token is set to 1, and all other elements are set to 0.
The one-hot encoded tensor \(\mathbf{Y}\) can be defined as follows:
where \(v^*_{b,t}\) denotes the true token at position \(t\) in the sequence for batch instance \(b\).
The cross-entropy loss function \(\mathcal{L}\) is then defined as:
where \(\mathbf{Y}_{b,t,v}\) is the one-hot encoded true token indicator, and \(\mathbf{P}_{b,t,v}\) is the predicted probability of token \(v\) at position \(t\) in the sequence for batch instance \(b\). Note carefully that the sum is over all batch instances, all token positions, and all tokens in the vocabulary.
Table of Notations#
In the table below, we will not add notational burden by adding superscript \(\mathcal{B}\) to indicate a certain tensor is batched. We would just assume that all tensors are batched unless otherwise stated.
Matrix Description |
Symbol |
Dimensions |
Description |
---|---|---|---|
Batched Input Tensor |
\(\tilde{\mathbf{X}}^{\mathcal{B}}\) |
\(B \times T \times D\) |
A batched tensor containing \(B\) input sequences, each sequence is of shape \(T \times D\). |
Batched Embedding Matrix Sequence \(b\) |
\(\mathbf{X}^{(b)}\) |
\(T \times D\) |
The token and positional embedding vector for the \(b\)-th input sequence of the batch. |
Batched Embedding Vector for Token \(t\) in Sequence \(b\) |
\(\mathbf{X}^{(b)}_t\) |
\(1 \times D\) |
The token and positional embedding vector for the token at position \(t\) in the \(b\)-th input sequence of the batch. |
One-Hot Encoded Input Matrix |
\(\mathbf{X}^{\text{ohe}}\) |
\(\mathcal{B} \times T \times V\) |
Each row corresponds to a one-hot encoded vector representing a token in the sequence for a batch of size \(\mathcal{B}\). |
Token Embedding Matrix Weight |
\(\mathbf{W}_e\) |
\(V \times D\) |
Each row is the embedding vector of the corresponding token in the vocabulary. |
Token Embedded Input Matrix |
\(\mathbf{X}\) |
\(\mathcal{B} \times T \times D\) |
Each row is the token embedding vector of the corresponding token in the input sequence for a batch of size \(\mathcal{B}\). |
Positional Encoding Matrix |
\(\operatorname{PE}(\cdot)\) |
\(T \times D\) |
Matrix with positional encoding vectors for each position in the sequence, computed using sinusoidal functions or learned positional embeddings. |
Output of Positional Encoding Layer |
\(\tilde{\mathbf{X}}\) |
\(\mathcal{B} \times T \times D\) |
The resultant embeddings matrix after adding positional encoding to the token embedded input matrix \(\mathbf{X}\). Each row now includes token and positional information. |
First Layer Normalized Input |
\(\mathbf{Z}^{(1)}_1\) |
\(\mathcal{B} \times T \times D\) |
The output of the initial layer normalization applied to \(\tilde{\mathbf{X}}\), serving as the input to the first decoder block’s masked multi-head attention mechanism. |
Query, Key, Value Matrices (Block \(\ell\)) |
\(\mathbf{Q}^{(\ell)}, \mathbf{K}^{(\ell)}, \mathbf{V}^{(\ell)}\) |
\(\mathcal{B} \times H \times T \times D//H\) |
The query, key, and value matrices obtained by linearly projecting \(\mathbf{Z}^{(\ell)}_1\) using learned weights \(\mathbf{W}^{Q, (\ell)}, \mathbf{W}^{K, (\ell)}, \mathbf{W}^{V, (\ell)}\), and then splitting into \(H\) attention heads. Note that the last dimension must be the same for query and key for the dot product to be valid, but can differ for values. |
Attention Scores (Block \(\ell\)) |
\(\mathbf{A}^{(\ell)}_s\) |
\(\mathcal{B} \times H \times T \times T\) |
The scaled dot-product attention scores computed between the query and key matrices, before applying the attention mask. |
Attention Mask |
\(\mathbf{M}\) |
\(T \times T \overset{\text{broadcast}}{\rightarrow} \mathcal{B} \times H \times T \times T\) |
A binary mask matrix used to prevent attending to future tokens. It has a lower triangular structure with \(-\infty\) for future positions and \(0\) for allowed positions. Note that \(1\) for allowed positions can also be used, one just need to handle it in code. |
Masked Attention Scores (Block \(\ell\)) |
\(\mathbf{A}^{M, (\ell)}_s\) |
\(\mathcal{B} \times H \times T \times T\) |
The attention scores after applying the attention mask \(\mathbf{M}\), which sets the scores of future tokens to \(-\infty\). |
Attention Weights (Block \(\ell\)) |
\(\mathbf{A}^{(\ell)}_w\) |
\(\mathcal{B} \times H \times T \times T\) |
The attention weights obtained by applying the softmax function to the masked attention scores \(\mathbf{A}^{M, (\ell)}_s\). |
Context Matrix (Block \(\ell\)) |
\(\mathbf{C}^{(\ell)}\) |
\(\mathcal{B} \times H \times T \times D//H\) |
The context matrix obtained by multiplying the attention weights \(\mathbf{A}^{(\ell)}_w\) with the value matrix \(\mathbf{V}^{(\ell)}\). |
Concatenated Context Matrix (Block \(\ell\)) |
\(\mathbf{C}^{(\ell)}_{\text{concat}}\) |
\(\mathcal{B} \times T \times D\) |
The concatenated context matrix obtained by concatenating the context matrices from all attention heads in block \(\ell\). |
First Self-Attention Output (Block \(\ell\)) |
\(\mathbf{Z}^{(\ell)}_2\) |
\(\mathcal{B} \times T \times D\) |
The output of the masked multi-head attention mechanism in block \(\ell\), obtained by linearly projecting the concatenated context matrix \(\mathbf{C}^{(\ell)}_{\text{concat}}\) using learned weights \(\mathbf{W}^{O, (\ell)}\). |
Output After First Residual Connection (Block \(\ell\)) |
\(\mathbf{Z}^{(\ell)}_3\) |
\(\mathcal{B} \times T \times D\) |
The resultant tensor after adding the masked multi-head attention output \(\mathbf{Z}^{(\ell)}_2\) to the layer normalized input \(\mathbf{Z}^{(\ell)}_1\) (for \(\ell=1\)) or \(\mathbf{Z}^{(\ell-1)}_{\text{out}}\) (for \(\ell>1\)) through a residual connection. |
Normalized Input to FFN (Block \(\ell\)) |
\(\mathbf{Z}^{(\ell)}_4\) |
\(\mathcal{B} \times T \times D\) |
The output of applying layer normalization to \(\mathbf{Z}^{(\ell)}_3\), serving as the input to the position-wise feed-forward network (FFN) in block \(\ell\). |
Intermediate FFN Output (Block \(\ell\)) |
\(\mathbf{Z}^{FF, (\ell)}_1\) |
\(\mathcal{B} \times T \times d_{\text{ff}}\) |
The intermediate output of the FFN in block \(\ell\), obtained by applying the first linear transformation and GELU activation to \(\mathbf{Z}^{(\ell)}_4\). |
Output of FFN (Block \(\ell\)) |
\(\mathbf{Z}^{(\ell)}_5\) |
\(\mathcal{B} \times T \times D\) |
The final output of the FFN in block \(\ell\), obtained by applying the second linear transformation to \(\mathbf{Z}^{FF, (\ell)}_1\). |
Output of Decoder Block \(\ell\) |
\(\mathbf{Z}^{(\ell)}_{\text{out}}\) |
\(\mathcal{B} \times T \times D\) |
The output of decoder block \(\ell\), obtained by adding the FFN output \(\mathbf{Z}^{(\ell)}_5\) to \(\mathbf{Z}^{(\ell)}_3\) through a residual connection. It serves as the input to the next decoder block (\(\ell+1\)) or the final output of the decoder. |
Pre-Projection Layer Normalized Output |
\(\mathbf{Z}_{\text{pre-head}}\) |
\(\mathcal{B} \times T \times D\) |
The output of applying layer normalization to the final decoder block output \(\mathbf{Z}^{(L)}_{\text{out}}\), where \(L\) is the total number of decoder blocks. |
Logits (Vocabulary Projection) |
\(\mathbf{Z}\) |
\(\mathcal{B} \times T \times V\) |
The logits obtained by linearly projecting \(\mathbf{Z}_{\text{pre-head}}\) to the vocabulary space using learned weights \(\mathbf{W}_s\). It represents the unnormalized scores for each token in the vocabulary at each position in the sequence. |
FFN Layer 1 Weight Matrix (Block \(\ell\)) |
\(\mathbf{W}^{FF, (\ell)}_1\) |
\(D \times d_{\text{ff}}\) |
The weight matrix for the first linear transformation in the FFN of block \(\ell\). |
FFN Layer 1 Bias Vector (Block \(\ell\)) |
\(\mathbf{b}^{FF, (\ell)}_1\) |
\(d_{\text{ff}}\) |
The bias vector for the first linear transformation in the FFN of block \(\ell\). |
FFN Layer 2 Weight Matrix (Block \(\ell\)) |
\(\mathbf{W}^{FF, (\ell)}_2\) |
\(d_{\text{ff}} \times D\) |
The weight matrix for the second linear transformation in the FFN of block \(\ell\). |
FFN Layer 2 Bias Vector (Block \(\ell\)) |
\(\mathbf{b}^{FF, (\ell)}_2\) |
\(D\) |
The bias vector for the second linear transformation in the FFN of block \(\ell\). |
Projection Layer Weight Matrix |
\(\mathbf{W}_s\) |
\(D \times V\) |
The weight matrix used to linearly project \(\mathbf{Z}_{\text{pre-proj}}\) to the vocabulary space, mapping the hidden dimension \(D\) to the vocabulary size \(V\). |