How To Fine-Tune Decoder-Only Models For Sequence Classification Using Last Token Pooling?#
Firstly, if you have not read my Generative Pre-trained Transformers (GPT) series, please have a read first to establish some basic understand on what a decoder-only model entails.
Dependencies#
pip install -U omniverse==0.0.63
# %pip install omniverse
from __future__ import annotations
import logging
from collections import Counter, OrderedDict
from typing import Any, Dict, List, Tuple, TypedDict, overload
import numpy as np
import pandas as pd
import psutil
import torch
from datasets import load_dataset
from rich.pretty import pprint
from scipy.special import softmax
from sklearn.metrics import (
accuracy_score,
auc,
average_precision_score,
brier_score_loss,
confusion_matrix,
f1_score,
log_loss,
precision_recall_curve,
precision_score,
recall_score,
roc_auc_score,
roc_curve,
)
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm # Use notebook version for better UI in notebooks
from transformers import (
DataCollatorWithPadding,
EvalPrediction,
GPT2ForSequenceClassification,
GPT2Tokenizer,
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
Trainer,
TrainingArguments,
)
from transformers.trainer_utils import EvalPrediction
from omnivault.transformer.config.decoder import (
AddNormConfig,
DecoderBlockConfig,
DecoderConfig,
MultiHeadedAttentionConfig,
PositionwiseFeedForwardConfig,
)
from omnivault.transformer.modules.attention.core import MultiHeadedAttention, ScaledDotProductAttention
from omnivault.transformer.modules.layers.addnorm import AddNorm
from omnivault.transformer.modules.layers.mlp import PositionwiseFeedForward
from omnivault.utils.reproducibility.seed import seed_all
from omnivault.utils.torch_utils.model_utils import total_trainable_parameters
Setting Up#
seed_all(seed=2024, seed_torch=True, set_torch_deterministic=False)
2024
LOGGER = logging.getLogger(__name__)
LOGGER.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
LOGGER.addHandler(handler)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MAX_LENGTH = 64
PADDING = "longest"
BATCH_SIZE = 32
TRUNCATION = True
RETURN_TENSORS = "pt"
Dataset#
dataset = load_dataset('financial_phrasebank', 'sentences_allagree', trust_remote_code=True)["train"]
dataset = dataset.rename_column("label", "labels")
dataset
Dataset({
features: ['sentence', 'labels'],
num_rows: 2264
})
def count_labels(labels: List[int]) -> Dict[int, int]:
label_counts = Counter(labels)
ordered_label_counts = OrderedDict(sorted(label_counts.items()))
return dict(ordered_label_counts)
sentences_allagree = dataset['sentence']
labels_allagree = dataset['labels']
label_counts = count_labels(labels_allagree)
pprint(label_counts)
{0: 303, 1: 1391, 2: 570}
train_valid_split = dataset.train_test_split(test_size=0.1, shuffle=True, stratify_by_column='labels')
train_dataset = train_valid_split['train']
valid_dataset = train_valid_split['test']
We create our own Dataset
just for understanding!
train_df = train_dataset.to_pandas()
valid_df = valid_dataset.to_pandas()
class FinancialDataset(Dataset):
def __init__(self, df: pd.DataFrame, tokenizer: PreTrainedTokenizer, **tokenizer_kwargs: Any) -> None:
self.tokenizer = tokenizer
self.tokenizer_kwargs = tokenizer_kwargs
self.inputs = df["sentence"].tolist()
self.labels = df["labels"].tolist()
def __len__(self) -> int:
return len(self.labels)
def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
input_ids = self.tokenizer.encode(text=self.inputs[index], **self.tokenizer_kwargs).long()
labels = torch.tensor(self.labels[index]).long()
return {
"input_ids": input_ids,
"labels": labels,
}
We will create the causal mask in the collator.
Tokenizer#
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
pprint(tokenizer.special_tokens_map)
tokenizer.pad_token = tokenizer.eos_token
pprint(tokenizer)
{'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'}
GPT2Tokenizer(name_or_path='gpt2', vocab_size=50257, model_max_length=1024, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True), added_tokens_decoder={ │ │ 50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True), }
Data Collator And DataLoader#
train_dataset = FinancialDataset(train_df, tokenizer=tokenizer, max_length=3, padding=PADDING, truncation=TRUNCATION, return_tensors=RETURN_TENSORS)
valid_dataset = FinancialDataset(valid_df, tokenizer=tokenizer, max_length=3, padding=PADDING, truncation=TRUNCATION, return_tensors=RETURN_TENSORS)
def construct_dummy_batch_causal_masks(batch_size: int, seq_len: int) -> torch.BoolTensor:
"""Broadcast future mask from shape (L, L) to (B, L, L) then (B, 1, L, L)."""
# Create a lower triangular mask for a single sequence
future_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool), diagonal=0).to(torch.bool)
future_mask = future_mask.contiguous()
# broadcast future mask from shape (L, L) to (B, L, L)
causal_masks = future_mask.unsqueeze(0).expand(batch_size, -1, -1)
# broadcast future mask from shape (B, L, L) to (B, 1, L, L)
causal_masks = causal_masks.unsqueeze(1)
return torch.BoolTensor(causal_masks)
def collate_for_unidirectional(
batch: List[Dict[str, torch.Tensor]],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
max_length = max(item["input_ids"].size(1) for item in batch) # 25
input_ids = torch.zeros((len(batch), max_length), dtype=torch.long)
labels = torch.zeros(len(batch), dtype=torch.long)
# do padding manually
for index, item in enumerate(batch):
seq_len = item["input_ids"].size(1)
input_ids[index, :seq_len] = item["input_ids"]
labels[index] = item["labels"]
batch_size, seq_len = input_ids.size()
causal_masks = construct_dummy_batch_causal_masks(batch_size, seq_len)
return input_ids, labels, causal_masks
seed_all(seed=2024, seed_torch=True, set_torch_deterministic=False)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate_for_unidirectional)
valid_dataloader = DataLoader(valid_dataset, batch_size=2, shuffle=False, collate_fn=collate_for_unidirectional)
for batch in train_dataloader:
input_ids, labels, causal_masks = batch
pprint(input_ids)
pprint(labels)
pprint(causal_masks)
break
tensor([[47117, 351, 262], │ │ [ 37, 3732, 680]])
tensor([0, 0])
tensor([[[[ True, False, False], │ │ [ True, True, False], │ │ [ True, True, True]]], │ │ │ │ │ │ [[[ True, False, False], │ │ [ True, True, False], │ │ [ True, True, True]]]])
Model Architecture#
class DecoderForSequenceClassificationConfig(DecoderConfig):
num_labels: int
head_bias: bool = False
pre_head_pooling: bool = True
class GPTPretrainedModel(nn.Module):
def _init_weights(self, module: nn.Module) -> None:
normal_init_modules = (nn.Linear, nn.Embedding)
if isinstance(module, normal_init_modules):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if hasattr(module, "bias") and module.bias is not None:
torch.nn.init.zeros_(module.bias)
class GPTDecoderBlock(nn.Module):
"""GPTDecoderBlock focuses on masked self-attention and feed-forward layers.
The architecture follows the GPT-style decoder, which only has masked
self-attention and position-wise feed-forward layers, omitting the
encoder-decoder cross-attention.
"""
def __init__(self, config: DecoderConfig) -> None:
super().__init__()
self.masked_self_attention_mha = MultiHeadedAttention(
**config.decoder_block.masked_self_attention_mha.model_dump(mode="python")
)
self.feed_forward = PositionwiseFeedForward(**config.decoder_block.feed_forward.model_dump(mode="python"))
self.add_norm_1 = AddNorm(**config.decoder_block.add_norm_1.model_dump(mode="python"))
self.add_norm_2 = AddNorm(**config.decoder_block.add_norm_2.model_dump(mode="python"))
def forward(self, z: torch.Tensor, causal_masks: torch.BoolTensor) -> torch.Tensor:
"""
Parameters
----------
z: Input sequence.
type: torch.Tensor
shape: (B, S or T, D)
Returns
-------
z: Output tensor after masked self-attention and feed-forward layers.
type: torch.Tensor
shape: (B, S or T, D)
"""
z = self.add_norm_1(
z,
lambda z: self.masked_self_attention_mha(query=z, key=z, value=z, mask=causal_masks),
)
z = self.add_norm_2(z, self.feed_forward)
return z
class GPTBackbone(GPTPretrainedModel):
def __init__(self, config: DecoderConfig) -> None:
super().__init__()
self.d_model: int = config.d_model
self.tok_embed: nn.Embedding = nn.Embedding(config.vocab_size, config.d_model)
self.pos_embed: nn.Parameter = nn.Parameter(torch.zeros(1, config.context_length, config.d_model))
self.decoder_blocks: nn.ModuleList = nn.ModuleList(
[GPTDecoderBlock(config) for _ in range(config.num_decoder_blocks)]
) # PyTorch did not make ModuleList a proper container, maybe open a PR to make it inherit Generic[T]???
self.dropout: nn.Dropout = nn.Dropout(config.dropout)
self.layer_norm: nn.LayerNorm = nn.LayerNorm(config.d_model)
self.apply(self._init_weights)
context_projections = ("context_projection.weight", "W_O.weight")
# apply special scaled init to the residual projections, per GPT-2 paper
for parameter_name, parameter in self.named_parameters():
# NOTE: W_O is also projection but I did not have foresight to name it as such.
if parameter_name.endswith(context_projections):
mean = 0.0
std_dev = 0.02 / torch.sqrt(torch.tensor(2 * config.num_decoder_blocks, dtype=torch.float))
torch.nn.init.normal_(parameter, mean=mean, std=std_dev)
def forward(
self, input_tokens: torch.LongTensor, *, causal_masks: torch.BoolTensor
) -> torch.FloatTensor:
seq_len: int = input_tokens.size(1) # note seq_len <= context_length in decoder
causal_masks = causal_masks.to(input_tokens.device) # type: ignore[assignment]
z = self.tok_embed(input_tokens) # TODO: * math.sqrt(self.d_model) for better optimization landscape
z = z + self.pos_embed[:, :seq_len, :]
z = self.dropout(z)
for decoder_block in self.decoder_blocks:
z = decoder_block(z, causal_masks=causal_masks)
z = self.layer_norm(z)
return z
class LastTokenPooling(nn.Module):
def __init__(self, pre_head_pooling: bool = True) -> None:
super().__init__()
self.pre_head_pooling = pre_head_pooling
@overload
def forward(self, last_hidden_state: torch.Tensor, logits: None = None) -> torch.Tensor: ...
@overload
def forward(self, last_hidden_state: None, logits: torch.Tensor) -> torch.Tensor: ...
def forward(self, last_hidden_state: torch.Tensor | None = None, logits: torch.Tensor | None = None) -> torch.Tensor:
"""Forward pass for the pooling layer.
Parameters
----------
last_hidden_state: Hidden state of the last layer.
type: torch.Tensor
shape: (B, T, D)
logits: Logits from the last layer.
type: torch.Tensor
shape: (B, T, C)
Notes
-----
In both cases, we will slice the `T` dimension to get the last token's
hidden state or logits. For example, if `last_hidden_state` is provided,
then we have `[B, T, D] -> [B, D]` and if `logits` is provided, then we
have `[B, T, C] -> [B, C]`.
"""
if self.pre_head_pooling:
assert last_hidden_state is not None, "last_hidden_state must be provided when pre_head is True"
pooled_hidden_state = last_hidden_state[:, -1, :]
return pooled_hidden_state
else:
assert logits is not None, "logits must be provided when pre_head is False"
pooled_logits = logits[:, -1, :]
return pooled_logits
class GPTForSequenceClassification(GPTPretrainedModel):
def __init__(self, config: DecoderForSequenceClassificationConfig) -> None:
super().__init__()
self.config = config
self.backbone = GPTBackbone(config)
self.pooler = LastTokenPooling(pre_head_pooling=config.pre_head_pooling)
self.head = nn.Linear(config.d_model, config.num_labels, bias=config.head_bias)
self.apply(self._init_weights)
# apply special scaled init to the residual projections, per GPT-2 paper
for parameter_name, parameter in self.named_parameters():
if parameter_name.endswith("context_projection.weight"):
mean = 0.0
std_dev = 0.02 / torch.sqrt(torch.tensor(2 * config.num_decoder_blocks, dtype=torch.float))
torch.nn.init.normal_(parameter, mean=mean, std=std_dev)
def forward(
self,
input_tokens: torch.LongTensor,
*, # force keyword only arguments to prevent errors
causal_masks: torch.BoolTensor,
) -> torch.FloatTensor:
"""
Notations
---------
B: Batch size
S or L: Source sequence length
T or L: Target sequence length
D: Embedding dimension
C: Vocabulary size (Class size)
Parameters
----------
input_tokens: Input sequence.
type: torch.Tensor
shape: (B, T)
causal_masks: Future mask.
type: torch.BoolTensor
shape: (B, 1, T, T)
Variables
---------
z: Input sequence after token and position embedding.
type: torch.Tensor
shape: (B, T, D)
causal_masks: Target mask.
type: torch.BoolTensor
shape: (B, 1, T, T)
logits: Output logits.
type: torch.FloatTensor
shape: (B, T, C)
pooled_logits: Pooled logits.
type: torch.FloatTensor
shape: (B, C)
"""
backbone_last_layer_hidden_state = self.backbone(input_tokens, causal_masks=causal_masks)
if self.config.pre_head_pooling:
pooled_hidden_state = self.pooler(backbone_last_layer_hidden_state)
pooled_logits = self.head(pooled_hidden_state)
else:
logits = self.head(backbone_last_layer_hidden_state)
pooled_logits = self.pooler(logits)
return pooled_logits
model_config = DecoderForSequenceClassificationConfig(
d_model=32,
vocab_size=tokenizer.vocab_size,
context_length=MAX_LENGTH,
num_decoder_blocks=1,
dropout=0.0,
decoder_block=DecoderBlockConfig(
masked_self_attention_mha=MultiHeadedAttentionConfig(
attention=ScaledDotProductAttention(), d_model=32, H=1, dropout=0.0
),
feed_forward=PositionwiseFeedForwardConfig(
d_model=32, d_ff=32 * 2, activation=nn.GELU(approximate="tanh"), dropout=0.0, bias=True
),
add_norm_1=AddNormConfig(feature_dim=32, dropout=0.0),
add_norm_2=AddNormConfig(feature_dim=32, dropout=0.0),
),
num_labels=3,
)
model = GPTForSequenceClassification(model_config).to(DEVICE)
pprint(model)
GPTForSequenceClassification( (backbone): GPTBackbone( │ (tok_embed): Embedding(50257, 32) │ (decoder_blocks): ModuleList( │ (0): GPTDecoderBlock( │ │ (masked_self_attention_mha): MultiHeadedAttention( │ │ (W_Q): Linear(in_features=32, out_features=32, bias=False) │ │ (W_K): Linear(in_features=32, out_features=32, bias=False) │ │ (W_V): Linear(in_features=32, out_features=32, bias=False) │ │ (W_O): Linear(in_features=32, out_features=32, bias=False) │ │ (attention): ScaledDotProductAttention( │ │ │ (dropout): Dropout(p=0.0, inplace=False) │ │ ) │ │ (dropout): Dropout(p=0.0, inplace=False) │ │ ) │ │ (feed_forward): PositionwiseFeedForward( │ │ (ffn): ModuleDict( │ │ │ (context_fc): Linear(in_features=32, out_features=64, bias=True) │ │ │ (activation): GELU(approximate='tanh') │ │ │ (context_projection): Linear(in_features=64, out_features=32, bias=True) │ │ │ (dropout): Dropout(p=0.0, inplace=False) │ │ ) │ │ ) │ │ (add_norm_1): AddNorm( │ │ (dropout): Dropout(p=0.0, inplace=False) │ │ (layer_norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True) │ │ ) │ │ (add_norm_2): AddNorm( │ │ (dropout): Dropout(p=0.0, inplace=False) │ │ (layer_norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True) │ │ ) │ ) │ ) │ (dropout): Dropout(p=0.0, inplace=False) │ (layer_norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True) ) (pooler): LastTokenPooling() (head): Linear(in_features=32, out_features=3, bias=False) )
Dry Run#
In our dry run, we make the following assumptions:
batch_size = 2
which means we have \(2\) samples in a batch.MAX_LEN = 3
which means the context length \(T\) is \(3\).d_model = 4
which means the model dimension is \(4\) for hidden layers.Consequently the final output dimension of the backbone is \(\mathcal{B} \times T \times D \rightarrow 2\times 3 \times 4\).
seed_all(seed=2024, seed_torch=True, set_torch_deterministic=False)
dry_run_model_config = DecoderForSequenceClassificationConfig(
d_model=4,
vocab_size=tokenizer.vocab_size,
context_length=MAX_LENGTH,
num_decoder_blocks=1,
dropout=0.0,
decoder_block=DecoderBlockConfig(
masked_self_attention_mha=MultiHeadedAttentionConfig(
attention=ScaledDotProductAttention(), d_model=4, H=1, dropout=0.0
),
feed_forward=PositionwiseFeedForwardConfig(
d_model=4, d_ff=4 * 2, activation=nn.GELU(approximate="tanh"), dropout=0.0, bias=True
),
add_norm_1=AddNormConfig(feature_dim=4, dropout=0.0),
add_norm_2=AddNormConfig(feature_dim=4, dropout=0.0),
),
num_labels=3,
)
dry_run_model = GPTForSequenceClassification(dry_run_model_config).to(DEVICE)
pprint(dry_run_model)
GPTForSequenceClassification( (backbone): GPTBackbone( │ (tok_embed): Embedding(50257, 4) │ (decoder_blocks): ModuleList( │ (0): GPTDecoderBlock( │ │ (masked_self_attention_mha): MultiHeadedAttention( │ │ (W_Q): Linear(in_features=4, out_features=4, bias=False) │ │ (W_K): Linear(in_features=4, out_features=4, bias=False) │ │ (W_V): Linear(in_features=4, out_features=4, bias=False) │ │ (W_O): Linear(in_features=4, out_features=4, bias=False) │ │ (attention): ScaledDotProductAttention( │ │ │ (dropout): Dropout(p=0.0, inplace=False) │ │ ) │ │ (dropout): Dropout(p=0.0, inplace=False) │ │ ) │ │ (feed_forward): PositionwiseFeedForward( │ │ (ffn): ModuleDict( │ │ │ (context_fc): Linear(in_features=4, out_features=8, bias=True) │ │ │ (activation): GELU(approximate='tanh') │ │ │ (context_projection): Linear(in_features=8, out_features=4, bias=True) │ │ │ (dropout): Dropout(p=0.0, inplace=False) │ │ ) │ │ ) │ │ (add_norm_1): AddNorm( │ │ (dropout): Dropout(p=0.0, inplace=False) │ │ (layer_norm): LayerNorm((4,), eps=1e-05, elementwise_affine=True) │ │ ) │ │ (add_norm_2): AddNorm( │ │ (dropout): Dropout(p=0.0, inplace=False) │ │ (layer_norm): LayerNorm((4,), eps=1e-05, elementwise_affine=True) │ │ ) │ ) │ ) │ (dropout): Dropout(p=0.0, inplace=False) │ (layer_norm): LayerNorm((4,), eps=1e-05, elementwise_affine=True) ) (pooler): LastTokenPooling() (head): Linear(in_features=4, out_features=3, bias=False) )
seed_all(seed=2024, seed_torch=True, set_torch_deterministic=False)
batch = next(iter(train_dataloader))
input_ids, labels, causal_masks = batch
First, we see the input ids, labels and causal masks to be of the below format.
pprint(input_ids)
pprint(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
pprint(labels)
pprint(causal_masks)
tensor([[47117, 351, 262], │ │ [ 37, 3732, 680]])
'Relations with the'
tensor([0, 0])
tensor([[[[ True, False, False], │ │ [ True, True, False], │ │ [ True, True, True]]], │ │ │ │ │ │ [[[ True, False, False], │ │ [ True, True, False], │ │ [ True, True, True]]]])
dry_run_backbone = dry_run_model.backbone
dry_run_backbone_last_layer_hidden_state = dry_run_backbone(input_ids, causal_masks=causal_masks)
dry_run_backbone_last_layer_hidden_state = dry_run_backbone_last_layer_hidden_state.detach().cpu()
pprint(dry_run_backbone_last_layer_hidden_state)
pprint(dry_run_backbone_last_layer_hidden_state.shape)
tensor([[[ 1.0732, 0.3017, -1.6063, 0.2314], │ │ [-0.6302, 1.6116, -0.0369, -0.9445], │ │ [-0.0674, 1.2399, -1.4665, 0.2940]], │ │ │ │ [[ 0.7063, 0.4734, -1.6868, 0.5071], │ │ [-0.5821, 1.1785, 0.6935, -1.2899], │ │ [ 0.0365, -0.7257, -0.8955, 1.5847]]])
torch.Size([2, 3, 4])
Here indeed the output of the backbone is of shape [2, 3, 4]
. More concretely,
we have for the first sequence/example in the batch to be [47117, 351, 262]
with the underlying text to be 'Relations with the'
and the corresponding
label to be 0
. Now you see there are 3 tokens in the sequence, it is
normal because if we do autoregressive modelling, we need to predict the next
token given the previous tokens. However, when we move on to sequence level
classification, we actually want to predict the label for the entire sequence
and not just say, given the first token, predict the second token and so on.
Fundamentally, the backbone is not designed for this task. Currently, the
backbone outputs the hidden states for each token in the sequence.
For example,
[ 1.0732, 0.3017, -1.6063, 0.2314] # -> token embedding for `Relations`
[-0.6302, 1.6116, -0.0369, -0.9445] # -> token embedding for `with`
[-0.0674, 1.2399, -1.4665, 0.2940] # -> token embedding for `the`
We introduce the idea of pooling the hidden states to get a single representation for the entire sequence. You can think of it as transforming the hidden states of all 3 tokens in the sequence to 1 single sentence/sequence representation.
[-0.0674, 1.2399, -1.4665, 0.2940] # -> pooled embedding for `Relations with the`
However, in decoder only models, we do not have the [CLS]
token to pool the
hidden states. However, recall the causal mask format for the first sequence.
[ True, False, False]
[ True, True, False]
[ True, True, True]
Oh, so the last token in the sequence is the one that is not masked - which defaults to cross attention since it has information of every token in the sequence. So, we can simply pool the last token to get the sequence representation.
dry_run_pooler = dry_run_model.pooler
dry_run_pooler_output = dry_run_pooler(dry_run_backbone_last_layer_hidden_state)
dry_run_pooler_output = dry_run_pooler_output.detach().cpu()
pprint(dry_run_pooler_output)
pprint(dry_run_pooler_output.shape)
tensor([[-0.0674, 1.2399, -1.4665, 0.2940], │ │ [ 0.0365, -0.7257, -0.8955, 1.5847]])
torch.Size([2, 4])
And we got the pooled representation for the first sequence. Earlier I
commented this to be the pooled embedding for the sequence Relations with the
.
However, to be more pedantic, it is merely the embedding for the last token in the sequence and because of the last token being aware of all tokens in the sequence, it can be considered as the pooled embedding for the entire sequence.
[-0.0674, 1.2399, -1.4665, 0.2940]
So we went from [3, 4]
to [1, 4]
by pooling the last token. This is a lossy
compression but is good enough. If you have done encoder pooling before, you
will figure that there are many ways to “better” pool the hidden states. For
example, you can mean pool, max pool, etc. However, in decoder only models, we
can only make well use of the last token so our pooling is limited to that, unless
you swap the causal attention to cross attention, which people do that to benefit
from the large number of parameters in the decoder.
Anothing thing is HuggingFace defaults the last token pooling to after the head layer. We offer the option to pool before the head layer as well and the results should be similar.
Lastly, we pass the pooled embeddings to a linear layer to get the logits for the classification task. For our current example we got the logits to be:
[0.0287, 0.0123, 0.0312]
head = dry_run_model.head
logits = head(dry_run_pooler_output)
logits = logits.detach().cpu()
pprint(logits)
tensor([[ 0.0287, 0.0123, 0.0312], │ │ [ 0.0237, -0.0079, 0.0117]])
Training#
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0048)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=6, eta_min=0.0)
num_epochs = 6
train_dataset = FinancialDataset(train_df, tokenizer=tokenizer, max_length=MAX_LENGTH, padding=PADDING, truncation=TRUNCATION, return_tensors=RETURN_TENSORS)
valid_dataset = FinancialDataset(valid_df, tokenizer=tokenizer, max_length=MAX_LENGTH, padding=PADDING, truncation=TRUNCATION, return_tensors=RETURN_TENSORS)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_for_unidirectional, pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=32, shuffle=False, collate_fn=collate_for_unidirectional, pin_memory=True)
def train_one_epoch(
model: nn.Module,
dataloader: DataLoader,
optimizer: torch.optim.Optimizer,
criterion: nn.Module,
device: torch.device,
) -> Tuple[float, float]:
model.train()
total_loss = 0.0
correct_predictions = 0
for batch in tqdm(dataloader, desc="Training", leave=True):
input_ids, labels, causal_masks = (x.to(device) for x in batch)
optimizer.zero_grad()
outputs = model(input_tokens=input_ids, causal_masks=causal_masks)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
_, preds = torch.max(outputs, dim=1)
correct_predictions += torch.sum(preds == labels).item()
avg_loss = total_loss / len(dataloader)
accuracy = correct_predictions / len(dataloader.dataset)
return avg_loss, accuracy
def validate_one_epoch(
model: nn.Module, dataloader: DataLoader, criterion: nn.Module, device: torch.device
) -> Tuple[float, float]:
model.eval()
val_loss = 0.0
val_correct_predictions = 0
with torch.no_grad():
for batch in tqdm(dataloader, desc="Validation", leave=True):
input_ids, labels, causal_masks = (x.to(device) for x in batch)
outputs = model(input_tokens=input_ids, causal_masks=causal_masks)
loss = criterion(outputs, labels)
val_loss += loss.item()
_, preds = torch.max(outputs, dim=1)
val_correct_predictions += torch.sum(preds == labels).item()
avg_val_loss = val_loss / len(dataloader)
val_accuracy = val_correct_predictions / len(dataloader.dataset)
return avg_val_loss, val_accuracy
def train_model(
model: nn.Module,
train_dataloader: DataLoader,
valid_dataloader: DataLoader,
criterion: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler._LRScheduler,
num_epochs: int,
device: torch.device,
) -> None:
for epoch in range(num_epochs):
train_loss, train_accuracy = train_one_epoch(model, train_dataloader, optimizer, criterion, device)
scheduler.step()
val_loss, val_accuracy = validate_one_epoch(model, valid_dataloader, criterion, device)
LOGGER.info(
"Epoch %d/%d - Training loss: %.4f, Training accuracy: %.4f - Validation loss: %.4f, Validation accuracy: %.4f",
epoch + 1, num_epochs, train_loss, train_accuracy, val_loss, val_accuracy
)
train_model(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, num_epochs, DEVICE)
2024-07-20 21:12:02,603 - __main__ - INFO - Epoch 1/6 - Training loss: 0.7196, Training accuracy: 0.6966 - Validation loss: 0.5611, Validation accuracy: 0.7841
2024-07-20 21:12:03,551 - __main__ - INFO - Epoch 2/6 - Training loss: 0.4483, Training accuracy: 0.7973 - Validation loss: 0.5083, Validation accuracy: 0.7974
2024-07-20 21:12:04,381 - __main__ - INFO - Epoch 3/6 - Training loss: 0.2279, Training accuracy: 0.9116 - Validation loss: 0.6285, Validation accuracy: 0.7930
2024-07-20 21:12:05,294 - __main__ - INFO - Epoch 4/6 - Training loss: 0.1285, Training accuracy: 0.9607 - Validation loss: 0.4113, Validation accuracy: 0.8722
2024-07-20 21:12:06,251 - __main__ - INFO - Epoch 5/6 - Training loss: 0.0344, Training accuracy: 0.9936 - Validation loss: 0.4243, Validation accuracy: 0.8943
2024-07-20 21:12:07,058 - __main__ - INFO - Epoch 6/6 - Training loss: 0.0220, Training accuracy: 0.9961 - Validation loss: 0.4257, Validation accuracy: 0.9031
The results are just decent, but you can see the model is learning. Tuning decoder only models need some experimentation.
Using HuggingFace#
class Batch(TypedDict):
sentence: List[str]
labels: List[int]
class TokenizedBatch(TypedDict):
input_ids: List[int]
attention_mask: List[int]
labels: List[int]
def preprocess_function(batch: Batch, **kwargs: Any) -> TokenizedBatch:
return tokenizer(batch["sentence"], **kwargs)
dataset = load_dataset("financial_phrasebank", "sentences_allagree", trust_remote_code=True)["train"]
dataset = dataset.rename_column("label", "labels")
train_valid_split = dataset.train_test_split(test_size=0.1, shuffle=True, stratify_by_column="labels")
train_dataset = train_valid_split["train"]
valid_dataset = train_valid_split["test"]
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
tokenized_train_dataset = train_dataset.map(
preprocess_function,
fn_kwargs={"truncation": TRUNCATION, "padding": PADDING, "max_length": MAX_LENGTH},
batched=True,
num_proc=psutil.cpu_count(logical=True),
batch_size=1000,
).remove_columns(["sentence"])
tokenized_valid_dataset = valid_dataset.map(
preprocess_function,
fn_kwargs={"truncation": TRUNCATION, "padding": PADDING, "max_length": MAX_LENGTH},
batched=True,
num_proc=psutil.cpu_count(logical=True),
batch_size=1000,
).remove_columns(["sentence"])
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
/opt/conda/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
warnings.warn(
/opt/conda/lib/python3.10/site-packages/multiprocess/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = os.fork()
id2label = {0: "negative", 1: "neutral", 2: "positive"}
label2id = {"negative": 0, "neutral": 1, "positive": 2}
num_labels = len(id2label)
base_model = GPT2ForSequenceClassification.from_pretrained(
"gpt2",
id2label=id2label,
label2id=label2id,
num_labels=num_labels,
problem_type="single_label_classification",
)
base_model.config.pad_token_id = tokenizer.pad_token_id
Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
total_trainable_parameters(base_model) / 1e6
124.442112
training_args = TrainingArguments(
do_eval=True,
do_predict=False,
do_train=True,
warmup_ratio=0.1,
learning_rate=2e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=5,
report_to="none",
output_dir = './artifacts',
overwrite_output_dir=True,
gradient_accumulation_steps=1,
logging_steps=25,
evaluation_strategy='epoch',
eval_steps=25,
save_strategy="epoch",
save_steps=25,
load_best_model_at_end=True,
metric_for_best_model='accuracy',
lr_scheduler_type='cosine',
weight_decay=0.01,
save_total_limit=2,
seed=42,
data_seed=42,
half_precision_backend="auto",
optim="adamw_torch",
label_smoothing_factor=0.0,
max_grad_norm = 1.0,
)
/opt/conda/lib/python3.10/site-packages/transformers/training_args.py:1525: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead
warnings.warn(
def compute_metrics_for_single_label_classification(eval_prediction: EvalPrediction) -> Dict[str, float | List[float]]:
logits, labels = eval_prediction.predictions, eval_prediction.label_ids
probs = softmax(logits, axis=-1)
num_classes = logits.shape[1]
preds = np.argmax(probs, axis=1)
metrics = {
"eval_log_loss": log_loss(labels, probs),
"eval_accuracy": accuracy_score(labels, preds),
"eval_precision_macro": precision_score(labels, preds, average="macro", zero_division=0),
"eval_recall_macro": recall_score(labels, preds, average="macro", zero_division=0),
"eval_f1_score_macro": f1_score(labels, preds, average="macro", zero_division=0),
"eval_precision_micro": precision_score(labels, preds, average="micro", zero_division=0),
"eval_recall_micro": recall_score(labels, preds, average="micro", zero_division=0),
"eval_f1_score_micro": f1_score(labels, preds, average="micro", zero_division=0),
"eval_confusion_matrix": confusion_matrix(labels, preds).tolist(),
"eval_roc_auc": roc_auc_score(labels, probs, multi_class="ovr"),
# "eval_pr_auc": average_precision_score(labels, probs, average="macro"),
}
if num_classes == 2:
metrics["eval_brier_score"] = brier_score_loss(labels, probs[:, 1], pos_label=1)
else:
brier_scores = [brier_score_loss(labels == i, probs[:, i]) for i in range(num_classes)]
metrics["eval_brier_score"] = np.mean(brier_scores)
if num_classes > 2:
for class_index in range(num_classes):
fpr, tpr, _ = roc_curve(labels == class_index, probs[:, class_index])
roc_auc = auc(fpr, tpr)
precision, recall, _ = precision_recall_curve(labels == class_index, probs[:, class_index])
pr_auc = auc(recall, precision)
metrics[f"eval_roc_auc_class_{class_index}"] = roc_auc
metrics[f"eval_pr_auc_class_{class_index}"] = pr_auc
return metrics
trainer = Trainer(
model=base_model,
args=training_args,
train_dataset=tokenized_train_dataset,
eval_dataset=tokenized_valid_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics_for_single_label_classification,
)
trainer.train()
Epoch | Training Loss | Validation Loss | Log Loss | Accuracy | Precision Macro | Recall Macro | F1 Score Macro | Precision Micro | Recall Micro | F1 Score Micro | Confusion Matrix | Roc Auc | Brier Score | Roc Auc Class 0 | Pr Auc Class 0 | Roc Auc Class 1 | Pr Auc Class 1 | Roc Auc Class 2 | Pr Auc Class 2 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
1 | 0.654800 | 0.575435 | 0.575435 | 0.740088 | 0.612724 | 0.596449 | 0.599999 | 0.740088 | 0.740088 | 0.740088 | [[9, 5, 16], [0, 125, 15], [12, 11, 34]] | 0.876099 | 0.113814 | 0.860237 | 0.379456 | 0.942775 | 0.963399 | 0.825284 | 0.608643 |
2 | 0.478300 | 0.219012 | 0.219012 | 0.929515 | 0.909526 | 0.901044 | 0.903625 | 0.929515 | 0.929515 | 0.929515 | [[27, 3, 0], [0, 137, 3], [5, 5, 47]] | 0.979735 | 0.038522 | 0.991540 | 0.935628 | 0.987603 | 0.991196 | 0.960062 | 0.933540 |
3 | 0.158000 | 0.221089 | 0.221089 | 0.938326 | 0.906788 | 0.942398 | 0.921905 | 0.938326 | 0.938326 | 0.938326 | [[30, 0, 0], [2, 133, 5], [4, 3, 50]] | 0.986326 | 0.031093 | 0.993739 | 0.949336 | 0.993103 | 0.995271 | 0.972136 | 0.952987 |
4 | 0.212800 | 0.227551 | 0.227551 | 0.942731 | 0.914779 | 0.944779 | 0.927814 | 0.942731 | 0.942731 | 0.942731 | [[30, 0, 0], [1, 134, 5], [4, 3, 50]] | 0.989035 | 0.033271 | 0.994755 | 0.961793 | 0.994745 | 0.996582 | 0.977606 | 0.959318 |
5 | 0.101200 | 0.212625 | 0.212625 | 0.942731 | 0.915999 | 0.930785 | 0.923062 | 0.942731 | 0.942731 | 0.942731 | [[28, 0, 2], [1, 135, 4], [3, 3, 51]] | 0.990026 | 0.029624 | 0.994755 | 0.962396 | 0.994828 | 0.996647 | 0.980495 | 0.961750 |
TrainOutput(global_step=1275, training_loss=0.3979669969222125, metrics={'train_runtime': 125.2594, 'train_samples_per_second': 81.311, 'train_steps_per_second': 10.179, 'total_flos': 332666429276160.0, 'train_loss': 0.3979669969222125, 'epoch': 5.0})