Implementation#

Twitter Handle LinkedIn Profile GitHub Profile Tag Tag Code

Merge And Quantize#

# %pip install -U -q omniverse==0.0.57

Dependencies#

from __future__ import annotations

import copy
import math
from typing import Any, Dict, List, Optional, TypedDict, Union

import numpy as np
import psutil
import torch
from datasets import load_dataset
from pydantic import BaseModel, Field
from rich.pretty import pprint
from scipy.special import softmax
from sklearn.metrics import (
    accuracy_score,
    auc,
    average_precision_score,
    brier_score_loss,
    confusion_matrix,
    f1_score,
    log_loss,
    precision_recall_curve,
    precision_score,
    recall_score,
    roc_auc_score,
    roc_curve,
)
from torch import nn
from transformers import (
    DataCollatorWithPadding,
    Qwen2ForSequenceClassification,
    Qwen2Tokenizer,
    Trainer,
    TrainingArguments,
)
from transformers.trainer_utils import EvalPrediction

from omnivault.utils.reproducibility.seed import seed_all
2024-07-22 12:01:06.064411: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-22 12:01:06.064466: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-22 12:01:06.065888: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

Setting Up#

seed_all(42, seed_torch=True, set_torch_deterministic=False)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

MAX_LENGTH = 32
PADDING = "longest"
BATCH_SIZE = 32
TRUNCATION = True
RETURN_TENSORS = "pt"

Dataset Preparation#

class Batch(TypedDict):
    sentence: List[str]
    labels: List[int]


class TokenizedBatch(TypedDict):
    input_ids: List[int]
    attention_mask: List[int]
    labels: List[int]

tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen1.5-0.5B", padding_side="left")

def preprocess_function(batch: Batch, **kwargs: Any) -> TokenizedBatch:
    return tokenizer(batch["sentence"], **kwargs)

dataset = load_dataset("financial_phrasebank", "sentences_allagree", trust_remote_code=True)["train"]
dataset = dataset.rename_column("label", "labels")

train_valid_split = dataset.train_test_split(test_size=0.1, shuffle=True, stratify_by_column="labels")

train_dataset = train_valid_split["train"]
valid_dataset = train_valid_split["test"]

tokenized_train_dataset = train_dataset.map(
    preprocess_function,
    fn_kwargs={"truncation": TRUNCATION, "padding": PADDING, "max_length": MAX_LENGTH},
    batched=True,
    num_proc=psutil.cpu_count(logical=True),
    batch_size=1000,
).remove_columns(["sentence"])

tokenized_valid_dataset = valid_dataset.map(
    preprocess_function,
    fn_kwargs={"truncation": TRUNCATION, "padding": PADDING, "max_length": MAX_LENGTH},
    batched=True,
    num_proc=psutil.cpu_count(logical=True),
    batch_size=1000,
).remove_columns(["sentence"])

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

id2label = {0: "negative", 1: "neutral", 2: "positive"}
label2id = {"negative": 0, "neutral": 1, "positive": 2}
num_labels = len(id2label)
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.

Base Model#

base_model = Qwen2ForSequenceClassification.from_pretrained(
    "Qwen/Qwen1.5-0.5B",
    id2label=id2label,
    label2id=label2id,
    num_labels=num_labels,
    problem_type="single_label_classification",
)
base_model.config.pad_token_id = tokenizer.pad_token_id

base_model = base_model.to(DEVICE)
pprint(base_model)
Some weights of Qwen2ForSequenceClassification were not initialized from the model checkpoint at Qwen/Qwen1.5-0.5B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Qwen2ForSequenceClassification(
  (model): Qwen2Model(
(embed_tokens): Embedding(151936, 1024)
(layers): ModuleList(
(0-23): 24 x Qwen2DecoderLayer(
│   │   (self_attn): Qwen2SdpaAttention(
│   │     (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
│   │     (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
│   │     (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
│   │     (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
│   │     (rotary_emb): Qwen2RotaryEmbedding()
│   │   )
│   │   (mlp): Qwen2MLP(
│   │     (gate_proj): Linear(in_features=1024, out_features=2816, bias=False)
│   │     (up_proj): Linear(in_features=1024, out_features=2816, bias=False)
│   │     (down_proj): Linear(in_features=2816, out_features=1024, bias=False)
│   │     (act_fn): SiLU()
│   │   )
│   │   (input_layernorm): Qwen2RMSNorm()
│   │   (post_attention_layernorm): Qwen2RMSNorm()
)
)
(norm): Qwen2RMSNorm()
  )
  (score): Linear(in_features=1024, out_features=3, bias=False)
)
def total_trainable_parameters(module: nn.Module) -> int:
    """Returns the number of trainable parameters in the model."""
    return sum(p.numel() for p in module.parameters() if p.requires_grad)


def total_parameters(module: nn.Module) -> int:
    """Returns the total number of parameters in the model, including non-trainable."""
    return sum(p.numel() for p in module.parameters())

base_model_total_trainable = total_trainable_parameters(base_model)
print(f"Total trainable parameters before LoRA: {base_model_total_trainable:,}")
Total trainable parameters before LoRA: 463,990,784

Metrics#

def compute_metrics_for_single_label_classification(eval_prediction: EvalPrediction) -> Dict[str, float | List[float]]:
    logits, labels = eval_prediction.predictions, eval_prediction.label_ids
    probs = softmax(logits, axis=-1)

    num_classes = logits.shape[1]
    preds = np.argmax(probs, axis=1)

    metrics = {
        "eval_log_loss": log_loss(labels, probs),
        "eval_accuracy": accuracy_score(labels, preds),
        "eval_precision_macro": precision_score(labels, preds, average="macro", zero_division=0),
        "eval_recall_macro": recall_score(labels, preds, average="macro", zero_division=0),
        "eval_f1_score_macro": f1_score(labels, preds, average="macro", zero_division=0),
        "eval_precision_micro": precision_score(labels, preds, average="micro", zero_division=0),
        "eval_recall_micro": recall_score(labels, preds, average="micro", zero_division=0),
        "eval_f1_score_micro": f1_score(labels, preds, average="micro", zero_division=0),
        "eval_confusion_matrix": confusion_matrix(labels, preds).tolist(),
        "eval_roc_auc": roc_auc_score(labels, probs, multi_class="ovr"),
        "eval_pr_auc": average_precision_score(labels, probs, average="macro")
    }

    if num_classes == 2:
        metrics["eval_brier_score"] = brier_score_loss(labels, probs[:, 1], pos_label=1)
    else:
        brier_scores = [brier_score_loss(labels == i, probs[:, i]) for i in range(num_classes)]
        metrics["eval_brier_score"] = np.mean(brier_scores)

    if num_classes > 2:
        for class_index in range(num_classes):
            fpr, tpr, _ = roc_curve(labels == class_index, probs[:, class_index])
            roc_auc = auc(fpr, tpr)
            precision, recall, _ = precision_recall_curve(labels == class_index, probs[:, class_index])
            pr_auc = auc(recall, precision)
            metrics[f"eval_roc_auc_class_{class_index}"] = roc_auc
            metrics[f"eval_pr_auc_class_{class_index}"] = pr_auc

    return metrics

Evaluate With Pretrained Model#

trainer = Trainer(
    model=base_model,
    args=TrainingArguments(output_dir="./artifacts", report_to="none"),
    data_collator=data_collator,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_valid_dataset,
    compute_metrics=compute_metrics_for_single_label_classification,
)

valid_metrics = trainer.predict(tokenized_valid_dataset, metric_key_prefix="eval")
pprint(valid_metrics.metrics)
We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
{
'eval_log_loss': 7.40178591009753,
'eval_accuracy': 0.14096916299559473,
'eval_precision_macro': 0.3000285877644368,
'eval_recall_macro': 0.3223057644110276,
'eval_f1_score_macro': 0.11305118925439782,
'eval_precision_micro': 0.14096916299559473,
'eval_recall_micro': 0.14096916299559473,
'eval_f1_score_micro': 0.14096916299559473,
'eval_confusion_matrix': [[27, 1, 2], [132, 2, 6], [53, 1, 3]],
'eval_roc_auc': 0.5319817168701279,
'eval_pr_auc': 0.3596197342614926,
'eval_brier_score': 0.55610225252113,
'eval_roc_auc_class_0': 0.4730964467005076,
'eval_pr_auc_class_0': 0.12580756501576662,
'eval_roc_auc_class_1': 0.5628899835796387,
'eval_pr_auc_class_1': 0.6492525648321494,
'eval_roc_auc_class_2': 0.5599587203302373,
'eval_pr_auc_class_2': 0.2816445108714582,
'eval_loss': 7.431046962738037,
'eval_runtime': 1.8501,
'eval_samples_per_second': 122.698,
'eval_steps_per_second': 15.675
}

LoRA Implementation#

class LoraConfig(BaseModel):
    r: int = Field(..., description="Lora attention dimension (the 'rank').")
    lora_alpha: int = Field(..., description="The alpha parameter for Lora scaling.")
    lora_dropout: float = Field(..., description="The dropout probability for Lora layers.")
    target_modules: List[str] = Field(
        default=None,
        description=(
            "The names of the modules to apply the adapter to. If specified, only the modules with the specified "
            "names will be replaced. When passing a string, a regex match will be performed. When passing a list of "
            "strings, either an exact match will be performed or it is checked if the name of the module ends with any "
            "of the passed strings. If specified as 'all-linear', all linear/Conv1D modules are chosen, excluding the "
            "output layer. If not specified, modules are chosen according to the model architecture. If the architecture "
            "is unknown, an error will be raised—manual specification of target modules is required in such cases."
        ),
    )
    modules_to_save: List[str] = Field(
        default=None,
        description=(
            """List of modules apart from adapter layers to be set as
               trainable and saved in the final checkpoint."""
        ),
    )
lora_config = LoraConfig(
    r=4, lora_alpha=8, lora_dropout=0.1, target_modules=["q_proj", "k_proj", "v_proj"], modules_to_save=["score"]
)
pprint(lora_config)
LoraConfig(
r=4,
lora_alpha=8,
lora_dropout=0.1,
target_modules=['q_proj', 'k_proj', 'v_proj'],
modules_to_save=['score']
)

We print out the target modules below. For simplicity, we target only the q, k and v layers for now.

for module_name, _module in base_model.named_modules():
    if any(target_module in module_name for target_module in lora_config.target_modules):
        print(module_name)
model.layers.0.self_attn.q_proj
model.layers.0.self_attn.k_proj
model.layers.0.self_attn.v_proj
model.layers.1.self_attn.q_proj
model.layers.1.self_attn.k_proj
model.layers.1.self_attn.v_proj
model.layers.2.self_attn.q_proj
model.layers.2.self_attn.k_proj
model.layers.2.self_attn.v_proj
model.layers.3.self_attn.q_proj
model.layers.3.self_attn.k_proj
model.layers.3.self_attn.v_proj
model.layers.4.self_attn.q_proj
model.layers.4.self_attn.k_proj
model.layers.4.self_attn.v_proj
model.layers.5.self_attn.q_proj
model.layers.5.self_attn.k_proj
model.layers.5.self_attn.v_proj
model.layers.6.self_attn.q_proj
model.layers.6.self_attn.k_proj
model.layers.6.self_attn.v_proj
model.layers.7.self_attn.q_proj
model.layers.7.self_attn.k_proj
model.layers.7.self_attn.v_proj
model.layers.8.self_attn.q_proj
model.layers.8.self_attn.k_proj
model.layers.8.self_attn.v_proj
model.layers.9.self_attn.q_proj
model.layers.9.self_attn.k_proj
model.layers.9.self_attn.v_proj
model.layers.10.self_attn.q_proj
model.layers.10.self_attn.k_proj
model.layers.10.self_attn.v_proj
model.layers.11.self_attn.q_proj
model.layers.11.self_attn.k_proj
model.layers.11.self_attn.v_proj
model.layers.12.self_attn.q_proj
model.layers.12.self_attn.k_proj
model.layers.12.self_attn.v_proj
model.layers.13.self_attn.q_proj
model.layers.13.self_attn.k_proj
model.layers.13.self_attn.v_proj
model.layers.14.self_attn.q_proj
model.layers.14.self_attn.k_proj
model.layers.14.self_attn.v_proj
model.layers.15.self_attn.q_proj
model.layers.15.self_attn.k_proj
model.layers.15.self_attn.v_proj
model.layers.16.self_attn.q_proj
model.layers.16.self_attn.k_proj
model.layers.16.self_attn.v_proj
model.layers.17.self_attn.q_proj
model.layers.17.self_attn.k_proj
model.layers.17.self_attn.v_proj
model.layers.18.self_attn.q_proj
model.layers.18.self_attn.k_proj
model.layers.18.self_attn.v_proj
model.layers.19.self_attn.q_proj
model.layers.19.self_attn.k_proj
model.layers.19.self_attn.v_proj
model.layers.20.self_attn.q_proj
model.layers.20.self_attn.k_proj
model.layers.20.self_attn.v_proj
model.layers.21.self_attn.q_proj
model.layers.21.self_attn.k_proj
model.layers.21.self_attn.v_proj
model.layers.22.self_attn.q_proj
model.layers.22.self_attn.k_proj
model.layers.22.self_attn.v_proj
model.layers.23.self_attn.q_proj
model.layers.23.self_attn.k_proj
model.layers.23.self_attn.v_proj
"""LoRA: Low-Rank Adaptation of Large Language Models.

References
----------
[1] https://pytorch.org/torchtune/stable/tutorials/lora_finetune.html
"""


from __future__ import annotations

import math
from typing import List

import torch
from pydantic import BaseModel, Field
from torch import nn


class LoraConfig(BaseModel):
    r: int = Field(..., description="Lora attention dimension (the 'rank').")
    lora_alpha: int = Field(..., description="The alpha parameter for Lora scaling.")
    lora_dropout: float = Field(..., description="The dropout probability for Lora layers.")
    target_modules: List[str] = Field(
        default=None,
        description=(
            "The names of the modules to apply the adapter to. If specified, only the modules with the specified "
            "names will be replaced. When passing a string, a regex match will be performed. When passing a list of "
            "strings, either an exact match will be performed or it is checked if the name of the module ends with any "
            "of the passed strings. If specified as 'all-linear', all linear/Conv1D modules are chosen, excluding the "
            "output layer. If not specified, modules are chosen according to the model architecture. If the architecture "
            "is unknown, an error will be raised—manual specification of target modules is required in such cases."
        ),
    )
    modules_to_save: List[str] = Field(
        default=None,
        description=(
            """List of modules apart from adapter layers to be set as
               trainable and saved in the final checkpoint."""
        ),
    )


def _lora_a_init_params(x: nn.Linear) -> None:
    """
    Initialize LoRA A weight to Kaiming uniform.
    """
    nn.init.kaiming_uniform_(x.weight, a=math.sqrt(5))


def _lora_b_init_params(x: nn.Linear) -> None:
    """
    Initialize LoRA B weight to zeros.
    """
    nn.init.zeros_(x.weight)


class LoRALinear(nn.Module):
    """LoRA Linear layer."""

    def __init__(self, original_linear: nn.Linear, rank: int, alpha: float, dropout: float) -> None:
        """Initialize the `LoRALinear` layer.

        Parameters
        ----------
        original_linear : nn.Linear
            The original linear layer from the pretrained
        rank : int
            The rank of the LoRA layer.
        alpha : float
            The alpha parameter for LoRA scaling.
        dropout : float
            The dropout probability for the LoRA layer.
        """
        super().__init__()

        # These are the weights from the original pretrained model
        self.linear = original_linear  # weight shape=[out_dim, in_dim]

        in_dim = self.linear.in_features
        out_dim = self.linear.out_features

        # These are the new LoRA params. In general rank << in_dim, out_dim - do not put bias here
        self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False)  # weight shape=[rank, in_dim]
        self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False)  # weight shape=[out_dim, rank]

        self.rank = rank
        self.alpha = alpha
        self.dropout = nn.Dropout(p=dropout)

        self._init_weights()

    def _init_weights(self) -> None:
        """See https://github.com/microsoft/LoRA/blob/4c0333854cb905966f8cc4e9a74068c1e507c7b7/loralib/layers.py#L119."""

        _lora_a_init_params(self.lora_a)
        _lora_b_init_params(self.lora_b)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the `LoRALinear` layer."""
        frozen_out = x @ self.linear.weight.T  # This would be the output of the original model
        if self.linear.bias is not None:
            frozen_out += self.linear.bias

        # lora_a projects inputs down to the much smaller self.rank,
        # then lora_b projects back up to the output dimension
        x = self.dropout(x)
        lora_out = x @ (self.lora_a.weight.T @ self.lora_b.weight.T)  # [B, T, D1] @ [D1, R] @ [R, D2] = [D1, D2]
        # Finally, scale by the alpha parameter (normalized by rank)
        # and add to the original model's outputs
        return frozen_out + (self.alpha / self.rank) * lora_out

    @torch.no_grad()
    def _merge(self) -> nn.Linear:
        """
        Merge the LoRA layers to the original linear layer.
        """

        # (gate_proj): Linear(in_features=1024, out_features=2816, bias=False) -> weight = [2816, 1024]
        # [1024, R] @ [R, 2816]
        # torch.Size([1024, 2816]) torch.Size([2816, 1024])

        lora_weight = self.lora_a.weight.T @ self.lora_b.weight.T # [D1, R] @ [R, D2] = [D1, D2]
        lora_weight = lora_weight.to(self.linear.weight.device)
        self.linear.weight += (self.alpha / self.rank) * lora_weight.T
        return self.linear


def merge_and_unload_model(model: nn.Module) -> nn.Module:
    """Recursively merge LoRA layers back into the original Linear layers in
    the model and unload LoRA parameters."""
    for module_name, module in model.named_children():
        if isinstance(module, LoRALinear):
            merged_linear = module._merge()
            setattr(model, module_name, merged_linear)
        else:
            merge_and_unload_model(module)
    return model

def apply_lora_to_base_model(
    model: nn.Module, rank: int, alpha: float, dropout: float, target_modules: List[str] | None = None
) -> None:
    """Recursively apply LoRA to a model. Only supports applying on `nn.Linear` layers.

    In the `if` condition, we first check if the module is an instance of
    `nn.Linear`. If it is, we then check if the `target_modules` is specified
    by user, if it is not, then `if target_modules is None` will return `True`
    and we apply LoRA to the module because we assume that the user wants to
    apply LoRA to all `nn.Linear` layers. If the `target_modules` is specified,
    then `if target_modules is None` will return `False` and we will check the
    second condition `any(target in module_name for target in target_modules)`
    which will return `True` if any of the target modules are in the module name.
    """
    for module_name, module in model.named_children():
        if isinstance(module, nn.Linear):
            if target_modules is None or any(target in module_name for target in target_modules):
                setattr(
                    model,
                    module_name,
                    LoRALinear(
                        original_linear=module,
                        rank=rank,
                        alpha=alpha,
                        dropout=dropout,
                    ),
                )
        else:
            # Recursively apply LoRA to children modules
            apply_lora_to_base_model(
                model=module, rank=rank, alpha=alpha, dropout=dropout, target_modules=target_modules
            )

Note that originally there’s a mistake to “re-create” self.linear = original_linear  # weight shape=[out_dim, in_dim] as self.linear = nn.Linear(...) which is wrong because you do not inherit the original pre-trained weights. How I figured out is that during training the initial loss/metrics are unstable and looks off, a revisit to the implementation quickly reveal this issue.

We do a deep copy on the base_model to avoid mutation.

base_model_with_adapter = copy.deepcopy(base_model)

We apply recursively the LoRA module to the q, k and v layers via apply_lora_to_base_model.

apply_lora_to_base_model(
    model=base_model_with_adapter,
    rank=lora_config.r,
    alpha=lora_config.lora_alpha,
    dropout=lora_config.lora_dropout,
    target_modules=lora_config.target_modules,
)
pprint(base_model_with_adapter)
Qwen2ForSequenceClassification(
  (model): Qwen2Model(
(embed_tokens): Embedding(151936, 1024)
(layers): ModuleList(
(0-23): 24 x Qwen2DecoderLayer(
│   │   (self_attn): Qwen2SdpaAttention(
│   │     (q_proj): LoRALinear(
│   │   │   (linear): Linear(in_features=1024, out_features=1024, bias=True)
│   │   │   (lora_a): Linear(in_features=1024, out_features=4, bias=False)
│   │   │   (lora_b): Linear(in_features=4, out_features=1024, bias=False)
│   │   │   (dropout): Dropout(p=0.1, inplace=False)
│   │     )
│   │     (k_proj): LoRALinear(
│   │   │   (linear): Linear(in_features=1024, out_features=1024, bias=True)
│   │   │   (lora_a): Linear(in_features=1024, out_features=4, bias=False)
│   │   │   (lora_b): Linear(in_features=4, out_features=1024, bias=False)
│   │   │   (dropout): Dropout(p=0.1, inplace=False)
│   │     )
│   │     (v_proj): LoRALinear(
│   │   │   (linear): Linear(in_features=1024, out_features=1024, bias=True)
│   │   │   (lora_a): Linear(in_features=1024, out_features=4, bias=False)
│   │   │   (lora_b): Linear(in_features=4, out_features=1024, bias=False)
│   │   │   (dropout): Dropout(p=0.1, inplace=False)
│   │     )
│   │     (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
│   │     (rotary_emb): Qwen2RotaryEmbedding()
│   │   )
│   │   (mlp): Qwen2MLP(
│   │     (gate_proj): Linear(in_features=1024, out_features=2816, bias=False)
│   │     (up_proj): Linear(in_features=1024, out_features=2816, bias=False)
│   │     (down_proj): Linear(in_features=2816, out_features=1024, bias=False)
│   │     (act_fn): SiLU()
│   │   )
│   │   (input_layernorm): Qwen2RMSNorm()
│   │   (post_attention_layernorm): Qwen2RMSNorm()
)
)
(norm): Qwen2RMSNorm()
  )
  (score): Linear(in_features=1024, out_features=3, bias=False)
)

Pay attention here to the architecture, each q, k and v (our chosen target modules) originally has the following name for their layer:

(q_proj): Linear(in_features=1024, out_features=1024, bias=True)
(k_proj): Linear(in_features=1024, out_features=1024, bias=True)
(v_proj): Linear(in_features=1024, out_features=1024, bias=True)

And after the replacement, the Linear layer is replaced with a LoRALinear which also wraps the original Linear layer along with additional lora_a and lora_b.

(q_proj): LoRALinear(
 (linear): Linear(in_features=1024, out_features=1024, bias=True)
 (lora_a): Linear(in_features=1024, out_features=4, bias=False)
 (lora_b): Linear(in_features=4, out_features=1024, bias=False)
 (dropout): Dropout(p=0.1, inplace=False)
)
(k_proj): LoRALinear(
 (linear): Linear(in_features=1024, out_features=1024, bias=True)
 (lora_a): Linear(in_features=1024, out_features=4, bias=False)
 (lora_b): Linear(in_features=4, out_features=1024, bias=False)
 (dropout): Dropout(p=0.1, inplace=False)
)
(v_proj): LoRALinear(
 (linear): Linear(in_features=1024, out_features=1024, bias=True)
 (lora_a): Linear(in_features=1024, out_features=4, bias=False)
 (lora_b): Linear(in_features=4, out_features=1024, bias=False)
)
base_model_with_adapter_total_trainable = total_trainable_parameters(base_model_with_adapter)
print(f"Total trainable parameters after LoRA before freezing: {base_model_with_adapter_total_trainable:,}")
Total trainable parameters after LoRA before freezing: 464,580,608

Note that bias is default to True in original model, but in LoRA we need to have it as False. You also see that currently the total trainable parameters are more than base model. Why?

base_model_with_adapter_total_trainable - base_model_total_trainable
589824

Where does the additional \(589824\) parameters come from, let’s find out below.

dim = base_model_with_adapter.model.layers[0].self_attn.q_proj.linear.weight.shape[0]
layers = base_model_with_adapter.model.layers.__len__()
rank = lora_config.r
num_target_modules = len(lora_config.target_modules)

qkv_lora_weight_params = (dim * rank * 2) * layers * num_target_modules # 2 is the AB 1 each

base_model_with_adapter_total_trainable - base_model_total_trainable ==  qkv_lora_weight_params
True

The additional parameters is basically because we apply to qkv where each qkv has 24 layers each, so for each layer, say q_proj we would have an additional of 1024 * 4 * 2 because matrix A and B are mirrored to version of [dim, rank].

Now of course the next step is to freeze the base pretrained weights. Note we DO NOT want to freeze the score module as that is our classification head.

for parameter_name, parameter in base_model_with_adapter.named_parameters():
    # We will set requires_grad to False if 'lora_' is not in the parameter name AND the parameter name does not contain any of the module names specified in modules_to_save
    if "lora_" not in parameter_name and not any(
        module_name in parameter_name for module_name in lora_config.modules_to_save
    ):
        parameter.requires_grad = False
    else:
        # Safeguard here parameters that are part of LoRA or specified modules are trainable
        parameter.requires_grad = True
base_model_with_adapter_total_trainable = total_trainable_parameters(base_model_with_adapter)
print(f"Total trainable parameters after LoRA after freezing: {base_model_with_adapter_total_trainable:,}")
Total trainable parameters after LoRA after freezing: 592,896
(base_model_with_adapter_total_trainable / base_model_total_trainable) * 100
0.1277818483567122

We are only training on ~0.1277% of the total parameters.

Train LoRA#

break
  Cell In[22], line 1
    break
    ^
SyntaxError: 'break' outside loop
seed_all(42, seed_torch=True, set_torch_deterministic=False)

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    modules_to_save=["score"],
)
pprint(lora_config)

base_model_with_adapter = copy.deepcopy(base_model)

apply_lora_to_base_model(
    model=base_model_with_adapter,
    rank=lora_config.r,
    alpha=lora_config.lora_alpha,
    dropout=lora_config.lora_dropout,
    target_modules=lora_config.target_modules,
)

for parameter_name, parameter in base_model_with_adapter.named_parameters():
    # We will set requires_grad to False if 'lora_' is not in the parameter name AND the parameter name does not contain any of the module names specified in modules_to_save
    if "lora_" not in parameter_name and not any(
        module_name in parameter_name for module_name in lora_config.modules_to_save
    ):
        parameter.requires_grad = False
    else:
        # Safeguard here parameters that are part of LoRA or specified modules are trainable
        parameter.requires_grad = True
LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.1,
target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'],
modules_to_save=['score']
)
pprint(base_model_with_adapter)
Qwen2ForSequenceClassification(
  (model): Qwen2Model(
(embed_tokens): Embedding(151936, 1024)
(layers): ModuleList(
(0-23): 24 x Qwen2DecoderLayer(
│   │   (self_attn): Qwen2SdpaAttention(
│   │     (q_proj): LoRALinear(
│   │   │   (linear): Linear(in_features=1024, out_features=1024, bias=True)
│   │   │   (lora_a): Linear(in_features=1024, out_features=16, bias=False)
│   │   │   (lora_b): Linear(in_features=16, out_features=1024, bias=False)
│   │   │   (dropout): Dropout(p=0.1, inplace=False)
│   │     )
│   │     (k_proj): LoRALinear(
│   │   │   (linear): Linear(in_features=1024, out_features=1024, bias=True)
│   │   │   (lora_a): Linear(in_features=1024, out_features=16, bias=False)
│   │   │   (lora_b): Linear(in_features=16, out_features=1024, bias=False)
│   │   │   (dropout): Dropout(p=0.1, inplace=False)
│   │     )
│   │     (v_proj): LoRALinear(
│   │   │   (linear): Linear(in_features=1024, out_features=1024, bias=True)
│   │   │   (lora_a): Linear(in_features=1024, out_features=16, bias=False)
│   │   │   (lora_b): Linear(in_features=16, out_features=1024, bias=False)
│   │   │   (dropout): Dropout(p=0.1, inplace=False)
│   │     )
│   │     (o_proj): LoRALinear(
│   │   │   (linear): Linear(in_features=1024, out_features=1024, bias=False)
│   │   │   (lora_a): Linear(in_features=1024, out_features=16, bias=False)
│   │   │   (lora_b): Linear(in_features=16, out_features=1024, bias=False)
│   │   │   (dropout): Dropout(p=0.1, inplace=False)
│   │     )
│   │     (rotary_emb): Qwen2RotaryEmbedding()
│   │   )
│   │   (mlp): Qwen2MLP(
│   │     (gate_proj): LoRALinear(
│   │   │   (linear): Linear(in_features=1024, out_features=2816, bias=False)
│   │   │   (lora_a): Linear(in_features=1024, out_features=16, bias=False)
│   │   │   (lora_b): Linear(in_features=16, out_features=2816, bias=False)
│   │   │   (dropout): Dropout(p=0.1, inplace=False)
│   │     )
│   │     (up_proj): LoRALinear(
│   │   │   (linear): Linear(in_features=1024, out_features=2816, bias=False)
│   │   │   (lora_a): Linear(in_features=1024, out_features=16, bias=False)
│   │   │   (lora_b): Linear(in_features=16, out_features=2816, bias=False)
│   │   │   (dropout): Dropout(p=0.1, inplace=False)
│   │     )
│   │     (down_proj): LoRALinear(
│   │   │   (linear): Linear(in_features=2816, out_features=1024, bias=False)
│   │   │   (lora_a): Linear(in_features=2816, out_features=16, bias=False)
│   │   │   (lora_b): Linear(in_features=16, out_features=1024, bias=False)
│   │   │   (dropout): Dropout(p=0.1, inplace=False)
│   │     )
│   │     (act_fn): SiLU()
│   │   )
│   │   (input_layernorm): Qwen2RMSNorm()
│   │   (post_attention_layernorm): Qwen2RMSNorm()
)
)
(norm): Qwen2RMSNorm()
  )
  (score): Linear(in_features=1024, out_features=3, bias=False)
)
training_args = TrainingArguments(
    do_eval=True,
    do_predict=False,
    do_train=True,
    warmup_ratio=0.0,
    learning_rate=6e-4,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    report_to="none",
    output_dir="./artifacts",
    overwrite_output_dir=True,
    gradient_accumulation_steps=1,
    logging_steps=25,
    evaluation_strategy="steps",
    eval_steps=32,
    save_strategy="steps",
    save_steps=128,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    lr_scheduler_type="cosine",
    weight_decay=0.0,
    save_total_limit=2,
    seed=42,
    data_seed=42,
    half_precision_backend="auto",
    optim="adamw_torch",
    label_smoothing_factor=0.0,
    max_grad_norm=1.0,
)
/opt/conda/lib/python3.10/site-packages/transformers/training_args.py:1494: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead
  warnings.warn(
trainer = Trainer(
    model=base_model_with_adapter,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_valid_dataset,
    compute_metrics=compute_metrics_for_single_label_classification,
)
trainer.train()
[192/192 02:15, Epoch 3/3]
Step Training Loss Validation Loss Log Loss Accuracy Precision Macro Recall Macro F1 Score Macro Precision Micro Recall Micro F1 Score Micro Confusion Matrix Roc Auc Pr Auc Brier Score Roc Auc Class 0 Pr Auc Class 0 Roc Auc Class 1 Pr Auc Class 1 Roc Auc Class 2 Pr Auc Class 2
32 1.313600 0.428587 0.428587 0.845815 0.867180 0.702005 0.695558 0.845815 0.845815 0.845815 [[7, 0, 23], [0, 132, 8], [0, 4, 53]] 0.970182 0.913684 0.080725 0.967851 0.848750 0.990476 0.993630 0.952219 0.895753
64 0.411400 0.181376 0.181376 0.942731 0.920626 0.909733 0.912348 0.942731 0.942731 0.942731 [[24, 0, 6], [1, 135, 4], [1, 1, 55]] 0.992234 0.975164 0.030542 0.990863 0.954921 0.996470 0.997859 0.989370 0.971857
96 0.164500 0.136885 0.136885 0.947137 0.941042 0.910443 0.924157 0.947137 0.947137 0.947137 [[25, 1, 4], [0, 138, 2], [1, 4, 52]] 0.994777 0.984542 0.026965 0.996616 0.977963 0.996798 0.997951 0.990918 0.977165
128 0.096400 0.139018 0.139018 0.960352 0.941155 0.931579 0.936106 0.960352 0.960352 0.960352 [[27, 0, 3], [0, 140, 0], [3, 3, 51]] 0.995570 0.985105 0.021838 0.995431 0.974730 0.998194 0.998862 0.993086 0.981205
160 0.027700 0.134256 0.134256 0.955947 0.934282 0.920468 0.927157 0.955947 0.955947 0.955947 [[26, 0, 4], [0, 140, 0], [3, 3, 51]] 0.996782 0.989355 0.022951 0.997293 0.983136 0.998522 0.999064 0.994530 0.985478
192 0.030300 0.160583 0.160583 0.955947 0.934282 0.920468 0.927157 0.955947 0.955947 0.955947 [[26, 0, 4], [0, 140, 0], [3, 3, 51]] 0.996349 0.988262 0.024319 0.997293 0.983136 0.998358 0.998957 0.993395 0.982279

TrainOutput(global_step=192, training_loss=0.280302738926063, metrics={'train_runtime': 136.3488, 'train_samples_per_second': 44.819, 'train_steps_per_second': 1.408, 'total_flos': 370740459995136.0, 'train_loss': 0.280302738926063, 'epoch': 3.0})

The accuracy hits around \(95\%\) after 10 epochs. This is a far cry from what an encoder like Deberta can achieve. But this is a good start and shows that the implementation is working.

Merge And Unload#

We do a sanity check on the trained model’s predictions and ensure later when we do the merge and unload, the results may not differ much. It may still differ due to floating points operations but should be minimum.

trainer.predict(tokenized_valid_dataset).metrics
{'test_loss': 0.13901817798614502,
 'test_eval_log_loss': 0.13901824007066116,
 'test_eval_accuracy': 0.960352422907489,
 'test_eval_precision_macro': 0.9411551411551411,
 'test_eval_recall_macro': 0.9315789473684211,
 'test_eval_f1_score_macro': 0.936106070735046,
 'test_eval_precision_micro': 0.960352422907489,
 'test_eval_recall_micro': 0.960352422907489,
 'test_eval_f1_score_micro': 0.960352422907489,
 'test_eval_confusion_matrix': [[27, 0, 3], [0, 140, 0], [3, 3, 51]],
 'test_eval_roc_auc': 0.9955702958862339,
 'test_eval_pr_auc': 0.9851052470652112,
 'test_eval_brier_score': 0.021837805972314456,
 'test_eval_roc_auc_class_0': 0.9954314720812182,
 'test_eval_pr_auc_class_0': 0.9747298715068023,
 'test_eval_roc_auc_class_1': 0.9981937602627258,
 'test_eval_pr_auc_class_1': 0.9988615953476088,
 'test_eval_roc_auc_class_2': 0.9930856553147576,
 'test_eval_pr_auc_class_2': 0.9812046605398257,
 'test_runtime': 1.7167,
 'test_samples_per_second': 132.231,
 'test_steps_per_second': 4.66}
base_model_merged_and_unloaded = merge_and_unload_model(base_model_with_adapter)
pprint(base_model_merged_and_unloaded)
Qwen2ForSequenceClassification(
  (model): Qwen2Model(
(embed_tokens): Embedding(151936, 1024)
(layers): ModuleList(
(0-23): 24 x Qwen2DecoderLayer(
│   │   (self_attn): Qwen2SdpaAttention(
│   │     (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
│   │     (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
│   │     (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
│   │     (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
│   │     (rotary_emb): Qwen2RotaryEmbedding()
│   │   )
│   │   (mlp): Qwen2MLP(
│   │     (gate_proj): Linear(in_features=1024, out_features=2816, bias=False)
│   │     (up_proj): Linear(in_features=1024, out_features=2816, bias=False)
│   │     (down_proj): Linear(in_features=2816, out_features=1024, bias=False)
│   │     (act_fn): SiLU()
│   │   )
│   │   (input_layernorm): Qwen2RMSNorm()
│   │   (post_attention_layernorm): Qwen2RMSNorm()
)
)
(norm): Qwen2RMSNorm()
  )
  (score): Linear(in_features=1024, out_features=3, bias=False)
)
trainer = Trainer(
    model=base_model_merged_and_unloaded,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_valid_dataset,
    compute_metrics=compute_metrics_for_single_label_classification,
)

trainer.predict(tokenized_valid_dataset).metrics
{'test_loss': 0.139018252491951,
 'test_eval_log_loss': 0.13901830174555815,
 'test_eval_accuracy': 0.960352422907489,
 'test_eval_precision_macro': 0.9411551411551411,
 'test_eval_recall_macro': 0.9315789473684211,
 'test_eval_f1_score_macro': 0.936106070735046,
 'test_eval_precision_micro': 0.960352422907489,
 'test_eval_recall_micro': 0.960352422907489,
 'test_eval_f1_score_micro': 0.960352422907489,
 'test_eval_confusion_matrix': [[27, 0, 3], [0, 140, 0], [3, 3, 51]],
 'test_eval_roc_auc': 0.9955702958862339,
 'test_eval_pr_auc': 0.9851052470652112,
 'test_eval_brier_score': 0.021837813003862862,
 'test_eval_roc_auc_class_0': 0.9954314720812182,
 'test_eval_pr_auc_class_0': 0.9747298715068023,
 'test_eval_roc_auc_class_1': 0.9981937602627258,
 'test_eval_pr_auc_class_1': 0.9988615953476088,
 'test_eval_roc_auc_class_2': 0.9930856553147576,
 'test_eval_pr_auc_class_2': 0.9812046605398257,
 'test_runtime': 0.9137,
 'test_samples_per_second': 248.448,
 'test_steps_per_second': 8.756}

See HuggingFace’s merge_and_unload.