Registry#
What does the registry design pattern solve? Consider you have to implement many schedulers in your deep learning project. Everytime you do so you may want to add the choice to a global dictionary (singleton) so that you can easily instantiate the scheduler. This is prone to mistake and a hassle once your project grows as you may need to maintain it.
The registry provides a central place to manage all available scheduler types. This makes it easier to keep track of what schedulers are available and ensures consistent instantiation across the application. With a simple decorator, you can register your scheduler and use it in your project globally.
It is often combined with the factory design pattern. First, we define:
Registry Pattern:
The
SchedulerRegistry
maintains a collection (dictionary) of scheduler classes.It provides methods to register new scheduler types (
register
method).It allows retrieval of registered scheduler classes (
get_scheduler
method).
Factory Pattern:
The
create_scheduler
method acts as a factory method.It creates and returns instances of schedulers based on the provided name and parameters.
So, in essence, this SchedulerRegistry
is behaving as both:
A registry: It keeps track of available scheduler types.
A factory: It creates instances of schedulers.
So basically,
Scheduler types are registered with the registry (using the
@scheduler_registry.register
decorator).When you need a scheduler, you call
scheduler_registry.create_scheduler(name, optimizer, **kwargs)
.The registry looks up the correct scheduler class based on the name.
It then uses that class to create and return a scheduler instance.
A Registry For PyTorch Schedulers#
"""Module for creating PyTorch scheduler instances dynamically with an enhanced Registry pattern."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Literal, Type
import torch
from pydantic import BaseModel
from rich.pretty import pprint
RegisteredSchedulers = Literal[
"StepLR",
"CosineAnnealingLR",
"CosineAnnealingWarmRestarts",
]
class SchedulerRegistry:
_schedulers: Dict[str, Type[SchedulerConfig]] = {}
@classmethod
def register(cls: Type[SchedulerRegistry], name: str) -> Callable[[Type[SchedulerConfig]], Type[SchedulerConfig]]:
def register_scheduler_cls(scheduler_cls: Type[SchedulerConfig]) -> Type[SchedulerConfig]:
if name in cls._schedulers:
raise ValueError(f"Cannot register duplicate scheduler {name}")
if not issubclass(scheduler_cls, SchedulerConfig):
raise ValueError(f"Scheduler (name={name}, class={scheduler_cls.__name__}) must extend SchedulerConfig")
cls._schedulers[name] = scheduler_cls
return scheduler_cls
return register_scheduler_cls
@classmethod
def get_scheduler(cls, name: str) -> Type[SchedulerConfig]:
scheduler_cls = cls._schedulers.get(name)
if not scheduler_cls:
raise ValueError(f"Scheduler {name} not found in registry")
return scheduler_cls
@classmethod
def create_scheduler(
cls: Type[SchedulerRegistry], name: str, optimizer: torch.optim.Optimizer, **kwargs: Any
) -> torch.optim.lr_scheduler.LRScheduler:
scheduler_cls = cls.get_scheduler(name)
scheduler_config = scheduler_cls(**kwargs)
return scheduler_config.build(optimizer)
class SchedulerConfig(BaseModel, ABC):
"""Base class for creating PyTorch scheduler instances dynamically."""
@abstractmethod
def build(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler.LRScheduler:
"""Builder method for creating a scheduler instance."""
pass
class Config:
extra = "forbid"
@SchedulerRegistry.register("StepLR")
class StepLRConfig(SchedulerConfig):
step_size: int
gamma: float = 0.1
last_epoch: int = -1
verbose: bool = False
def build(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler.StepLR:
return torch.optim.lr_scheduler.StepLR(
optimizer, step_size=self.step_size, gamma=self.gamma, last_epoch=self.last_epoch, verbose=self.verbose
)
@SchedulerRegistry.register("CosineAnnealingLR")
class CosineAnnealingLRConfig(SchedulerConfig):
T_max: int
eta_min: float = 0
last_epoch: int = -1
verbose: bool = False
def build(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler.CosineAnnealingLR:
return torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=self.T_max, eta_min=self.eta_min, last_epoch=self.last_epoch, verbose=self.verbose
)
@SchedulerRegistry.register("CosineAnnealingWarmRestarts")
class CosineAnnealingWarmRestartsConfig(SchedulerConfig):
T_0: int
T_mult: int = 1
eta_min: float = 0
last_epoch: int = -1
verbose: bool = False
def build(self, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler.CosineAnnealingWarmRestarts:
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=self.T_0,
T_mult=self.T_mult,
eta_min=self.eta_min,
last_epoch=self.last_epoch,
verbose=self.verbose,
)
if __name__ == "__main__":
# Create a dummy optimizer for demonstration
model = torch.nn.Linear(10, 2)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
pprint(SchedulerRegistry._schedulers)
# Create a StepLR scheduler
step_lr = SchedulerRegistry.create_scheduler("StepLR", optimizer, step_size=30, gamma=0.1)
print(f"Created StepLR scheduler: {step_lr}")
# Create a CosineAnnealingLR scheduler
cosine_lr = SchedulerRegistry.create_scheduler("CosineAnnealingLR", optimizer, T_max=100, eta_min=0.001)
print(f"Created CosineAnnealingLR scheduler: {cosine_lr}")
# Create a LambdaLR scheduler
cosine_warm_restarts = SchedulerRegistry.create_scheduler(
"CosineAnnealingWarmRestarts", optimizer, T_0=100, T_mult=2
)
print(f"Created CosineAnnealingWarmRestarts scheduler: {cosine_warm_restarts}")
{ │ 'StepLR': <class '__main__.StepLRConfig'>, │ 'CosineAnnealingLR': <class '__main__.CosineAnnealingLRConfig'>, │ 'CosineAnnealingWarmRestarts': <class '__main__.CosineAnnealingWarmRestartsConfig'> }
Created StepLR scheduler: <torch.optim.lr_scheduler.StepLR object at 0x7f747c650730>
Created CosineAnnealingLR scheduler: <torch.optim.lr_scheduler.CosineAnnealingLR object at 0x7f7472ea3a00>
Created CosineAnnealingWarmRestarts scheduler: <torch.optim.lr_scheduler.CosineAnnealingWarmRestarts object at 0x7f7472e03730>
/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/optim/lr_scheduler.py:62: UserWarning: The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.
warnings.warn(
Singleton#
The registry also acts globally and behaves like a singleton. So you can make it one if you want to, either via the below way or via metaclass.
from typing import Dict, Type, Callable, Any
class SchedulerRegistry:
_instance = None
_schedulers: Dict[str, Type[SchedulerConfig]] = {}
def __new__(cls: Type[SchedulerRegistry]) -> SchedulerRegistry:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def register(self, name: str) -> Callable[[Type[SchedulerConfig]], Type[SchedulerConfig]]:
def register_scheduler_cls(scheduler_cls: Type[SchedulerConfig]) -> Type[SchedulerConfig]:
if name in self._schedulers:
raise ValueError(f"Cannot register duplicate scheduler {name}")
if not issubclass(scheduler_cls, SchedulerConfig):
raise ValueError(f"Scheduler (name={name}, class={scheduler_cls.__name__}) must extend SchedulerConfig")
self._schedulers[name] = scheduler_cls
return scheduler_cls
return register_scheduler_cls
def get_scheduler(self, name: str) -> Type[SchedulerConfig]:
scheduler_cls = self._schedulers.get(name)
if not scheduler_cls:
raise ValueError(f"Scheduler {name} not found in registry")
return scheduler_cls
def create_scheduler(
self, name: str, optimizer: torch.optim.Optimizer, **kwargs: Any
) -> torch.optim.lr_scheduler.LRScheduler:
scheduler_cls = self.get_scheduler(name)
scheduler_config = scheduler_cls(**kwargs)
return scheduler_config.build(optimizer)