Registry#

Twitter Handle LinkedIn Profile GitHub Profile Tag Code

What does the registry design pattern solve? Consider you have to implement many schedulers in your deep learning project. Everytime you do so you may want to add the choice to a global dictionary (singleton) so that you can easily instantiate the scheduler. This is prone to mistake and a hassle once your project grows as you may need to maintain it.

The registry provides a central place to manage all available scheduler types. This makes it easier to keep track of what schedulers are available and ensures consistent instantiation across the application. With a simple decorator, you can register your scheduler and use it in your project globally.

It is often combined with the factory design pattern. First, we define:

  1. Registry Pattern:

    • The SchedulerRegistry maintains a collection (dictionary) of scheduler classes.

    • It provides methods to register new scheduler types (register method).

    • It allows retrieval of registered scheduler classes (get_scheduler method).

  2. Factory Pattern:

    • The create_scheduler method acts as a factory method.

    • It creates and returns instances of schedulers based on the provided name and parameters.

So, in essence, this SchedulerRegistry is behaving as both:

  1. A registry: It keeps track of available scheduler types.

  2. A factory: It creates instances of schedulers.

So basically,

  1. Scheduler types are registered with the registry (using the @scheduler_registry.register decorator).

  2. When you need a scheduler, you call scheduler_registry.create_scheduler(name, optimizer, **kwargs).

  3. The registry looks up the correct scheduler class based on the name.

  4. It then uses that class to create and return a scheduler instance.

A Registry For PyTorch Schedulers#

"""Module for creating PyTorch scheduler instances dynamically with an enhanced Registry pattern."""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Literal, Type

import torch
from pydantic import BaseModel
from rich.pretty import pprint

RegisteredSchedulers = Literal[
    "StepLR",
    "CosineAnnealingLR",
    "CosineAnnealingWarmRestarts",
]


class SchedulerRegistry:
    _schedulers: Dict[str, Type[SchedulerConfig]] = {}

    @classmethod
    def register(cls: Type[SchedulerRegistry], name: str) -> Callable[[Type[SchedulerConfig]], Type[SchedulerConfig]]:
        def register_scheduler_cls(scheduler_cls: Type[SchedulerConfig]) -> Type[SchedulerConfig]:
            if name in cls._schedulers:
                raise ValueError(f"Cannot register duplicate scheduler {name}")
            if not issubclass(scheduler_cls, SchedulerConfig):
                raise ValueError(f"Scheduler (name={name}, class={scheduler_cls.__name__}) must extend SchedulerConfig")
            cls._schedulers[name] = scheduler_cls
            return scheduler_cls

        return register_scheduler_cls

    @classmethod
    def get_scheduler(cls, name: str) -> Type[SchedulerConfig]:
        scheduler_cls = cls._schedulers.get(name)
        if not scheduler_cls:
            raise ValueError(f"Scheduler {name} not found in registry")
        return scheduler_cls

    @classmethod
    def create_scheduler(
        cls: Type[SchedulerRegistry], name: str, optimizer: torch.optim.Optimizer, **kwargs: Any
    ) -> torch.optim.lr_scheduler.LRScheduler:
        scheduler_cls = cls.get_scheduler(name)
        scheduler_config = scheduler_cls(**kwargs)
        return scheduler_config.build(optimizer)


class SchedulerConfig(BaseModel, ABC):
    """Base class for creating PyTorch scheduler instances dynamically."""

    @abstractmethod
    def build(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler.LRScheduler:
        """Builder method for creating a scheduler instance."""
        pass

    class Config:
        extra = "forbid"


@SchedulerRegistry.register("StepLR")
class StepLRConfig(SchedulerConfig):
    step_size: int
    gamma: float = 0.1
    last_epoch: int = -1
    verbose: bool = False

    def build(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler.StepLR:
        return torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=self.step_size, gamma=self.gamma, last_epoch=self.last_epoch, verbose=self.verbose
        )


@SchedulerRegistry.register("CosineAnnealingLR")
class CosineAnnealingLRConfig(SchedulerConfig):
    T_max: int
    eta_min: float = 0
    last_epoch: int = -1
    verbose: bool = False

    def build(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler.CosineAnnealingLR:
        return torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.T_max, eta_min=self.eta_min, last_epoch=self.last_epoch, verbose=self.verbose
        )


@SchedulerRegistry.register("CosineAnnealingWarmRestarts")
class CosineAnnealingWarmRestartsConfig(SchedulerConfig):
    T_0: int
    T_mult: int = 1
    eta_min: float = 0
    last_epoch: int = -1
    verbose: bool = False

    def build(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler.CosineAnnealingWarmRestarts:
        return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            T_0=self.T_0,
            T_mult=self.T_mult,
            eta_min=self.eta_min,
            last_epoch=self.last_epoch,
            verbose=self.verbose,
        )


if __name__ == "__main__":
    # Create a dummy optimizer for demonstration
    model = torch.nn.Linear(10, 2)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

    pprint(SchedulerRegistry._schedulers)
    # Create a StepLR scheduler
    step_lr = SchedulerRegistry.create_scheduler("StepLR", optimizer, step_size=30, gamma=0.1)
    print(f"Created StepLR scheduler: {step_lr}")

    # Create a CosineAnnealingLR scheduler
    cosine_lr = SchedulerRegistry.create_scheduler("CosineAnnealingLR", optimizer, T_max=100, eta_min=0.001)
    print(f"Created CosineAnnealingLR scheduler: {cosine_lr}")

    # Create a LambdaLR scheduler
    cosine_warm_restarts = SchedulerRegistry.create_scheduler(
        "CosineAnnealingWarmRestarts", optimizer, T_0=100, T_mult=2
    )
    print(f"Created CosineAnnealingWarmRestarts scheduler: {cosine_warm_restarts}")
{
'StepLR': <class '__main__.StepLRConfig'>,
'CosineAnnealingLR': <class '__main__.CosineAnnealingLRConfig'>,
'CosineAnnealingWarmRestarts': <class '__main__.CosineAnnealingWarmRestartsConfig'>
}
Created StepLR scheduler: <torch.optim.lr_scheduler.StepLR object at 0x7f747c650730>
Created CosineAnnealingLR scheduler: <torch.optim.lr_scheduler.CosineAnnealingLR object at 0x7f7472ea3a00>
Created CosineAnnealingWarmRestarts scheduler: <torch.optim.lr_scheduler.CosineAnnealingWarmRestarts object at 0x7f7472e03730>
/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/optim/lr_scheduler.py:62: UserWarning: The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.
  warnings.warn(

Singleton#

The registry also acts globally and behaves like a singleton. So you can make it one if you want to, either via the below way or via metaclass.

from typing import Dict, Type, Callable, Any


class SchedulerRegistry:
    _instance = None
    _schedulers: Dict[str, Type[SchedulerConfig]] = {}

    def __new__(cls: Type[SchedulerRegistry]) -> SchedulerRegistry:
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def register(self, name: str) -> Callable[[Type[SchedulerConfig]], Type[SchedulerConfig]]:
        def register_scheduler_cls(scheduler_cls: Type[SchedulerConfig]) -> Type[SchedulerConfig]:
            if name in self._schedulers:
                raise ValueError(f"Cannot register duplicate scheduler {name}")
            if not issubclass(scheduler_cls, SchedulerConfig):
                raise ValueError(f"Scheduler (name={name}, class={scheduler_cls.__name__}) must extend SchedulerConfig")
            self._schedulers[name] = scheduler_cls
            return scheduler_cls

        return register_scheduler_cls

    def get_scheduler(self, name: str) -> Type[SchedulerConfig]:
        scheduler_cls = self._schedulers.get(name)
        if not scheduler_cls:
            raise ValueError(f"Scheduler {name} not found in registry")
        return scheduler_cls

    def create_scheduler(
        self, name: str, optimizer: torch.optim.Optimizer, **kwargs: Any
    ) -> torch.optim.lr_scheduler.LRScheduler:
        scheduler_cls = self.get_scheduler(name)
        scheduler_config = scheduler_cls(**kwargs)
        return scheduler_config.build(optimizer)

References And Further Readings#