How To Do Teacher-Student Knowledge Distillation?

How To Do Teacher-Student Knowledge Distillation?#

Twitter Handle LinkedIn Profile GitHub Profile Tag Tag

Dependencies#

pip install -U omniverse==0.0.63
P100 16GB
# %pip install omniverse==0.0.63
from __future__ import annotations

import logging
from collections import Counter, OrderedDict
from typing import Any, Dict, List, Tuple, TypedDict, overload, cast

import gc
import numpy as np
import pandas as pd
import psutil
import torch
from datasets import load_dataset
from rich.pretty import pprint
from scipy.special import softmax
from sklearn.metrics import (
    accuracy_score,
    auc,
    average_precision_score,
    brier_score_loss,
    confusion_matrix,
    f1_score,
    log_loss,
    precision_recall_curve,
    precision_score,
    recall_score,
    roc_auc_score,
    roc_curve,
)
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm  # Use notebook version for better UI in notebooks
from transformers import (
    DataCollatorWithPadding,
    EvalPrediction,
    GPT2ForSequenceClassification,
    GPT2Tokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
    PreTrainedTokenizerBase,
    PreTrainedTokenizerFast,
    Trainer,
    TrainingArguments,
)
from transformers.trainer_utils import EvalPrediction

from omnivault.transformer.config.decoder import (
    AddNormConfig,
    DecoderBlockConfig,
    DecoderConfig,
    MultiHeadedAttentionConfig,
    PositionwiseFeedForwardConfig,
)
from omnivault.transformer.modules.attention.core import MultiHeadedAttention, ScaledDotProductAttention
from omnivault.transformer.modules.layers.addnorm import AddNorm
from omnivault.transformer.modules.layers.mlp import PositionwiseFeedForward
from omnivault.utils.reproducibility.seed import seed_all
from omnivault.utils.torch_utils.model_utils import total_trainable_parameters

Setting Up#

seed_all(seed=2024, seed_torch=True, set_torch_deterministic=False)
2024
LOGGER = logging.getLogger(__name__)
LOGGER.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
LOGGER.addHandler(handler)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

MAX_LENGTH = 64
PADDING = "longest"
BATCH_SIZE = 32
TRUNCATION = True
RETURN_TENSORS = "pt"

Dataset#

class Batch(TypedDict):
    sentence: List[str]
    labels: List[int]


class TokenizedBatch(TypedDict):
    input_ids: List[int]
    attention_mask: List[int]
    labels: List[int]


def preprocess_function(batch: Batch, **kwargs: Any) -> TokenizedBatch:
    return cast(TokenizedBatch, tokenizer(batch["sentence"], **kwargs))


dataset = load_dataset("financial_phrasebank", "sentences_allagree", trust_remote_code=True)["train"]
dataset = dataset.rename_column("label", "labels")

train_valid_split = dataset.train_test_split(test_size=0.1, shuffle=True, stratify_by_column="labels")

train_dataset = train_valid_split["train"]
valid_dataset = train_valid_split["test"]
class FinancialPhraseDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        sentences: List[str],
        labels: List[int],
        tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
        **tokenizer_kwargs: Any,
    ) -> None:
        self.sentences = sentences
        self.labels = labels
        self.tokenizer = tokenizer
        self.tokenizer_kwargs = tokenizer_kwargs or {
            "max_length": 64,
            "padding": "longest",
            "truncation": True,
            "return_tensors": "pt",
        }

    def __len__(self) -> int:
        return len(self.sentences)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        sentence = self.sentences[idx]
        label = self.labels[idx]

        encoding = self.tokenizer(sentence, **self.tokenizer_kwargs)

        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "labels": torch.tensor(label, dtype=torch.long),
        }

Note again the causal masks is handled internally in huggingface models and we do not need to create it ourselves, hence here all masks are 1s.

from transformers import DebertaV2Tokenizer

tokenizer = DebertaV2Tokenizer.from_pretrained("microsoft/deberta-v3-xsmall")

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

# Set padding side (usually 'right' for DeBERTa)
tokenizer.padding_side = "right"
/opt/conda/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
  warnings.warn(
train_dataset = FinancialPhraseDataset(
    sentences=train_dataset["sentence"],
    labels=train_dataset["labels"],
    tokenizer=tokenizer,
)

valid_dataset = FinancialPhraseDataset(
    sentences=valid_dataset["sentence"],
    labels=valid_dataset["labels"],
    tokenizer=tokenizer,
)
class FinancialPhraseCollator:
    def __init__(
        self,
        tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
        max_length: int,
        padding_side: str = "right",
    ) -> None:
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.padding_side = padding_side

    def pad_sequence(self, sequence: torch.Tensor, target_length: int) -> torch.Tensor:
        pad_length = target_length - sequence.size(0)
        if pad_length <= 0:
            return sequence[:target_length]

        pad_tensor = torch.full((pad_length,), self.tokenizer.pad_token_id, dtype=sequence.dtype)

        if self.padding_side == "right":
            return torch.cat([sequence, pad_tensor])
        else:  # padding_side == "left"
            return torch.cat([pad_tensor, sequence])

    def create_attention_mask(self, sequence: torch.Tensor, target_length: int) -> torch.Tensor:
        attention_mask = torch.ones(target_length, dtype=torch.long)
        if self.padding_side == "right":
            attention_mask[sequence.size(0) :] = 0
        else:  # padding_side == "left"
            attention_mask[: target_length - sequence.size(0)] = 0
        return attention_mask

    def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        max_length = max(len(item["input_ids"]) for item in batch)
        max_length = min(max_length, self.max_length)

        input_ids = torch.stack([self.pad_sequence(item["input_ids"], max_length) for item in batch])
        attention_mask = torch.stack([self.create_attention_mask(item["input_ids"], max_length) for item in batch])
        labels = torch.stack([item["labels"] for item in batch])

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }
collator = FinancialPhraseCollator(
    tokenizer=tokenizer,
    max_length=MAX_LENGTH,
    padding_side=tokenizer.padding_side
)
seed_all(seed=2024, seed_torch=True, set_torch_deterministic=False)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collator
)

valid_dataloader = DataLoader(
    valid_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collator
)

for batch in train_dataloader:
    pprint(batch)
    break
{
'input_ids': tensor([[    1, 10029,   275,  ...,     0,     0,     0],
│   │   [    1, 16246,  8156,  ...,     2,     0,     0],
│   │   [    1,   486,   262,  ...,     0,     0,     0],
│   │   ...,
│   │   [    1,  5670,  5174,  ...,     0,     0,     0],
│   │   [    1, 11764, 48850,  ...,     0,     0,     0],
│   │   [    1,   585,  2784,  ...,     0,     0,     0]]),
'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
│   │   [1, 1, 1,  ..., 1, 0, 0],
│   │   [1, 1, 1,  ..., 0, 0, 0],
│   │   ...,
│   │   [1, 1, 1,  ..., 0, 0, 0],
│   │   [1, 1, 1,  ..., 0, 0, 0],
│   │   [1, 1, 1,  ..., 0, 0, 0]]),
'labels': tensor([0, 0, 1, 1, 1, 2, 1, 2, 1, 1, 0, 2, 1, 1, 1, 2, 2, 1, 1, 1, 2, 1, 2, 0,
│   │   1, 1, 0, 2, 1, 1, 1, 1])
}
from __future__ import annotations

from typing import Tuple

import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import DebertaV2ForSequenceClassification, get_linear_schedule_with_warmup
from transformers.modeling_outputs import SequenceClassifierOutput


def train_one_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler.LambdaLR,
    device: torch.device,
) -> Tuple[float, float]:
    model.train()
    total_loss = 0.0
    correct_predictions = 0
    total_predictions = 0

    for batch in tqdm(dataloader, desc="Training", leave=False):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()
        outputs: SequenceClassifierOutput = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        logits = outputs.logits

        loss.backward()
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
        _, preds = torch.max(logits, dim=1)
        correct_predictions += torch.sum(preds == labels).item()
        total_predictions += labels.size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions / total_predictions
    return avg_loss, accuracy


def valid_one_epoch(model: nn.Module, dataloader: DataLoader, device: torch.device) -> Tuple[float, float]:
    model.eval()
    total_loss = 0.0
    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validation", leave=False):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs: SequenceClassifierOutput = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            logits = outputs.logits

            total_loss += loss.item()
            _, preds = torch.max(logits, dim=1)
            correct_predictions += torch.sum(preds == labels).item()
            total_predictions += labels.size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions / total_predictions
    return avg_loss, accuracy


def train_model(
    model: nn.Module,
    train_dataloader: DataLoader,
    valid_dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler.LambdaLR,
    num_epochs: int,
    device: torch.device,
) -> nn.Module:
    for epoch in range(num_epochs):
        train_loss, train_accuracy = train_one_epoch(model, train_dataloader, optimizer, scheduler, device)
        val_loss, val_accuracy = valid_one_epoch(model, valid_dataloader, device)

        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"Training loss: {train_loss:.4f}, Training accuracy: {train_accuracy:.4f}")
        print(f"Validation loss: {val_loss:.4f}, Validation accuracy: {val_accuracy:.4f}")
        print()

    return model
seed_all(seed=42, seed_torch=True, set_torch_deterministic=False)

NUM_LABELS = 3
NUM_EPOCHS = 2

# Create student model (DeBERTa-v2-xlarge-mnli)
student_model = DebertaV2ForSequenceClassification.from_pretrained("microsoft/deberta-v3-xsmall", num_labels=NUM_LABELS)
student_model.to(DEVICE)

# Cross Entropy Loss
ce_loss = nn.CrossEntropyLoss()

# Train student model naively (without distillation)
print("Training student model naively...")
student_optimizer = torch.optim.AdamW(student_model.parameters(), lr=2e-5)
student_scheduler = get_linear_schedule_with_warmup(
    student_optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader) * NUM_EPOCHS
)
_ = train_model(student_model, train_dataloader, valid_dataloader, student_optimizer, student_scheduler, NUM_EPOCHS, DEVICE)

gc.collect()

# Delete unnecessary objects
del student_model
del student_optimizer
del student_scheduler
del _

# Empty CUDA cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()
Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-xsmall and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Training student model naively...
                                                         
Epoch 1/2
Training loss: 0.8688, Training accuracy: 0.5886
Validation loss: 0.7205, Validation accuracy: 0.6167
                                                         
Epoch 2/2
Training loss: 0.6349, Training accuracy: 0.6937
Validation loss: 0.6428, Validation accuracy: 0.8018
seed_all(seed=42, seed_torch=True, set_torch_deterministic=False)

# Create teacher model (DeBERTa-v2-xlarge)
teacher_model = DebertaV2ForSequenceClassification.from_pretrained("microsoft/deberta-v3-large", num_labels=NUM_LABELS)
teacher_model.to(DEVICE)

# Train teacher model
print("Training teacher model...")
teacher_optimizer = torch.optim.AdamW(teacher_model.parameters(), lr=2e-5)
teacher_scheduler = get_linear_schedule_with_warmup(
    teacher_optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader) * NUM_EPOCHS
)
teacher_model =train_model(teacher_model, train_dataloader, valid_dataloader, teacher_optimizer, teacher_scheduler, NUM_EPOCHS, DEVICE)

gc.collect()

# Delete unnecessary objects
del teacher_optimizer
del teacher_scheduler

# Empty CUDA cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()
Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-large and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Training teacher model...
                                                         
Epoch 1/2
Training loss: 0.5893, Training accuracy: 0.7320
Validation loss: 0.3581, Validation accuracy: 0.8238
                                                         
Epoch 2/2
Training loss: 0.2083, Training accuracy: 0.9293
Validation loss: 0.1159, Validation accuracy: 0.9692
def distillation_loss(
    student_logits: torch.Tensor,
    teacher_logits: torch.Tensor,
    labels: torch.Tensor,
    temperature: float = 2.0,
    alpha: float = 0.5,
) -> torch.Tensor:
    """
    Compute the knowledge distillation loss.

    This loss combines the soft targets from the teacher model with the
    hard targets from the true labels.

    Parameters
    ----------
    student_logits : torch.Tensor
        Raw, unnormalized output scores from the student model.
    teacher_logits : torch.Tensor
        Raw, unnormalized output scores from the teacher model.
    labels : torch.Tensor
        True class labels for the input data.
    temperature : float, optional
        Controls the softness of probability distributions (default 2.0).
    alpha : float, optional
        Balances contribution of soft and hard losses (default 0.5).

    Returns
    -------
    torch.Tensor
        The computed distillation loss.

    Variables
    ---------
    soft_targets : torch.Tensor
        Teacher's predictions converted to a smoothed probability
        distribution.
    soft_prob : torch.Tensor
        Log of the student's smoothed probability predictions.
    soft_loss : torch.Tensor
        KL-Divergence between soft probabilities and soft targets,
        scaled by temperature^2.
    hard_loss : torch.Tensor
        Cross-entropy loss between student logits and true labels.

    Notes
    -----
    The loss is computed as:
    L = α * L_soft + (1 - α) * L_hard

    Where:
    - L_soft is the KL divergence between soft student and teacher probs
    - L_hard is the cross-entropy between student logits and true labels
    """
    soft_targets = (teacher_logits / temperature).softmax(dim=-1)
    soft_prob = (student_logits / temperature).log_softmax(dim=-1)
    soft_loss = nn.KLDivLoss(reduction="batchmean")(soft_prob, soft_targets) * (temperature**2)
    hard_loss = ce_loss(student_logits, labels)
    return alpha * soft_loss + (1 - alpha) * hard_loss


def train_student_one_epoch(
    student_model: nn.Module,
    teacher_model: nn.Module,
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler.LambdaLR,
    device: torch.device,
    *,
    temperature: float = 2.0,
    alpha: float = 0.5,
) -> Tuple[float, float]:
    student_model.train()  # set student model to train mode
    teacher_model.eval()  # set teacher model to eval mode
    total_loss = 0.0
    correct_predictions = 0
    total_predictions = 0

    for batch in tqdm(dataloader, desc="Training", leave=False):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()
        with torch.no_grad():
            teacher_outputs: SequenceClassifierOutput = teacher_model(
                input_ids=input_ids, attention_mask=attention_mask
            )
        student_outputs: SequenceClassifierOutput = student_model(input_ids=input_ids, attention_mask=attention_mask)

        loss = distillation_loss(
            student_outputs.logits, teacher_outputs.logits, labels, temperature=temperature, alpha=alpha
        )
        loss.backward()
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
        _, preds = torch.max(student_outputs.logits, dim=1)
        correct_predictions += torch.sum(preds == labels).item()
        total_predictions += labels.size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions / total_predictions
    return avg_loss, accuracy


def train_student_model(
    student_model: nn.Module,
    teacher_model: nn.Module,
    train_dataloader: DataLoader,
    valid_dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler.LambdaLR,
    num_epochs: int,
    device: torch.device,
    *,
    temperature: float = 2.0,
    alpha: float = 0.5,
) -> None:
    for epoch in range(num_epochs):
        train_loss, train_accuracy = train_student_one_epoch(
            student_model,
            teacher_model,
            train_dataloader,
            optimizer,
            scheduler,
            device,
            temperature=temperature,
            alpha=alpha,
        )
        val_loss, val_accuracy = valid_one_epoch(student_model, valid_dataloader, device)

        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"Training loss: {train_loss:.4f}, Training accuracy: {train_accuracy:.4f}")
        print(f"Validation loss: {val_loss:.4f}, Validation accuracy: {val_accuracy:.4f}")
        print()

Term

Definition

Explanation

Student Logits (\(z_s\))

Raw unnormalized scores from the student model

\(z_s = f_s(x)\) where \(f_s\) is the student model and \(x\) is the input

Teacher Logits (\(z_t\))

Raw unnormalized scores from the teacher model

\(z_t = f_t(x)\) where \(f_t\) is the teacher model, and the parameter space \(\theta_t\) is larger than the parameter space \(\theta_s\) of the student model.

Labels (\(y\))

True class labels for the input data

\(y \in \{1, \ldots, K\}\) for \(K\) classes

Temperature (\(T\))

Hyperparameter controlling distribution softness

\(T \in \mathbb{R}^+\), typically \(T > 1\). Higher temperatures produce softer probability distributions, emphasizing the relative differences between class probabilities.

Alpha (\(\alpha\))

Balancing hyperparameter for soft and hard losses

\(\alpha \in [0, 1] \) and \(\alpha\) is the weight given to the distillation (soft) loss and \(1 - \alpha\) gives the weight for the hard loss.

Soft Targets (\(p_t\))

\(p_t = \text{softmax}(z_t / T)\)

\(p_t^{(i)} = \frac{\exp\left(z_t^{(i)}/T\right)}{\sum_j \exp\left(z_t^{(j)}/T\right)}\) The teacher’s predictions converted to a probability distribution and smoothed by the temperature.

Soft Probabilities (\(p_s\))

\(p_s = \log\text{softmax}(z_s / T)\)

\(p_s^{(i)} = \log\frac{\exp\left(z_s^{(i)}/T\right)}{\sum_j \exp\left(z_s^{(j)}/T\right)}\)

Soft Loss (\(L_\text{soft}\))

\(L_\text{soft} = T^2 \cdot \text{KL}(p_t \Vert p_s)\)

Measures divergence between soft probabilities and soft targets. Measures how well the student’s predictions match the teacher’s smoothed predictions.

Hard Loss (\(L_\text{hard}\))

\(L_\text{hard} = \text{CE}(z_s, y)\)

The standard supervised learning loss, measuring how well the student’s predictions match the true labels.

Distillation Loss (\(L_\text{distill}\))

\(L_\text{distill} = \alpha L_\text{soft} + (1-\alpha) L_\text{hard}\)

The combined loss function that the student model optimizes, balancing between mimicking the teacher and predicting true labels.

KL-Divergence

\(\text{KL}(p \Vert q) = \sum_i p_i \log\frac{p_i}{q_i}\)

Measures difference between two probability distributions, and in this scenario, it quantifies how much the student’s prediction distribution differs from the teacher’s.

Cross-Entropy Loss

\(\text{CE}(z, y) = -\sum_i y_i \log(\text{softmax}(z)_i)\)

Standard loss for multi-class classification

Note: In the table, \(i\) and \(j\) are used as indices for classes, \(\exp\) denotes the exponential function, and \(\log\) is the natural logarithm. The softmax function is defined as \(\text{softmax}(x)_i = \frac{\exp(x_i)}{\sum_j \exp(x_j)}\).

seed_all(seed=42, seed_torch=True, set_torch_deterministic=False)

student_model = DebertaV2ForSequenceClassification.from_pretrained("microsoft/deberta-v3-xsmall", num_labels=NUM_LABELS)
student_model.to(DEVICE)

# Train student model
student_optimizer = torch.optim.AdamW(student_model.parameters(), lr=2e-5)
student_scheduler = get_linear_schedule_with_warmup(
    student_optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader) * NUM_EPOCHS
)
train_student_model(
    student_model,
    teacher_model,
    train_dataloader,
    valid_dataloader,
    student_optimizer,
    student_scheduler,
    NUM_EPOCHS,
    DEVICE,
    temperature=0.5,
    alpha=0.8
)
Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-xsmall and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                                         
Epoch 1/2
Training loss: 0.3327, Training accuracy: 0.5925
Validation loss: 0.7193, Validation accuracy: 0.6167
                                                         
Epoch 2/2
Training loss: 0.2407, Training accuracy: 0.7472
Validation loss: 0.6435, Validation accuracy: 0.8150

We see an improvement over vanilla student model.

Epoch 1/2
Training loss: 0.8688, Training accuracy: 0.5886
Validation loss: 0.7205, Validation accuracy: 0.6167

                                                         
Epoch 2/2
Training loss: 0.6349, Training accuracy: 0.6937
Validation loss: 0.6428, Validation accuracy: 0.8018

References And Further Readings#

There are many variants of how one can do knowledge distillation. For example, you can also use unsupervised data to do knowledge distillation. See the example in Huggingface.