Decorator#

Twitter Handle LinkedIn Profile GitHub Profile Tag

Decorator#

from typing import Callable, TypeVar
from typing_extensions import ParamSpec

# https://mypy.readthedocs.io/en/stable/generics.html#declaring-decorators
P = ParamSpec("P")
T = TypeVar("T")


def trace(func: Callable[P, T]) -> Callable[P, T]:
    """Decorator to log function calls."""

    def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
        """Me is wrapper."""
        result = func(*args, **kwargs)
        print(f"{func.__name__}({args!r}, {kwargs!r}) " f"-> {result!r}")
        return result
    return wrapper


def greet(name: str) -> str:
    """Function to greet a person."""
    msg = f"Hello, {name}!"
    return msg


print(f"Before decorating func name: {greet.__name__}")
print(f"Before decorating func docstring: {greet.__doc__}")
greet = trace(greet)
msg = greet("Bob")
print(f"After decorating func name: {greet.__name__}")
print(f"After decorating func docstring: {greet.__doc__}")
Before decorating func name: greet
Before decorating func docstring: Function to greet a person.
greet(('Bob',), {}) -> 'Hello, Bob!'
After decorating func name: wrapper
After decorating func docstring: Me is wrapper.

The flow is pretty simple:

  1. greet is passed in to trace as func = greet.

  2. The trace returns a wrapper function.

  3. The wrapper function takes in any args and kwargs in which these arguments are passed on to func, which is in the local scope of trace.

  4. Now we can intuitively see why it is called wrapper, as it wraps around the original greet.

  5. So when we do greet = trace(greet), and this patched greet when called with say, msg = greet(name = "Bob"), then the wrapper is invoked, so we have wrapper(name = "Bob"), but now we have an additional print statement in the wrapper, so we will see print(f"{func.__name__}({args!r}, {kwargs!r}) " f"-> {result!r}") get called as well.

Now below we use @ as decorator to be decorated on top of greet, as we see now, the @ is just syntatic sugar for what we have done earlier: greet = trace(greet).

@trace
def greet(name: str) -> str:
    """Function to greet a person."""
    msg = f"Hello, {name}!"
    return msg

print(f"Before decorating func name: {greet.__name__}")
print(f"Before decorating func docstring: {greet.__doc__}")
msg = greet("Bob")
print(f"After decorating func name: {greet.__name__}")
print(f"After decorating func docstring: {greet.__doc__}")
Before decorating func name: wrapper
Before decorating func docstring: Me is wrapper.
greet(('Bob',), {}) -> 'Hello, Bob!'
After decorating func name: wrapper
After decorating func docstring: Me is wrapper.

Introspection Is Not Preserved#

But notice the introspection is not preserved as now metadata like __name__ and __doc__ are no longer preserved.

Before decorating func name: wrapper
Before decorating func docstring: Me is wrapper.
greet(('Bob',), {}) -> 'Hello, Bob!'
After decorating func name: wrapper
After decorating func docstring: Me is wrapper.

So the new greet is no longer the original greet as seen, it is now called wrapper and the docstring is no longer the original greet docstring. This causes issue, debugger such as pdb will not work as expected as they rely heavily on introspection.

We can use functools as the package has implemented the decorator such that the original function’s metadata is preserved for instrospection.

Using functools.wraps#

import functools

def trace(func: Callable[P, T]) -> Callable[P, T]:
    """Decorator to log function calls."""

    @functools.wraps(func) # This copies the metadata of `func` to `wrapper`
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
        """Me is wrapper."""
        result = func(*args, **kwargs)
        print(f"{func.__name__}({args!r}, {kwargs!r}) " f"-> {result!r}")
        return result
    return wrapper

@trace
def greet(name: str) -> str:
    """Function to greet a person."""
    msg = f"Hello, {name}!"
    return msg


print(f"Before decorating func name: {greet.__name__}")
print(f"Before decorating func docstring: {greet.__doc__}")
msg = greet("Bob")
print(f"After decorating func name: {greet.__name__}")
print(f"After decorating func docstring: {greet.__doc__}")
Before decorating func name: greet
Before decorating func docstring: Function to greet a person.
greet(('Bob',), {}) -> 'Hello, Bob!'
After decorating func name: greet
After decorating func docstring: Function to greet a person.

We do another example with recursion, which shows the power of the trace decorator. One big use case of decorators is logging, tracing, etc.

@trace
def fibonacci(n: int) -> int:
    """Return the n-th Fibonacci number.

    Parameters
    ----------
    n: int
        The index of the Fibonacci number.

    Returns
    -------
    int
        The n-th Fibonacci number.
    """
    if n in {0, 1}:
        return n
    return (fibonacci(n - 2) + fibonacci(n - 1))

fibonacci(4)
fibonacci((0,), {}) -> 0
fibonacci((1,), {}) -> 1
fibonacci((2,), {}) -> 1
fibonacci((1,), {}) -> 1
fibonacci((0,), {}) -> 0
fibonacci((1,), {}) -> 1
fibonacci((2,), {}) -> 1
fibonacci((3,), {}) -> 2
fibonacci((4,), {}) -> 3
3

Decorators For Reistry Design Pattern#

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Callable, Dict, List, Type, TypeVar

from rich.pretty import pprint


class Recipe(ABC):
    @abstractmethod
    def ingredients(self) -> List[str]:
        ...

    @abstractmethod
    def instructions(self) -> List[str]:
        ...


T = TypeVar("T", bound=Recipe)


class RecipeBook:
    _recipes: Dict[str, Type[Recipe]] = {}

    @classmethod
    def register(cls: Type[RecipeBook], category: str) -> Callable[[Type[T]], Type[T]]:
        def decorator(recipe_cls: Type[T]) -> Type[T]:
            if not issubclass(recipe_cls, Recipe):
                raise TypeError(f"{recipe_cls.__name__} must inherit from Recipe")
            if category in cls._recipes:
                raise ValueError(f"A recipe is already registered for {category}")
            cls._recipes[category] = recipe_cls
            return recipe_cls

        return decorator

    @classmethod
    def get_recipe(cls: Type[RecipeBook], category: str) -> Type[Recipe]:
        if category not in cls._recipes:
            raise KeyError(f"No recipe found for {category}")
        return cls._recipes[category]


@RecipeBook.register("pasta")
class PastaRecipe(Recipe):
    def ingredients(self) -> List[str]:
        return ["Pasta", "Tomato sauce", "Cheese"]

    def instructions(self) -> List[str]:
        return [
            "Boil pasta according to package instructions",
            "Heat tomato sauce in a pan",
            "Drain pasta and mix with sauce",
            "Sprinkle cheese on top",
        ]


@RecipeBook.register("salad")
class SaladRecipe(Recipe):
    def ingredients(self) -> List[str]:
        return ["Lettuce", "Tomatoes", "Cucumber", "Dressing"]

    def instructions(self) -> List[str]:
        return ["Wash and chop lettuce, tomatoes, and cucumber", "Mix vegetables in a bowl", "Add dressing and toss"]


def print_recipe(category: str) -> None:
    recipe_cls = RecipeBook.get_recipe(category)
    recipe = recipe_cls()
    print(f"\n{category.capitalize()} Recipe:")
    print("Ingredients:")
    for item in recipe.ingredients():
        print(f"- {item}")
    print("\nInstructions:")
    for i, step in enumerate(recipe.instructions(), 1):
        print(f"{i}. {step}")


def main() -> None:
    print_recipe("pasta")
    print_recipe("salad")


if __name__ == "__main__":
    pprint(RecipeBook._recipes)  # global registry
    main()
{'pasta': <class '__main__.PastaRecipe'>, 'salad': <class '__main__.SaladRecipe'>}
Pasta Recipe:
Ingredients:
- Pasta
- Tomato sauce
- Cheese

Instructions:
1. Boil pasta according to package instructions
2. Heat tomato sauce in a pan
3. Drain pasta and mix with sauce
4. Sprinkle cheese on top

Salad Recipe:
Ingredients:
- Lettuce
- Tomatoes
- Cucumber
- Dressing

Instructions:
1. Wash and chop lettuce, tomatoes, and cucumber
2. Mix vegetables in a bowl
3. Add dressing and toss