How To Do Teacher-Student Knowledge Distillation?#
Dependencies#
pip install -U omniverse==0.0.63
P100 16GB
# %pip install omniverse==0.0.63
from __future__ import annotations
import logging
from collections import Counter, OrderedDict
from typing import Any, Dict, List, Tuple, TypedDict, overload, cast
import gc
import numpy as np
import pandas as pd
import psutil
import torch
from datasets import load_dataset
from rich.pretty import pprint
from scipy.special import softmax
from sklearn.metrics import (
accuracy_score,
auc,
average_precision_score,
brier_score_loss,
confusion_matrix,
f1_score,
log_loss,
precision_recall_curve,
precision_score,
recall_score,
roc_auc_score,
roc_curve,
)
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm # Use notebook version for better UI in notebooks
from transformers import (
DataCollatorWithPadding,
EvalPrediction,
GPT2ForSequenceClassification,
GPT2Tokenizer,
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
Trainer,
TrainingArguments,
)
from transformers.trainer_utils import EvalPrediction
from omnivault.transformer.config.decoder import (
AddNormConfig,
DecoderBlockConfig,
DecoderConfig,
MultiHeadedAttentionConfig,
PositionwiseFeedForwardConfig,
)
from omnivault.transformer.modules.attention.core import MultiHeadedAttention, ScaledDotProductAttention
from omnivault.transformer.modules.layers.addnorm import AddNorm
from omnivault.transformer.modules.layers.mlp import PositionwiseFeedForward
from omnivault.utils.reproducibility.seed import seed_all
from omnivault.utils.torch_utils.model_utils import total_trainable_parameters
Setting Up#
seed_all(seed=2024, seed_torch=True, set_torch_deterministic=False)
2024
LOGGER = logging.getLogger(__name__)
LOGGER.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
LOGGER.addHandler(handler)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MAX_LENGTH = 64
PADDING = "longest"
BATCH_SIZE = 32
TRUNCATION = True
RETURN_TENSORS = "pt"
Dataset#
class Batch(TypedDict):
sentence: List[str]
labels: List[int]
class TokenizedBatch(TypedDict):
input_ids: List[int]
attention_mask: List[int]
labels: List[int]
def preprocess_function(batch: Batch, **kwargs: Any) -> TokenizedBatch:
return cast(TokenizedBatch, tokenizer(batch["sentence"], **kwargs))
dataset = load_dataset("financial_phrasebank", "sentences_allagree", trust_remote_code=True)["train"]
dataset = dataset.rename_column("label", "labels")
train_valid_split = dataset.train_test_split(test_size=0.1, shuffle=True, stratify_by_column="labels")
train_dataset = train_valid_split["train"]
valid_dataset = train_valid_split["test"]
class FinancialPhraseDataset(torch.utils.data.Dataset):
def __init__(
self,
sentences: List[str],
labels: List[int],
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
**tokenizer_kwargs: Any,
) -> None:
self.sentences = sentences
self.labels = labels
self.tokenizer = tokenizer
self.tokenizer_kwargs = tokenizer_kwargs or {
"max_length": 64,
"padding": "longest",
"truncation": True,
"return_tensors": "pt",
}
def __len__(self) -> int:
return len(self.sentences)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
sentence = self.sentences[idx]
label = self.labels[idx]
encoding = self.tokenizer(sentence, **self.tokenizer_kwargs)
return {
"input_ids": encoding["input_ids"].squeeze(0),
"attention_mask": encoding["attention_mask"].squeeze(0),
"labels": torch.tensor(label, dtype=torch.long),
}
Note again the causal masks is handled internally in huggingface models and we do not need to create it ourselves, hence here all masks are 1s.
from transformers import DebertaV2Tokenizer
tokenizer = DebertaV2Tokenizer.from_pretrained("microsoft/deberta-v3-xsmall")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
# Set padding side (usually 'right' for DeBERTa)
tokenizer.padding_side = "right"
/opt/conda/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
warnings.warn(
train_dataset = FinancialPhraseDataset(
sentences=train_dataset["sentence"],
labels=train_dataset["labels"],
tokenizer=tokenizer,
)
valid_dataset = FinancialPhraseDataset(
sentences=valid_dataset["sentence"],
labels=valid_dataset["labels"],
tokenizer=tokenizer,
)
class FinancialPhraseCollator:
def __init__(
self,
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
max_length: int,
padding_side: str = "right",
) -> None:
self.tokenizer = tokenizer
self.max_length = max_length
self.padding_side = padding_side
def pad_sequence(self, sequence: torch.Tensor, target_length: int) -> torch.Tensor:
pad_length = target_length - sequence.size(0)
if pad_length <= 0:
return sequence[:target_length]
pad_tensor = torch.full((pad_length,), self.tokenizer.pad_token_id, dtype=sequence.dtype)
if self.padding_side == "right":
return torch.cat([sequence, pad_tensor])
else: # padding_side == "left"
return torch.cat([pad_tensor, sequence])
def create_attention_mask(self, sequence: torch.Tensor, target_length: int) -> torch.Tensor:
attention_mask = torch.ones(target_length, dtype=torch.long)
if self.padding_side == "right":
attention_mask[sequence.size(0) :] = 0
else: # padding_side == "left"
attention_mask[: target_length - sequence.size(0)] = 0
return attention_mask
def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
max_length = max(len(item["input_ids"]) for item in batch)
max_length = min(max_length, self.max_length)
input_ids = torch.stack([self.pad_sequence(item["input_ids"], max_length) for item in batch])
attention_mask = torch.stack([self.create_attention_mask(item["input_ids"], max_length) for item in batch])
labels = torch.stack([item["labels"] for item in batch])
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
collator = FinancialPhraseCollator(
tokenizer=tokenizer,
max_length=MAX_LENGTH,
padding_side=tokenizer.padding_side
)
seed_all(seed=2024, seed_torch=True, set_torch_deterministic=False)
train_dataloader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
collate_fn=collator
)
valid_dataloader = DataLoader(
valid_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
collate_fn=collator
)
for batch in train_dataloader:
pprint(batch)
break
{ │ 'input_ids': tensor([[ 1, 10029, 275, ..., 0, 0, 0], │ │ [ 1, 16246, 8156, ..., 2, 0, 0], │ │ [ 1, 486, 262, ..., 0, 0, 0], │ │ ..., │ │ [ 1, 5670, 5174, ..., 0, 0, 0], │ │ [ 1, 11764, 48850, ..., 0, 0, 0], │ │ [ 1, 585, 2784, ..., 0, 0, 0]]), │ 'attention_mask': tensor([[1, 1, 1, ..., 0, 0, 0], │ │ [1, 1, 1, ..., 1, 0, 0], │ │ [1, 1, 1, ..., 0, 0, 0], │ │ ..., │ │ [1, 1, 1, ..., 0, 0, 0], │ │ [1, 1, 1, ..., 0, 0, 0], │ │ [1, 1, 1, ..., 0, 0, 0]]), │ 'labels': tensor([0, 0, 1, 1, 1, 2, 1, 2, 1, 1, 0, 2, 1, 1, 1, 2, 2, 1, 1, 1, 2, 1, 2, 0, │ │ 1, 1, 0, 2, 1, 1, 1, 1]) }
from __future__ import annotations
from typing import Tuple
import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import DebertaV2ForSequenceClassification, get_linear_schedule_with_warmup
from transformers.modeling_outputs import SequenceClassifierOutput
def train_one_epoch(
model: nn.Module,
dataloader: DataLoader,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LambdaLR,
device: torch.device,
) -> Tuple[float, float]:
model.train()
total_loss = 0.0
correct_predictions = 0
total_predictions = 0
for batch in tqdm(dataloader, desc="Training", leave=False):
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
optimizer.zero_grad()
outputs: SequenceClassifierOutput = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
logits = outputs.logits
loss.backward()
optimizer.step()
scheduler.step()
total_loss += loss.item()
_, preds = torch.max(logits, dim=1)
correct_predictions += torch.sum(preds == labels).item()
total_predictions += labels.size(0)
avg_loss = total_loss / len(dataloader)
accuracy = correct_predictions / total_predictions
return avg_loss, accuracy
def valid_one_epoch(model: nn.Module, dataloader: DataLoader, device: torch.device) -> Tuple[float, float]:
model.eval()
total_loss = 0.0
correct_predictions = 0
total_predictions = 0
with torch.no_grad():
for batch in tqdm(dataloader, desc="Validation", leave=False):
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
outputs: SequenceClassifierOutput = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
logits = outputs.logits
total_loss += loss.item()
_, preds = torch.max(logits, dim=1)
correct_predictions += torch.sum(preds == labels).item()
total_predictions += labels.size(0)
avg_loss = total_loss / len(dataloader)
accuracy = correct_predictions / total_predictions
return avg_loss, accuracy
def train_model(
model: nn.Module,
train_dataloader: DataLoader,
valid_dataloader: DataLoader,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LambdaLR,
num_epochs: int,
device: torch.device,
) -> nn.Module:
for epoch in range(num_epochs):
train_loss, train_accuracy = train_one_epoch(model, train_dataloader, optimizer, scheduler, device)
val_loss, val_accuracy = valid_one_epoch(model, valid_dataloader, device)
print(f"Epoch {epoch+1}/{num_epochs}")
print(f"Training loss: {train_loss:.4f}, Training accuracy: {train_accuracy:.4f}")
print(f"Validation loss: {val_loss:.4f}, Validation accuracy: {val_accuracy:.4f}")
print()
return model
seed_all(seed=42, seed_torch=True, set_torch_deterministic=False)
NUM_LABELS = 3
NUM_EPOCHS = 2
# Create student model (DeBERTa-v2-xlarge-mnli)
student_model = DebertaV2ForSequenceClassification.from_pretrained("microsoft/deberta-v3-xsmall", num_labels=NUM_LABELS)
student_model.to(DEVICE)
# Cross Entropy Loss
ce_loss = nn.CrossEntropyLoss()
# Train student model naively (without distillation)
print("Training student model naively...")
student_optimizer = torch.optim.AdamW(student_model.parameters(), lr=2e-5)
student_scheduler = get_linear_schedule_with_warmup(
student_optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader) * NUM_EPOCHS
)
_ = train_model(student_model, train_dataloader, valid_dataloader, student_optimizer, student_scheduler, NUM_EPOCHS, DEVICE)
gc.collect()
# Delete unnecessary objects
del student_model
del student_optimizer
del student_scheduler
del _
# Empty CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-xsmall and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Training student model naively...
Epoch 1/2
Training loss: 0.8688, Training accuracy: 0.5886
Validation loss: 0.7205, Validation accuracy: 0.6167
Epoch 2/2
Training loss: 0.6349, Training accuracy: 0.6937
Validation loss: 0.6428, Validation accuracy: 0.8018
seed_all(seed=42, seed_torch=True, set_torch_deterministic=False)
# Create teacher model (DeBERTa-v2-xlarge)
teacher_model = DebertaV2ForSequenceClassification.from_pretrained("microsoft/deberta-v3-large", num_labels=NUM_LABELS)
teacher_model.to(DEVICE)
# Train teacher model
print("Training teacher model...")
teacher_optimizer = torch.optim.AdamW(teacher_model.parameters(), lr=2e-5)
teacher_scheduler = get_linear_schedule_with_warmup(
teacher_optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader) * NUM_EPOCHS
)
teacher_model =train_model(teacher_model, train_dataloader, valid_dataloader, teacher_optimizer, teacher_scheduler, NUM_EPOCHS, DEVICE)
gc.collect()
# Delete unnecessary objects
del teacher_optimizer
del teacher_scheduler
# Empty CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-large and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Training teacher model...
Epoch 1/2
Training loss: 0.5893, Training accuracy: 0.7320
Validation loss: 0.3581, Validation accuracy: 0.8238
Epoch 2/2
Training loss: 0.2083, Training accuracy: 0.9293
Validation loss: 0.1159, Validation accuracy: 0.9692
def distillation_loss(
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor,
temperature: float = 2.0,
alpha: float = 0.5,
) -> torch.Tensor:
"""
Compute the knowledge distillation loss.
This loss combines the soft targets from the teacher model with the
hard targets from the true labels.
Parameters
----------
student_logits : torch.Tensor
Raw, unnormalized output scores from the student model.
teacher_logits : torch.Tensor
Raw, unnormalized output scores from the teacher model.
labels : torch.Tensor
True class labels for the input data.
temperature : float, optional
Controls the softness of probability distributions (default 2.0).
alpha : float, optional
Balances contribution of soft and hard losses (default 0.5).
Returns
-------
torch.Tensor
The computed distillation loss.
Variables
---------
soft_targets : torch.Tensor
Teacher's predictions converted to a smoothed probability
distribution.
soft_prob : torch.Tensor
Log of the student's smoothed probability predictions.
soft_loss : torch.Tensor
KL-Divergence between soft probabilities and soft targets,
scaled by temperature^2.
hard_loss : torch.Tensor
Cross-entropy loss between student logits and true labels.
Notes
-----
The loss is computed as:
L = α * L_soft + (1 - α) * L_hard
Where:
- L_soft is the KL divergence between soft student and teacher probs
- L_hard is the cross-entropy between student logits and true labels
"""
soft_targets = (teacher_logits / temperature).softmax(dim=-1)
soft_prob = (student_logits / temperature).log_softmax(dim=-1)
soft_loss = nn.KLDivLoss(reduction="batchmean")(soft_prob, soft_targets) * (temperature**2)
hard_loss = ce_loss(student_logits, labels)
return alpha * soft_loss + (1 - alpha) * hard_loss
def train_student_one_epoch(
student_model: nn.Module,
teacher_model: nn.Module,
dataloader: DataLoader,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LambdaLR,
device: torch.device,
*,
temperature: float = 2.0,
alpha: float = 0.5,
) -> Tuple[float, float]:
student_model.train() # set student model to train mode
teacher_model.eval() # set teacher model to eval mode
total_loss = 0.0
correct_predictions = 0
total_predictions = 0
for batch in tqdm(dataloader, desc="Training", leave=False):
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
optimizer.zero_grad()
with torch.no_grad():
teacher_outputs: SequenceClassifierOutput = teacher_model(
input_ids=input_ids, attention_mask=attention_mask
)
student_outputs: SequenceClassifierOutput = student_model(input_ids=input_ids, attention_mask=attention_mask)
loss = distillation_loss(
student_outputs.logits, teacher_outputs.logits, labels, temperature=temperature, alpha=alpha
)
loss.backward()
optimizer.step()
scheduler.step()
total_loss += loss.item()
_, preds = torch.max(student_outputs.logits, dim=1)
correct_predictions += torch.sum(preds == labels).item()
total_predictions += labels.size(0)
avg_loss = total_loss / len(dataloader)
accuracy = correct_predictions / total_predictions
return avg_loss, accuracy
def train_student_model(
student_model: nn.Module,
teacher_model: nn.Module,
train_dataloader: DataLoader,
valid_dataloader: DataLoader,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LambdaLR,
num_epochs: int,
device: torch.device,
*,
temperature: float = 2.0,
alpha: float = 0.5,
) -> None:
for epoch in range(num_epochs):
train_loss, train_accuracy = train_student_one_epoch(
student_model,
teacher_model,
train_dataloader,
optimizer,
scheduler,
device,
temperature=temperature,
alpha=alpha,
)
val_loss, val_accuracy = valid_one_epoch(student_model, valid_dataloader, device)
print(f"Epoch {epoch+1}/{num_epochs}")
print(f"Training loss: {train_loss:.4f}, Training accuracy: {train_accuracy:.4f}")
print(f"Validation loss: {val_loss:.4f}, Validation accuracy: {val_accuracy:.4f}")
print()
Term |
Definition |
Explanation |
---|---|---|
Student Logits (\(z_s\)) |
Raw unnormalized scores from the student model |
\(z_s = f_s(x)\) where \(f_s\) is the student model and \(x\) is the input |
Teacher Logits (\(z_t\)) |
Raw unnormalized scores from the teacher model |
\(z_t = f_t(x)\) where \(f_t\) is the teacher model, and the parameter space \(\theta_t\) is larger than the parameter space \(\theta_s\) of the student model. |
Labels (\(y\)) |
True class labels for the input data |
\(y \in \{1, \ldots, K\}\) for \(K\) classes |
Temperature (\(T\)) |
Hyperparameter controlling distribution softness |
\(T \in \mathbb{R}^+\), typically \(T > 1\). Higher temperatures produce softer probability distributions, emphasizing the relative differences between class probabilities. |
Alpha (\(\alpha\)) |
Balancing hyperparameter for soft and hard losses |
\(\alpha \in [0, 1] \) and \(\alpha\) is the weight given to the distillation (soft) loss and \(1 - \alpha\) gives the weight for the hard loss. |
Soft Targets (\(p_t\)) |
\(p_t = \text{softmax}(z_t / T)\) |
\(p_t^{(i)} = \frac{\exp\left(z_t^{(i)}/T\right)}{\sum_j \exp\left(z_t^{(j)}/T\right)}\) The teacher’s predictions converted to a probability distribution and smoothed by the temperature. |
Soft Probabilities (\(p_s\)) |
\(p_s = \log\text{softmax}(z_s / T)\) |
\(p_s^{(i)} = \log\frac{\exp\left(z_s^{(i)}/T\right)}{\sum_j \exp\left(z_s^{(j)}/T\right)}\) |
Soft Loss (\(L_\text{soft}\)) |
\(L_\text{soft} = T^2 \cdot \text{KL}(p_t \Vert p_s)\) |
Measures divergence between soft probabilities and soft targets. Measures how well the student’s predictions match the teacher’s smoothed predictions. |
Hard Loss (\(L_\text{hard}\)) |
\(L_\text{hard} = \text{CE}(z_s, y)\) |
The standard supervised learning loss, measuring how well the student’s predictions match the true labels. |
Distillation Loss (\(L_\text{distill}\)) |
\(L_\text{distill} = \alpha L_\text{soft} + (1-\alpha) L_\text{hard}\) |
The combined loss function that the student model optimizes, balancing between mimicking the teacher and predicting true labels. |
KL-Divergence |
\(\text{KL}(p \Vert q) = \sum_i p_i \log\frac{p_i}{q_i}\) |
Measures difference between two probability distributions, and in this scenario, it quantifies how much the student’s prediction distribution differs from the teacher’s. |
Cross-Entropy Loss |
\(\text{CE}(z, y) = -\sum_i y_i \log(\text{softmax}(z)_i)\) |
Standard loss for multi-class classification |
Note: In the table, \(i\) and \(j\) are used as indices for classes, \(\exp\) denotes the exponential function, and \(\log\) is the natural logarithm. The softmax function is defined as \(\text{softmax}(x)_i = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\).
seed_all(seed=42, seed_torch=True, set_torch_deterministic=False)
student_model = DebertaV2ForSequenceClassification.from_pretrained("microsoft/deberta-v3-xsmall", num_labels=NUM_LABELS)
student_model.to(DEVICE)
# Train student model
student_optimizer = torch.optim.AdamW(student_model.parameters(), lr=2e-5)
student_scheduler = get_linear_schedule_with_warmup(
student_optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader) * NUM_EPOCHS
)
train_student_model(
student_model,
teacher_model,
train_dataloader,
valid_dataloader,
student_optimizer,
student_scheduler,
NUM_EPOCHS,
DEVICE,
temperature=0.5,
alpha=0.8
)
Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-xsmall and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1/2
Training loss: 0.3327, Training accuracy: 0.5925
Validation loss: 0.7193, Validation accuracy: 0.6167
Epoch 2/2
Training loss: 0.2407, Training accuracy: 0.7472
Validation loss: 0.6435, Validation accuracy: 0.8150
We see an improvement over vanilla student model.
Epoch 1/2
Training loss: 0.8688, Training accuracy: 0.5886
Validation loss: 0.7205, Validation accuracy: 0.6167
Epoch 2/2
Training loss: 0.6349, Training accuracy: 0.6937
Validation loss: 0.6428, Validation accuracy: 0.8018
References And Further Readings#
There are many variants of how one can do knowledge distillation. For example, you can also use unsupervised data to do knowledge distillation. See the example in Huggingface.
G. Hinton, O. Vinyals, and J. Dean, “Distilling the Knowledge in a Neural Network,” arXiv preprint arXiv:1503.02531, 2015. [Online]. Available: https://arxiv.org/abs/1503.02531
https://huggingface.co/docs/setfit/en/how_to/knowledge_distillation
https://douglasorr.github.io/2021-10-training-objectives/2-teacher/article.html
https://keras.io/examples/keras_recipes/better_knowledge_distillation/
https://pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html