State And Metadata Management#

Twitter Handle LinkedIn Profile GitHub Profile Tag Tag Code

Besides the configuration that drives the whole machine learning pipeline, be it in training or inference, we would also have something called the metadata or state object which encapsulates the current state of the process. This is especially useful when we want to save the state of the training process and resume training from a checkpoint. If you are not convinced, have a look at MosaicML’s Composer for some insights.

Metadata Management#

In the context of a machine learning pipeline, having a mutable Metadata class can be beneficial. This is because the pipeline is a dynamic process, and different stages of the pipeline may need to update different parts of the metadata.

For example:

  • At the start of the pipeline, you might set the start_time and run_id.

  • When loading the data, you might update the dataset_version.

  • During training, you might update the history with the training loss and metrics at the end of each epoch.

  • When the pipeline finishes, you might set the end_time and upload the final metadata as an artifact.

This kind of incremental updating is easier with a mutable object. However, it’s important to manage the mutable state carefully to avoid bugs. For example, if two parts of your code are both updating the metadata at the same time, they might interfere with each other. To avoid this, you might want to:

  • Make sure that different parts of your code are responsible for updating different parts of the metadata, and that they don’t interfere with each other.

  • Use locks or other synchronization mechanisms if you’re updating the metadata from multiple threads or processes.

  • Create a new copy of the metadata before making changes, if you need to keep the old state around for some reason.

An example below.

@dataclass(frozen=False)
class Metadata:
    # pipeline run time and usage
    start_time: str
    end_time: str
    time_taken: str
    memory_usage: str
    gpu_usage: str

    # inside load.py
    raw_file_size: int = None
    raw_file_format: str = None
    raw_filepath: str = None
    raw_dvc_metadata: Dict[str, Any] = None

    # inside train.py
    model_artifacts: Dict[str, Any] = None
    run_id: str = None
    model_version: str = None
    experiment_id: str = None
    experiment_name: str = None
    artifact_uri: str = None

    # inside evaluate.py
    holdout_performance: Dict[str, float] = None
    avg_expected_loss: float = None
    avg_bias: float = None
    avg_variance: float = None

State Management#

The State object represents the current state of the training process. It encapsulates key components such as the model, optimizer, criterion, scheduler, and additional metadata. The inspiration for State comes from MosaicML’s Composer, a library that has been beneficial in the context of pretraining Language Models (LLMs) and is also the library that me and my team adopted for our LLM pretraining project.

Here is a snippet of the State object when I was training a GPT model for fun. Note this is a very naive and bare version and does not include many features that Composer’s State object has.

Adder State

Why is State useful? It allows me to easily access the model, optimizer, criterion, scheduler, and other metadata from any part of the codebase. I can also serialize the State object and save it to disk, allowing me to resume training from a checkpoint.

Perhaps not the best design pattern, but I added a history attribute to the State object, which is a dictionary that stores the training history. This is to mimic what could be a separate History object, which is a common component in many deep learning libraries.

A sample implementation of the State object is shown below.

from __future__ import annotations

from typing import Any, Dict, List, Type, Union

import torch
from pydantic import BaseModel, Field
from rich.pretty import pprint
from torch import nn


def compare_models(model_a: nn.Module, model_b: nn.Module) -> bool:
    return all(
        torch.equal(param_a[1], param_b[1])
        for param_a, param_b in zip(model_a.state_dict().items(), model_b.state_dict().items())
    )


class State(BaseModel):
    model: nn.Module = Field(default=None, description="Model.")

    criterion: nn.Module = Field(default=None, description="Loss function.")
    optimizer: torch.optim.Optimizer = Field(default=None, description="Optimizer.")
    scheduler: Union[torch.optim.lr_scheduler.LRScheduler, None] = Field(default=None, description="Scheduler.")

    epoch_index: int = Field(default=0, description="Current epoch index.")
    train_batch_index: int = Field(
        default=0, description="Current batch index and is only referring to the training batch index."
    )
    step_index: int = Field(
        default=0,
        description="We do not add prefix train because it is understood and implied that the step number is the train due to how many gradients been stepped. Current step index and is only referring to the training step index. What is the difference between step and batch? In general, they coincide for when the epoch number is 1, but after the first epoch, we usually reset the batch index to 0, while the step index keeps increasing to the next epoch.",
    )
    history: Dict[str, List[float]] = Field(default={}, description="History of metrics.")

    # FIXME: loosen `Vocabularies` and `Tokenizers` to `Any` for now as it is too strict.
    vocabulary: Any = Field(default=None, description="Vocabulary.")
    tokenizer: Any = Field(default=None, description="Tokenizer.")

    def __eq__(self, other: object) -> bool:
        """Check if two State instances are equal."""
        assert isinstance(other, State), "Can only compare State instances."

        models_equal = compare_models(self.model, other.model)
        return models_equal

    class Config:
        """Pydantic config."""

        arbitrary_types_allowed = True

    def pretty_print(self) -> None:
        """Pretty print the config."""
        pprint(self)

    def save_snapshots(self, filepath: str) -> None:
        """Save the state dictionaries of the components to a file."""
        state = {
            "model": self.model.state_dict() if self.model else None,
            "criterion": self.criterion.state_dict() if self.criterion else None,
            "optimizer": self.optimizer.state_dict() if self.optimizer else None,
            "scheduler": self.scheduler.state_dict() if self.scheduler else None,
            "epoch_index": self.epoch_index,
            "train_batch_index": self.train_batch_index,
            "step_index": self.step_index,
            "history": self.history,
            "vocabulary": self.vocabulary,
            "tokenizer": self.tokenizer,
        }
        torch.save(state, filepath)

    @classmethod
    def load_snapshots(
        cls: Type[State],
        filepath: str,
        device: torch.device,
        *,
        model: nn.Module,
        criterion: nn.Module,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler.LRScheduler,
    ) -> State:
        """Load state dictionaries from a file and return a new State instance."""
        state = torch.load(filepath, map_location=device)

        epoch_index = state["epoch_index"]
        train_batch_index = state["train_batch_index"]
        step_index = state["step_index"]
        history = state["history"]
        vocabulary = state["vocabulary"]
        tokenizer = state["tokenizer"]

        # Create a new instance of State with loaded state
        new_state = cls(
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            epoch_index=epoch_index,
            train_batch_index=train_batch_index,
            step_index=step_index,
            history=history,
            vocabulary=vocabulary,
            tokenizer=tokenizer,
        )

        # Load state dicts into the model, criterion, etc., if they exist
        if new_state.model and "model" in state:
            new_state.model.load_state_dict(state["model"])
        if new_state.criterion and "criterion" in state:
            new_state.criterion.load_state_dict(state["criterion"])
        if new_state.optimizer and "optimizer" in state:
            new_state.optimizer.load_state_dict(state["optimizer"])
        if new_state.scheduler and "scheduler" in state:
            new_state.scheduler.load_state_dict(state["scheduler"])

        return new_state

References and Further Readings#