How To Fine-Tune Decoder-Only Models For Sequence Classification Using Last Token Pooling?

How To Fine-Tune Decoder-Only Models For Sequence Classification Using Last Token Pooling?#

Twitter Handle LinkedIn Profile GitHub Profile Tag Tag Code

Firstly, if you have not read my Generative Pre-trained Transformers (GPT) series, please have a read first to establish some basic understand on what a decoder-only model entails.

Dependencies#

pip install -U omniverse==0.0.63
# %pip install omniverse
from __future__ import annotations

import logging
from collections import Counter, OrderedDict
from typing import Any, Dict, List, Tuple, TypedDict, overload

import numpy as np
import pandas as pd
import psutil
import torch
from datasets import load_dataset
from rich.pretty import pprint
from scipy.special import softmax
from sklearn.metrics import (
    accuracy_score,
    auc,
    average_precision_score,
    brier_score_loss,
    confusion_matrix,
    f1_score,
    log_loss,
    precision_recall_curve,
    precision_score,
    recall_score,
    roc_auc_score,
    roc_curve,
)
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm  # Use notebook version for better UI in notebooks
from transformers import (
    DataCollatorWithPadding,
    EvalPrediction,
    GPT2ForSequenceClassification,
    GPT2Tokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
    PreTrainedTokenizerBase,
    PreTrainedTokenizerFast,
    Trainer,
    TrainingArguments,
)
from transformers.trainer_utils import EvalPrediction

from omnivault.transformer.config.decoder import (
    AddNormConfig,
    DecoderBlockConfig,
    DecoderConfig,
    MultiHeadedAttentionConfig,
    PositionwiseFeedForwardConfig,
)
from omnivault.transformer.modules.attention.core import MultiHeadedAttention, ScaledDotProductAttention
from omnivault.transformer.modules.layers.addnorm import AddNorm
from omnivault.transformer.modules.layers.mlp import PositionwiseFeedForward
from omnivault.utils.reproducibility.seed import seed_all
from omnivault.utils.torch_utils.model_utils import total_trainable_parameters

Setting Up#

seed_all(seed=2024, seed_torch=True, set_torch_deterministic=False)
2024
LOGGER = logging.getLogger(__name__)
LOGGER.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
LOGGER.addHandler(handler)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

MAX_LENGTH = 64
PADDING = "longest"
BATCH_SIZE = 32
TRUNCATION = True
RETURN_TENSORS = "pt"

Dataset#

dataset = load_dataset('financial_phrasebank', 'sentences_allagree', trust_remote_code=True)["train"]
dataset = dataset.rename_column("label", "labels")
dataset
Dataset({
    features: ['sentence', 'labels'],
    num_rows: 2264
})
def count_labels(labels: List[int]) -> Dict[int, int]:
    label_counts = Counter(labels)
    ordered_label_counts = OrderedDict(sorted(label_counts.items()))
    return dict(ordered_label_counts)


sentences_allagree = dataset['sentence']
labels_allagree = dataset['labels']

label_counts = count_labels(labels_allagree)
pprint(label_counts)
{0: 303, 1: 1391, 2: 570}
train_valid_split = dataset.train_test_split(test_size=0.1, shuffle=True, stratify_by_column='labels')
train_dataset = train_valid_split['train']
valid_dataset = train_valid_split['test']

We create our own Dataset just for understanding!

train_df = train_dataset.to_pandas()
valid_df = valid_dataset.to_pandas()
class FinancialDataset(Dataset):
    def __init__(self, df: pd.DataFrame, tokenizer: PreTrainedTokenizer, **tokenizer_kwargs: Any) -> None:
        self.tokenizer = tokenizer
        self.tokenizer_kwargs = tokenizer_kwargs
        self.inputs = df["sentence"].tolist()
        self.labels = df["labels"].tolist()

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
        input_ids = self.tokenizer.encode(text=self.inputs[index], **self.tokenizer_kwargs).long()
        labels = torch.tensor(self.labels[index]).long()
        return {
            "input_ids": input_ids,
            "labels": labels,
        }

We will create the causal mask in the collator.

Tokenizer#

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
pprint(tokenizer.special_tokens_map)

tokenizer.pad_token = tokenizer.eos_token
pprint(tokenizer)
{'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'}
GPT2Tokenizer(name_or_path='gpt2', vocab_size=50257, model_max_length=1024, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
│   │   50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}

Data Collator And DataLoader#

train_dataset = FinancialDataset(train_df, tokenizer=tokenizer, max_length=3, padding=PADDING, truncation=TRUNCATION, return_tensors=RETURN_TENSORS)
valid_dataset = FinancialDataset(valid_df, tokenizer=tokenizer, max_length=3, padding=PADDING, truncation=TRUNCATION, return_tensors=RETURN_TENSORS)
def construct_dummy_batch_causal_masks(batch_size: int, seq_len: int) -> torch.BoolTensor:
    """Broadcast future mask from shape (L, L) to (B, L, L) then (B, 1, L, L)."""
    # Create a lower triangular mask for a single sequence
    future_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool), diagonal=0).to(torch.bool)
    future_mask = future_mask.contiguous()
    # broadcast future mask from shape (L, L) to (B, L, L)
    causal_masks = future_mask.unsqueeze(0).expand(batch_size, -1, -1)
    # broadcast future mask from shape (B, L, L) to (B, 1, L, L)
    causal_masks = causal_masks.unsqueeze(1)
    return torch.BoolTensor(causal_masks)

def collate_for_unidirectional(
    batch: List[Dict[str, torch.Tensor]],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    max_length = max(item["input_ids"].size(1) for item in batch) # 25
    input_ids = torch.zeros((len(batch), max_length), dtype=torch.long)
    labels = torch.zeros(len(batch), dtype=torch.long)

    # do padding manually
    for index, item in enumerate(batch):
        seq_len = item["input_ids"].size(1)
        input_ids[index, :seq_len] = item["input_ids"]
        labels[index] = item["labels"]

    batch_size, seq_len = input_ids.size()

    causal_masks = construct_dummy_batch_causal_masks(batch_size, seq_len)
    return input_ids, labels, causal_masks
seed_all(seed=2024, seed_torch=True, set_torch_deterministic=False)

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate_for_unidirectional)
valid_dataloader = DataLoader(valid_dataset, batch_size=2, shuffle=False, collate_fn=collate_for_unidirectional)

for batch in train_dataloader:
    input_ids, labels, causal_masks = batch
    pprint(input_ids)
    pprint(labels)
    pprint(causal_masks)
    break
tensor([[47117,   351,   262],
│   │   [   37,  3732,   680]])
tensor([0, 0])
tensor([[[[ True, False, False],
│   │     [ True,  True, False],
│   │     [ True,  True,  True]]],
│   │   
│   │   
│   │   [[[ True, False, False],
│   │     [ True,  True, False],
│   │     [ True,  True,  True]]]])

Model Architecture#

class DecoderForSequenceClassificationConfig(DecoderConfig):
    num_labels: int
    head_bias: bool = False
    pre_head_pooling: bool = True


class GPTPretrainedModel(nn.Module):
    def _init_weights(self, module: nn.Module) -> None:
        normal_init_modules = (nn.Linear, nn.Embedding)
        if isinstance(module, normal_init_modules):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if hasattr(module, "bias") and module.bias is not None:
                torch.nn.init.zeros_(module.bias)


class GPTDecoderBlock(nn.Module):
    """GPTDecoderBlock focuses on masked self-attention and feed-forward layers.

    The architecture follows the GPT-style decoder, which only has masked
    self-attention and position-wise feed-forward layers, omitting the
    encoder-decoder cross-attention.
    """

    def __init__(self, config: DecoderConfig) -> None:
        super().__init__()
        self.masked_self_attention_mha = MultiHeadedAttention(
            **config.decoder_block.masked_self_attention_mha.model_dump(mode="python")
        )
        self.feed_forward = PositionwiseFeedForward(**config.decoder_block.feed_forward.model_dump(mode="python"))
        self.add_norm_1 = AddNorm(**config.decoder_block.add_norm_1.model_dump(mode="python"))
        self.add_norm_2 = AddNorm(**config.decoder_block.add_norm_2.model_dump(mode="python"))

    def forward(self, z: torch.Tensor, causal_masks: torch.BoolTensor) -> torch.Tensor:
        """
        Parameters
        ----------
        z:              Input sequence.
                        type:  torch.Tensor
                        shape: (B, S or T, D)

        Returns
        -------
        z:              Output tensor after masked self-attention and feed-forward layers.
                        type:  torch.Tensor
                        shape: (B, S or T, D)
        """
        z = self.add_norm_1(
            z,
            lambda z: self.masked_self_attention_mha(query=z, key=z, value=z, mask=causal_masks),
        )
        z = self.add_norm_2(z, self.feed_forward)
        return z


class GPTBackbone(GPTPretrainedModel):
    def __init__(self, config: DecoderConfig) -> None:
        super().__init__()
        self.d_model: int = config.d_model
        self.tok_embed: nn.Embedding = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_embed: nn.Parameter = nn.Parameter(torch.zeros(1, config.context_length, config.d_model))
        self.decoder_blocks: nn.ModuleList = nn.ModuleList(
            [GPTDecoderBlock(config) for _ in range(config.num_decoder_blocks)]
        )  # PyTorch did not make ModuleList a proper container, maybe open a PR to make it inherit Generic[T]???

        self.dropout: nn.Dropout = nn.Dropout(config.dropout)
        self.layer_norm: nn.LayerNorm = nn.LayerNorm(config.d_model)

        self.apply(self._init_weights)

        context_projections = ("context_projection.weight", "W_O.weight")
        # apply special scaled init to the residual projections, per GPT-2 paper
        for parameter_name, parameter in self.named_parameters():
            # NOTE: W_O is also projection but I did not have foresight to name it as such.
            if parameter_name.endswith(context_projections):
                mean = 0.0
                std_dev = 0.02 / torch.sqrt(torch.tensor(2 * config.num_decoder_blocks, dtype=torch.float))
                torch.nn.init.normal_(parameter, mean=mean, std=std_dev)

    def forward(
        self, input_tokens: torch.LongTensor, *, causal_masks: torch.BoolTensor
    ) -> torch.FloatTensor:
        seq_len: int = input_tokens.size(1)  # note seq_len <= context_length in decoder
        causal_masks = causal_masks.to(input_tokens.device)  # type: ignore[assignment]

        z = self.tok_embed(input_tokens)  # TODO: * math.sqrt(self.d_model) for better optimization landscape
        z = z + self.pos_embed[:, :seq_len, :]
        z = self.dropout(z)

        for decoder_block in self.decoder_blocks:
            z = decoder_block(z, causal_masks=causal_masks)

        z = self.layer_norm(z)
        return z

class LastTokenPooling(nn.Module):
    def __init__(self, pre_head_pooling: bool = True) -> None:
        super().__init__()
        self.pre_head_pooling = pre_head_pooling

    @overload
    def forward(self, last_hidden_state: torch.Tensor, logits: None = None) -> torch.Tensor: ...

    @overload
    def forward(self, last_hidden_state: None, logits: torch.Tensor) -> torch.Tensor: ...

    def forward(self,  last_hidden_state: torch.Tensor | None = None, logits: torch.Tensor | None = None) -> torch.Tensor:
        """Forward pass for the pooling layer.

        Parameters
        ----------
        last_hidden_state:  Hidden state of the last layer.
                            type:  torch.Tensor
                            shape: (B, T, D)
        logits:             Logits from the last layer.
                            type:  torch.Tensor
                            shape: (B, T, C)

        Notes
        -----
        In both cases, we will slice the `T` dimension to get the last token's
        hidden state or logits. For example, if `last_hidden_state` is provided,
        then we have `[B, T, D] -> [B, D]` and if `logits` is provided, then we
        have `[B, T, C] -> [B, C]`.
        """
        if self.pre_head_pooling:
            assert last_hidden_state is not None, "last_hidden_state must be provided when pre_head is True"
            pooled_hidden_state = last_hidden_state[:, -1, :]
            return pooled_hidden_state
        else:
            assert logits is not None, "logits must be provided when pre_head is False"
            pooled_logits = logits[:, -1, :]
            return pooled_logits

class GPTForSequenceClassification(GPTPretrainedModel):
    def __init__(self, config: DecoderForSequenceClassificationConfig) -> None:
        super().__init__()
        self.config = config

        self.backbone = GPTBackbone(config)
        self.pooler = LastTokenPooling(pre_head_pooling=config.pre_head_pooling)
        self.head = nn.Linear(config.d_model, config.num_labels, bias=config.head_bias)

        self.apply(self._init_weights)

        # apply special scaled init to the residual projections, per GPT-2 paper
        for parameter_name, parameter in self.named_parameters():
            if parameter_name.endswith("context_projection.weight"):
                mean = 0.0
                std_dev = 0.02 / torch.sqrt(torch.tensor(2 * config.num_decoder_blocks, dtype=torch.float))
                torch.nn.init.normal_(parameter, mean=mean, std=std_dev)

    def forward(
        self,
        input_tokens: torch.LongTensor,
        *,  # force keyword only arguments to prevent errors
        causal_masks: torch.BoolTensor,
    ) -> torch.FloatTensor:
        """
        Notations
        ---------
        B:      Batch size
        S or L: Source sequence length
        T or L: Target sequence length
        D:      Embedding dimension
        C:      Vocabulary size (Class size)

        Parameters
        ----------
        input_tokens:           Input sequence.
                                type:  torch.Tensor
                                shape: (B, T)
        causal_masks:           Future mask.
                                type:  torch.BoolTensor
                                shape: (B, 1, T, T)

        Variables
        ---------
        z:                      Input sequence after token and position embedding.
                                type:  torch.Tensor
                                shape: (B, T, D)
        causal_masks:           Target mask.
                                type:  torch.BoolTensor
                                shape: (B, 1, T, T)
        logits:                 Output logits.
                                type:  torch.FloatTensor
                                shape: (B, T, C)
        pooled_logits:          Pooled logits.
                                type:  torch.FloatTensor
                                shape: (B, C)
        """

        backbone_last_layer_hidden_state = self.backbone(input_tokens, causal_masks=causal_masks)

        if self.config.pre_head_pooling:
            pooled_hidden_state = self.pooler(backbone_last_layer_hidden_state)
            pooled_logits = self.head(pooled_hidden_state)
        else:
            logits = self.head(backbone_last_layer_hidden_state)
            pooled_logits = self.pooler(logits)
        return pooled_logits
model_config = DecoderForSequenceClassificationConfig(
    d_model=32,
    vocab_size=tokenizer.vocab_size,
    context_length=MAX_LENGTH,
    num_decoder_blocks=1,
    dropout=0.0,
    decoder_block=DecoderBlockConfig(
        masked_self_attention_mha=MultiHeadedAttentionConfig(
            attention=ScaledDotProductAttention(), d_model=32, H=1, dropout=0.0
        ),
        feed_forward=PositionwiseFeedForwardConfig(
            d_model=32, d_ff=32 * 2, activation=nn.GELU(approximate="tanh"), dropout=0.0, bias=True
        ),
        add_norm_1=AddNormConfig(feature_dim=32, dropout=0.0),
        add_norm_2=AddNormConfig(feature_dim=32, dropout=0.0),
    ),
    num_labels=3,
)
model = GPTForSequenceClassification(model_config).to(DEVICE)
pprint(model)
GPTForSequenceClassification(
  (backbone): GPTBackbone(
(tok_embed): Embedding(50257, 32)
(decoder_blocks): ModuleList(
(0): GPTDecoderBlock(
│   │   (masked_self_attention_mha): MultiHeadedAttention(
│   │     (W_Q): Linear(in_features=32, out_features=32, bias=False)
│   │     (W_K): Linear(in_features=32, out_features=32, bias=False)
│   │     (W_V): Linear(in_features=32, out_features=32, bias=False)
│   │     (W_O): Linear(in_features=32, out_features=32, bias=False)
│   │     (attention): ScaledDotProductAttention(
│   │   │   (dropout): Dropout(p=0.0, inplace=False)
│   │     )
│   │     (dropout): Dropout(p=0.0, inplace=False)
│   │   )
│   │   (feed_forward): PositionwiseFeedForward(
│   │     (ffn): ModuleDict(
│   │   │   (context_fc): Linear(in_features=32, out_features=64, bias=True)
│   │   │   (activation): GELU(approximate='tanh')
│   │   │   (context_projection): Linear(in_features=64, out_features=32, bias=True)
│   │   │   (dropout): Dropout(p=0.0, inplace=False)
│   │     )
│   │   )
│   │   (add_norm_1): AddNorm(
│   │     (dropout): Dropout(p=0.0, inplace=False)
│   │     (layer_norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
│   │   )
│   │   (add_norm_2): AddNorm(
│   │     (dropout): Dropout(p=0.0, inplace=False)
│   │     (layer_norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
│   │   )
)
)
(dropout): Dropout(p=0.0, inplace=False)
(layer_norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
  )
  (pooler): LastTokenPooling()
  (head): Linear(in_features=32, out_features=3, bias=False)
)

Dry Run#

In our dry run, we make the following assumptions:

  • batch_size = 2 which means we have \(2\) samples in a batch.

  • MAX_LEN = 3 which means the context length \(T\) is \(3\).

  • d_model = 4 which means the model dimension is \(4\) for hidden layers.

  • Consequently the final output dimension of the backbone is \(\mathcal{B} \times T \times D \rightarrow 2\times 3 \times 4\).

seed_all(seed=2024, seed_torch=True, set_torch_deterministic=False)

dry_run_model_config = DecoderForSequenceClassificationConfig(
    d_model=4,
    vocab_size=tokenizer.vocab_size,
    context_length=MAX_LENGTH,
    num_decoder_blocks=1,
    dropout=0.0,
    decoder_block=DecoderBlockConfig(
        masked_self_attention_mha=MultiHeadedAttentionConfig(
            attention=ScaledDotProductAttention(), d_model=4, H=1, dropout=0.0
        ),
        feed_forward=PositionwiseFeedForwardConfig(
            d_model=4, d_ff=4 * 2, activation=nn.GELU(approximate="tanh"), dropout=0.0, bias=True
        ),
        add_norm_1=AddNormConfig(feature_dim=4, dropout=0.0),
        add_norm_2=AddNormConfig(feature_dim=4, dropout=0.0),
    ),
    num_labels=3,
)

dry_run_model = GPTForSequenceClassification(dry_run_model_config).to(DEVICE)
pprint(dry_run_model)
GPTForSequenceClassification(
  (backbone): GPTBackbone(
(tok_embed): Embedding(50257, 4)
(decoder_blocks): ModuleList(
(0): GPTDecoderBlock(
│   │   (masked_self_attention_mha): MultiHeadedAttention(
│   │     (W_Q): Linear(in_features=4, out_features=4, bias=False)
│   │     (W_K): Linear(in_features=4, out_features=4, bias=False)
│   │     (W_V): Linear(in_features=4, out_features=4, bias=False)
│   │     (W_O): Linear(in_features=4, out_features=4, bias=False)
│   │     (attention): ScaledDotProductAttention(
│   │   │   (dropout): Dropout(p=0.0, inplace=False)
│   │     )
│   │     (dropout): Dropout(p=0.0, inplace=False)
│   │   )
│   │   (feed_forward): PositionwiseFeedForward(
│   │     (ffn): ModuleDict(
│   │   │   (context_fc): Linear(in_features=4, out_features=8, bias=True)
│   │   │   (activation): GELU(approximate='tanh')
│   │   │   (context_projection): Linear(in_features=8, out_features=4, bias=True)
│   │   │   (dropout): Dropout(p=0.0, inplace=False)
│   │     )
│   │   )
│   │   (add_norm_1): AddNorm(
│   │     (dropout): Dropout(p=0.0, inplace=False)
│   │     (layer_norm): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
│   │   )
│   │   (add_norm_2): AddNorm(
│   │     (dropout): Dropout(p=0.0, inplace=False)
│   │     (layer_norm): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
│   │   )
)
)
(dropout): Dropout(p=0.0, inplace=False)
(layer_norm): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
  )
  (pooler): LastTokenPooling()
  (head): Linear(in_features=4, out_features=3, bias=False)
)
seed_all(seed=2024, seed_torch=True, set_torch_deterministic=False)

batch = next(iter(train_dataloader))
input_ids, labels, causal_masks = batch

First, we see the input ids, labels and causal masks to be of the below format.

pprint(input_ids)
pprint(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
pprint(labels)
pprint(causal_masks)
tensor([[47117,   351,   262],
│   │   [   37,  3732,   680]])
'Relations with the'
tensor([0, 0])
tensor([[[[ True, False, False],
│   │     [ True,  True, False],
│   │     [ True,  True,  True]]],
│   │   
│   │   
│   │   [[[ True, False, False],
│   │     [ True,  True, False],
│   │     [ True,  True,  True]]]])
dry_run_backbone = dry_run_model.backbone
dry_run_backbone_last_layer_hidden_state = dry_run_backbone(input_ids, causal_masks=causal_masks)
dry_run_backbone_last_layer_hidden_state = dry_run_backbone_last_layer_hidden_state.detach().cpu()
pprint(dry_run_backbone_last_layer_hidden_state)
pprint(dry_run_backbone_last_layer_hidden_state.shape)
tensor([[[ 1.0732,  0.3017, -1.6063,  0.2314],
│   │    [-0.6302,  1.6116, -0.0369, -0.9445],
│   │    [-0.0674,  1.2399, -1.4665,  0.2940]],
│   │   
│   │   [[ 0.7063,  0.4734, -1.6868,  0.5071],
│   │    [-0.5821,  1.1785,  0.6935, -1.2899],
│   │    [ 0.0365, -0.7257, -0.8955,  1.5847]]])
torch.Size([2, 3, 4])

Here indeed the output of the backbone is of shape [2, 3, 4]. More concretely, we have for the first sequence/example in the batch to be [47117,   351,   262] with the underlying text to be 'Relations with the' and the corresponding label to be 0. Now you see there are 3 tokens in the sequence, it is normal because if we do autoregressive modelling, we need to predict the next token given the previous tokens. However, when we move on to sequence level classification, we actually want to predict the label for the entire sequence and not just say, given the first token, predict the second token and so on. Fundamentally, the backbone is not designed for this task. Currently, the backbone outputs the hidden states for each token in the sequence.

For example,

[ 1.0732,  0.3017, -1.6063,  0.2314] # -> token embedding for `Relations`
[-0.6302,  1.6116, -0.0369, -0.9445] # -> token embedding for `with`
[-0.0674,  1.2399, -1.4665,  0.2940] # -> token embedding for `the`

We introduce the idea of pooling the hidden states to get a single representation for the entire sequence. You can think of it as transforming the hidden states of all 3 tokens in the sequence to 1 single sentence/sequence representation.

[-0.0674,  1.2399, -1.4665,  0.2940] # -> pooled embedding for `Relations with the`

However, in decoder only models, we do not have the [CLS] token to pool the hidden states. However, recall the causal mask format for the first sequence.

[ True, False, False]
[ True,  True, False]
[ True,  True,  True]

Oh, so the last token in the sequence is the one that is not masked - which defaults to cross attention since it has information of every token in the sequence. So, we can simply pool the last token to get the sequence representation.

dry_run_pooler = dry_run_model.pooler
dry_run_pooler_output = dry_run_pooler(dry_run_backbone_last_layer_hidden_state)
dry_run_pooler_output = dry_run_pooler_output.detach().cpu()
pprint(dry_run_pooler_output)
pprint(dry_run_pooler_output.shape)
tensor([[-0.0674,  1.2399, -1.4665,  0.2940],
│   │   [ 0.0365, -0.7257, -0.8955,  1.5847]])
torch.Size([2, 4])

And we got the pooled representation for the first sequence. Earlier I commented this to be the pooled embedding for the sequence Relations with the. However, to be more pedantic, it is merely the embedding for the last token in the sequence and because of the last token being aware of all tokens in the sequence, it can be considered as the pooled embedding for the entire sequence.

[-0.0674,  1.2399, -1.4665,  0.2940] 

So we went from [3, 4] to [1, 4] by pooling the last token. This is a lossy compression but is good enough. If you have done encoder pooling before, you will figure that there are many ways to “better” pool the hidden states. For example, you can mean pool, max pool, etc. However, in decoder only models, we can only make well use of the last token so our pooling is limited to that, unless you swap the causal attention to cross attention, which people do that to benefit from the large number of parameters in the decoder.

Anothing thing is HuggingFace defaults the last token pooling to after the head layer. We offer the option to pool before the head layer as well and the results should be similar.

Lastly, we pass the pooled embeddings to a linear layer to get the logits for the classification task. For our current example we got the logits to be:

[0.0287,  0.0123,  0.0312]
head = dry_run_model.head
logits = head(dry_run_pooler_output)
logits = logits.detach().cpu()
pprint(logits)
tensor([[ 0.0287,  0.0123,  0.0312],
│   │   [ 0.0237, -0.0079,  0.0117]])

Training#

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0048)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=6, eta_min=0.0)
num_epochs = 6
train_dataset = FinancialDataset(train_df, tokenizer=tokenizer, max_length=MAX_LENGTH, padding=PADDING, truncation=TRUNCATION, return_tensors=RETURN_TENSORS)
valid_dataset = FinancialDataset(valid_df, tokenizer=tokenizer, max_length=MAX_LENGTH, padding=PADDING, truncation=TRUNCATION, return_tensors=RETURN_TENSORS)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_for_unidirectional, pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=32, shuffle=False, collate_fn=collate_for_unidirectional, pin_memory=True)
def train_one_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: nn.Module,
    device: torch.device,
) -> Tuple[float, float]:
    model.train()
    total_loss = 0.0
    correct_predictions = 0

    for batch in tqdm(dataloader, desc="Training", leave=True):
        input_ids, labels, causal_masks = (x.to(device) for x in batch)
        optimizer.zero_grad()
        outputs = model(input_tokens=input_ids, causal_masks=causal_masks)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, preds = torch.max(outputs, dim=1)
        correct_predictions += torch.sum(preds == labels).item()

    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions / len(dataloader.dataset)
    return avg_loss, accuracy


def validate_one_epoch(
    model: nn.Module, dataloader: DataLoader, criterion: nn.Module, device: torch.device
) -> Tuple[float, float]:
    model.eval()
    val_loss = 0.0
    val_correct_predictions = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validation", leave=True):
            input_ids, labels, causal_masks = (x.to(device) for x in batch)
            outputs = model(input_tokens=input_ids, causal_masks=causal_masks)
            loss = criterion(outputs, labels)

            val_loss += loss.item()
            _, preds = torch.max(outputs, dim=1)
            val_correct_predictions += torch.sum(preds == labels).item()

    avg_val_loss = val_loss / len(dataloader)
    val_accuracy = val_correct_predictions / len(dataloader.dataset)
    return avg_val_loss, val_accuracy


def train_model(
    model: nn.Module,
    train_dataloader: DataLoader,
    valid_dataloader: DataLoader,
    criterion: nn.Module,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler._LRScheduler,
    num_epochs: int,
    device: torch.device,
) -> None:
    for epoch in range(num_epochs):
        train_loss, train_accuracy = train_one_epoch(model, train_dataloader, optimizer, criterion, device)
        scheduler.step()
        val_loss, val_accuracy = validate_one_epoch(model, valid_dataloader, criterion, device)

        LOGGER.info(
            "Epoch %d/%d - Training loss: %.4f, Training accuracy: %.4f - Validation loss: %.4f, Validation accuracy: %.4f",
            epoch + 1, num_epochs, train_loss, train_accuracy, val_loss, val_accuracy
        )
train_model(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, num_epochs, DEVICE)
2024-07-20 21:12:02,603 - __main__ - INFO - Epoch 1/6 - Training loss: 0.7196, Training accuracy: 0.6966 - Validation loss: 0.5611, Validation accuracy: 0.7841
2024-07-20 21:12:03,551 - __main__ - INFO - Epoch 2/6 - Training loss: 0.4483, Training accuracy: 0.7973 - Validation loss: 0.5083, Validation accuracy: 0.7974
2024-07-20 21:12:04,381 - __main__ - INFO - Epoch 3/6 - Training loss: 0.2279, Training accuracy: 0.9116 - Validation loss: 0.6285, Validation accuracy: 0.7930
2024-07-20 21:12:05,294 - __main__ - INFO - Epoch 4/6 - Training loss: 0.1285, Training accuracy: 0.9607 - Validation loss: 0.4113, Validation accuracy: 0.8722
2024-07-20 21:12:06,251 - __main__ - INFO - Epoch 5/6 - Training loss: 0.0344, Training accuracy: 0.9936 - Validation loss: 0.4243, Validation accuracy: 0.8943
2024-07-20 21:12:07,058 - __main__ - INFO - Epoch 6/6 - Training loss: 0.0220, Training accuracy: 0.9961 - Validation loss: 0.4257, Validation accuracy: 0.9031

The results are just decent, but you can see the model is learning. Tuning decoder only models need some experimentation.

Using HuggingFace#

class Batch(TypedDict):
    sentence: List[str]
    labels: List[int]


class TokenizedBatch(TypedDict):
    input_ids: List[int]
    attention_mask: List[int]
    labels: List[int]


def preprocess_function(batch: Batch, **kwargs: Any) -> TokenizedBatch:
    return tokenizer(batch["sentence"], **kwargs)
dataset = load_dataset("financial_phrasebank", "sentences_allagree", trust_remote_code=True)["train"]
dataset = dataset.rename_column("label", "labels")

train_valid_split = dataset.train_test_split(test_size=0.1, shuffle=True, stratify_by_column="labels")

train_dataset = train_valid_split["train"]
valid_dataset = train_valid_split["test"]

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"

tokenized_train_dataset = train_dataset.map(
    preprocess_function,
    fn_kwargs={"truncation": TRUNCATION, "padding": PADDING, "max_length": MAX_LENGTH},
    batched=True,
    num_proc=psutil.cpu_count(logical=True),
    batch_size=1000,
).remove_columns(["sentence"])

tokenized_valid_dataset = valid_dataset.map(
    preprocess_function,
    fn_kwargs={"truncation": TRUNCATION, "padding": PADDING, "max_length": MAX_LENGTH},
    batched=True,
    num_proc=psutil.cpu_count(logical=True),
    batch_size=1000,
).remove_columns(["sentence"])

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
/opt/conda/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
  warnings.warn(
/opt/conda/lib/python3.10/site-packages/multiprocess/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = os.fork()
id2label = {0: "negative", 1: "neutral", 2: "positive"}
label2id = {"negative": 0, "neutral": 1, "positive": 2}
num_labels = len(id2label)

base_model = GPT2ForSequenceClassification.from_pretrained(
    "gpt2",
    id2label=id2label,
    label2id=label2id,
    num_labels=num_labels,
    problem_type="single_label_classification",
)
base_model.config.pad_token_id = tokenizer.pad_token_id
Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
total_trainable_parameters(base_model) / 1e6
124.442112
training_args = TrainingArguments(
    do_eval=True,
    do_predict=False,
    do_train=True,
    warmup_ratio=0.1,
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    report_to="none",
    output_dir = './artifacts',
    overwrite_output_dir=True,
    gradient_accumulation_steps=1,
    logging_steps=25,
    evaluation_strategy='epoch',
    eval_steps=25,
    save_strategy="epoch",
    save_steps=25,
    load_best_model_at_end=True,
    metric_for_best_model='accuracy',
    lr_scheduler_type='cosine',
    weight_decay=0.01,
    save_total_limit=2,
    seed=42,
    data_seed=42,
    half_precision_backend="auto",
    optim="adamw_torch",
    label_smoothing_factor=0.0,
    max_grad_norm = 1.0,
)
/opt/conda/lib/python3.10/site-packages/transformers/training_args.py:1525: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead
  warnings.warn(
def compute_metrics_for_single_label_classification(eval_prediction: EvalPrediction) -> Dict[str, float | List[float]]:
    logits, labels = eval_prediction.predictions, eval_prediction.label_ids
    probs = softmax(logits, axis=-1)

    num_classes = logits.shape[1]
    preds = np.argmax(probs, axis=1)

    metrics = {
        "eval_log_loss": log_loss(labels, probs),
        "eval_accuracy": accuracy_score(labels, preds),
        "eval_precision_macro": precision_score(labels, preds, average="macro", zero_division=0),
        "eval_recall_macro": recall_score(labels, preds, average="macro", zero_division=0),
        "eval_f1_score_macro": f1_score(labels, preds, average="macro", zero_division=0),
        "eval_precision_micro": precision_score(labels, preds, average="micro", zero_division=0),
        "eval_recall_micro": recall_score(labels, preds, average="micro", zero_division=0),
        "eval_f1_score_micro": f1_score(labels, preds, average="micro", zero_division=0),
        "eval_confusion_matrix": confusion_matrix(labels, preds).tolist(),
        "eval_roc_auc": roc_auc_score(labels, probs, multi_class="ovr"),
        # "eval_pr_auc": average_precision_score(labels, probs, average="macro"),
    }

    if num_classes == 2:
        metrics["eval_brier_score"] = brier_score_loss(labels, probs[:, 1], pos_label=1)
    else:
        brier_scores = [brier_score_loss(labels == i, probs[:, i]) for i in range(num_classes)]
        metrics["eval_brier_score"] = np.mean(brier_scores)

    if num_classes > 2:
        for class_index in range(num_classes):
            fpr, tpr, _ = roc_curve(labels == class_index, probs[:, class_index])
            roc_auc = auc(fpr, tpr)
            precision, recall, _ = precision_recall_curve(labels == class_index, probs[:, class_index])
            pr_auc = auc(recall, precision)
            metrics[f"eval_roc_auc_class_{class_index}"] = roc_auc
            metrics[f"eval_pr_auc_class_{class_index}"] = pr_auc

    return metrics
trainer = Trainer(
    model=base_model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_valid_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics_for_single_label_classification,
)

trainer.train()
[1275/1275 02:05, Epoch 5/5]
Epoch Training Loss Validation Loss Log Loss Accuracy Precision Macro Recall Macro F1 Score Macro Precision Micro Recall Micro F1 Score Micro Confusion Matrix Roc Auc Brier Score Roc Auc Class 0 Pr Auc Class 0 Roc Auc Class 1 Pr Auc Class 1 Roc Auc Class 2 Pr Auc Class 2
1 0.654800 0.575435 0.575435 0.740088 0.612724 0.596449 0.599999 0.740088 0.740088 0.740088 [[9, 5, 16], [0, 125, 15], [12, 11, 34]] 0.876099 0.113814 0.860237 0.379456 0.942775 0.963399 0.825284 0.608643
2 0.478300 0.219012 0.219012 0.929515 0.909526 0.901044 0.903625 0.929515 0.929515 0.929515 [[27, 3, 0], [0, 137, 3], [5, 5, 47]] 0.979735 0.038522 0.991540 0.935628 0.987603 0.991196 0.960062 0.933540
3 0.158000 0.221089 0.221089 0.938326 0.906788 0.942398 0.921905 0.938326 0.938326 0.938326 [[30, 0, 0], [2, 133, 5], [4, 3, 50]] 0.986326 0.031093 0.993739 0.949336 0.993103 0.995271 0.972136 0.952987
4 0.212800 0.227551 0.227551 0.942731 0.914779 0.944779 0.927814 0.942731 0.942731 0.942731 [[30, 0, 0], [1, 134, 5], [4, 3, 50]] 0.989035 0.033271 0.994755 0.961793 0.994745 0.996582 0.977606 0.959318
5 0.101200 0.212625 0.212625 0.942731 0.915999 0.930785 0.923062 0.942731 0.942731 0.942731 [[28, 0, 2], [1, 135, 4], [3, 3, 51]] 0.990026 0.029624 0.994755 0.962396 0.994828 0.996647 0.980495 0.961750

TrainOutput(global_step=1275, training_loss=0.3979669969222125, metrics={'train_runtime': 125.2594, 'train_samples_per_second': 81.311, 'train_steps_per_second': 10.179, 'total_flos': 332666429276160.0, 'train_loss': 0.3979669969222125, 'epoch': 5.0})