How to Inspect Function and Class Signatures in Python?#
Show code cell content
import inspect
from dataclasses import field, make_dataclass
from inspect import Parameter, Signature
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, _GenericAlias, get_type_hints, overload
from pydantic import BaseModel
from rich.pretty import pprint
from transformers import GPT2LMHeadModel, Trainer, TrainingArguments
Motivation#
There are two motivations for why we want to inspect function and class signatures in Python.
Code Introspection in Open Source Projects: When we are working on open source projects, we often need to inspect the function and class signatures of the codebase. This is especially true when we are working on a large codebase with many functions and classes. In this case, we need to inspect the function and class signatures to understand how the codebase is structured and how the functions and classes are used.
Sometimes there are nested abstractions, a child class \(\mathcal{C}_N\) that inherits from a parent class \(\mathcal{C}_{N-1}\), which in turn inherits from another parent class \(\mathcal{C}_{N-2}\), and so on. Sometimes the child class does not immediately show what types of arguments the constructor can take. In this case, we need to inspect the parent classes to understand the constructor signature of the child class.
Agent-Based Function Calling: In agent-based programming, where automated agents are tasked with executing functions or interacting with libraries, the risk of ‘hallucination’—where an agent attempts to invoke non-existent methods or improperly structured calls from a external module/library—is a notable concern. By equipping agents with the capability to query the actual signatures of libraries’ classes or functions, we can significantly mitigate this risk. This ensures that agents operate based on accurate, real-time information, thereby improving the reliability and effectiveness of automated tasks.
Construct Hypothetical Function, Child and Parent Classes#
class ParentClass:
"""This is the parent class."""
parent_class_attr = 'a parent class attribute'
def __init__(self, parent_instance_attr: str) -> None:
self.parent_instance_attr = parent_instance_attr
def parent_method(self) -> str:
"""This is a method in the parent class."""
return "Parent method called"
class ChildClass(ParentClass):
"""This is a subclass of ParentClass."""
# Class attribute
class_attr = 'a class attribute'
# Private and protected attributes
_protected_attr = 'a protected attribute'
__private_attr = 'a private attribute'
def __init__(self, instance_attr: str, parent_instance_attr: str) -> None:
"""Initialize the instance."""
super().__init__(parent_instance_attr)
# Instance attribute
self.instance_attr = instance_attr
self.instance_not_in_constructor_attr = 'an instance attribute not in the constructor'
self._private_instance_attr = 'a private instance attribute'
@property
def read_only_attr(self) -> str:
"""This is a read-only attribute."""
return 'You can read me, but you cannot change me.'
def instance_method(self, arg: str) -> str:
"""This is an instance method."""
return f'Instance method called with argument: {arg}'
@classmethod
def class_method(cls, arg: str) -> str:
"""This is a class method."""
return f'Class method called with argument: {arg}'
@staticmethod
def static_method(arg: str) -> str:
"""This is a static method."""
return f'Static method called with argument: {arg}'
def __str__(self) -> str:
"""Return a string representation of the instance."""
return f'MyClass(instance_attr={self.instance_attr})'
instance_child = ChildClass(instance_attr='an instance attribute', parent_instance_attr='a parent instance attribute')
class_child = ChildClass
instance_parent = ParentClass(parent_instance_attr='a parent instance attribute')
class_parent = ParentClass
def func(a: int, b: str, c: List[int], d: Tuple[str, str], e: Union[int, str], **kwargs: Any) -> str:
return a, b, c, d, e, kwargs
Inspect All Members#
@overload
def get_members_of_function_or_method(
func_or_class: Type[object], predicate: Optional[Callable[[Any], bool]] = None
) -> List[Tuple[str, Any]]:
...
@overload
def get_members_of_function_or_method(
func_or_class: Callable[..., Any], predicate: Optional[Callable[[Any], bool]] = None
) -> List[Tuple[str, Any]]:
...
def get_members_of_function_or_method(
func_or_class: Union[Type[object], Callable[..., Any]], predicate: Optional[Callable[[Any], bool]] = None
) -> List[Tuple[str, Any]]:
return inspect.getmembers(func_or_class, predicate)
def loop_through_members(members: List[Tuple[str, Any]], filter: Optional[str] = None) -> None:
if filter is not None:
members = [member for member in members if filter in member[0]]
for member in members:
name, value = member
print(f'{name}: {value}')
Our initial goal is to get all signatures and type annotations of a class or function. We can use the inspect
module to achieve this. The getmembers
function returns all members of a class or module. We can then filter out the functions and classes and inspect their signatures.
However, for our purpose, it may be overkill since it retrieves all members
within a module, the scope is very broad, for example, inspecting just the func
defined will also return all __globals__
, which may not be what we want.
func_all_members = get_members_of_function_or_method(func, predicate=None)
loop_through_members(func_all_members)
__annotations__: {'a': <class 'int'>, 'b': <class 'str'>, 'c': typing.List[int], 'd': typing.Tuple[str, str], 'e': typing.Union[int, str], 'kwargs': typing.Any, 'return': <class 'str'>}
__call__: <method-wrapper '__call__' of function object at 0x7fe1770615e0>
__class__: <class 'function'>
__closure__: None
__code__: <code object func at 0x7fe17706bdf0, file "/tmp/ipykernel_3080/2139551385.py", line 1>
__defaults__: None
__delattr__: <method-wrapper '__delattr__' of function object at 0x7fe1770615e0>
__dict__: {}
__dir__: <built-in method __dir__ of function object at 0x7fe1770615e0>
__doc__: None
__eq__: <method-wrapper '__eq__' of function object at 0x7fe1770615e0>
__format__: <built-in method __format__ of function object at 0x7fe1770615e0>
__ge__: <method-wrapper '__ge__' of function object at 0x7fe1770615e0>
__get__: <method-wrapper '__get__' of function object at 0x7fe1770615e0>
__getattribute__: <method-wrapper '__getattribute__' of function object at 0x7fe1770615e0>
__globals__: {'__name__': '__main__', '__doc__': 'Automatically created module for IPython interactive environment', '__package__': None, '__loader__': None, '__spec__': None, '__builtin__': <module 'builtins' (built-in)>, '__builtins__': <module 'builtins' (built-in)>, '_ih': ['', 'import inspect\nfrom dataclasses import field, make_dataclass\nfrom inspect import Parameter, Signature\nfrom typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, _GenericAlias, get_type_hints, overload\n\nfrom pydantic import BaseModel\nfrom rich.pretty import pprint\nfrom transformers import GPT2LMHeadModel, Trainer, TrainingArguments', 'class ParentClass:\n """This is the parent class."""\n\n parent_class_attr = \'a parent class attribute\'\n\n def __init__(self, parent_instance_attr: str) -> None:\n self.parent_instance_attr = parent_instance_attr\n\n def parent_method(self) -> str:\n """This is a method in the parent class."""\n return "Parent method called"\n\nclass ChildClass(ParentClass):\n """This is a subclass of ParentClass."""\n\n # Class attribute\n class_attr = \'a class attribute\'\n\n # Private and protected attributes\n _protected_attr = \'a protected attribute\'\n __private_attr = \'a private attribute\'\n\n def __init__(self, instance_attr: str, parent_instance_attr: str) -> None:\n """Initialize the instance."""\n super().__init__(parent_instance_attr)\n # Instance attribute\n self.instance_attr = instance_attr\n self.instance_not_in_constructor_attr = \'an instance attribute not in the constructor\'\n self._private_instance_attr = \'a private instance attribute\'\n\n @property\n def read_only_attr(self) -> str:\n """This is a read-only attribute."""\n return \'You can read me, but you cannot change me.\'\n\n def instance_method(self, arg: str) -> str:\n """This is an instance method."""\n return f\'Instance method called with argument: {arg}\'\n\n @classmethod\n def class_method(cls, arg: str) -> str:\n """This is a class method."""\n return f\'Class method called with argument: {arg}\'\n\n @staticmethod\n def static_method(arg: str) -> str:\n """This is a static method."""\n return f\'Static method called with argument: {arg}\'\n\n def __str__(self) -> str:\n """Return a string representation of the instance."""\n return f\'MyClass(instance_attr={self.instance_attr})\'', "instance_child = ChildClass(instance_attr='an instance attribute', parent_instance_attr='a parent instance attribute')\nclass_child = ChildClass\n\ninstance_parent = ParentClass(parent_instance_attr='a parent instance attribute')\nclass_parent = ParentClass", 'def func(a: int, b: str, c: List[int], d: Tuple[str, str], e: Union[int, str], **kwargs: Any) -> str:\n return a, b, c, d, e, kwargs', "@overload\ndef get_members_of_function_or_method(\n func_or_class: Type[object], predicate: Optional[Callable[[Any], bool]] = None\n) -> List[Tuple[str, Any]]:\n ...\n\n\n@overload\ndef get_members_of_function_or_method(\n func_or_class: Callable[..., Any], predicate: Optional[Callable[[Any], bool]] = None\n) -> List[Tuple[str, Any]]:\n ...\n\n\ndef get_members_of_function_or_method(\n func_or_class: Union[Type[object], Callable[..., Any]], predicate: Optional[Callable[[Any], bool]] = None\n) -> List[Tuple[str, Any]]:\n return inspect.getmembers(func_or_class, predicate)\n\ndef loop_through_members(members: List[Tuple[str, Any]], filter: Optional[str] = None) -> None:\n if filter is not None:\n members = [member for member in members if filter in member[0]]\n for member in members:\n name, value = member\n print(f'{name}: {value}')", 'func_all_members = get_members_of_function_or_method(func, predicate=None)\nloop_through_members(func_all_members)'], '_oh': {}, '_dh': [PosixPath('/home/runner/work/omniverse/omniverse/omniverse/playbook')], 'In': ['', 'import inspect\nfrom dataclasses import field, make_dataclass\nfrom inspect import Parameter, Signature\nfrom typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, _GenericAlias, get_type_hints, overload\n\nfrom pydantic import BaseModel\nfrom rich.pretty import pprint\nfrom transformers import GPT2LMHeadModel, Trainer, TrainingArguments', 'class ParentClass:\n """This is the parent class."""\n\n parent_class_attr = \'a parent class attribute\'\n\n def __init__(self, parent_instance_attr: str) -> None:\n self.parent_instance_attr = parent_instance_attr\n\n def parent_method(self) -> str:\n """This is a method in the parent class."""\n return "Parent method called"\n\nclass ChildClass(ParentClass):\n """This is a subclass of ParentClass."""\n\n # Class attribute\n class_attr = \'a class attribute\'\n\n # Private and protected attributes\n _protected_attr = \'a protected attribute\'\n __private_attr = \'a private attribute\'\n\n def __init__(self, instance_attr: str, parent_instance_attr: str) -> None:\n """Initialize the instance."""\n super().__init__(parent_instance_attr)\n # Instance attribute\n self.instance_attr = instance_attr\n self.instance_not_in_constructor_attr = \'an instance attribute not in the constructor\'\n self._private_instance_attr = \'a private instance attribute\'\n\n @property\n def read_only_attr(self) -> str:\n """This is a read-only attribute."""\n return \'You can read me, but you cannot change me.\'\n\n def instance_method(self, arg: str) -> str:\n """This is an instance method."""\n return f\'Instance method called with argument: {arg}\'\n\n @classmethod\n def class_method(cls, arg: str) -> str:\n """This is a class method."""\n return f\'Class method called with argument: {arg}\'\n\n @staticmethod\n def static_method(arg: str) -> str:\n """This is a static method."""\n return f\'Static method called with argument: {arg}\'\n\n def __str__(self) -> str:\n """Return a string representation of the instance."""\n return f\'MyClass(instance_attr={self.instance_attr})\'', "instance_child = ChildClass(instance_attr='an instance attribute', parent_instance_attr='a parent instance attribute')\nclass_child = ChildClass\n\ninstance_parent = ParentClass(parent_instance_attr='a parent instance attribute')\nclass_parent = ParentClass", 'def func(a: int, b: str, c: List[int], d: Tuple[str, str], e: Union[int, str], **kwargs: Any) -> str:\n return a, b, c, d, e, kwargs', "@overload\ndef get_members_of_function_or_method(\n func_or_class: Type[object], predicate: Optional[Callable[[Any], bool]] = None\n) -> List[Tuple[str, Any]]:\n ...\n\n\n@overload\ndef get_members_of_function_or_method(\n func_or_class: Callable[..., Any], predicate: Optional[Callable[[Any], bool]] = None\n) -> List[Tuple[str, Any]]:\n ...\n\n\ndef get_members_of_function_or_method(\n func_or_class: Union[Type[object], Callable[..., Any]], predicate: Optional[Callable[[Any], bool]] = None\n) -> List[Tuple[str, Any]]:\n return inspect.getmembers(func_or_class, predicate)\n\ndef loop_through_members(members: List[Tuple[str, Any]], filter: Optional[str] = None) -> None:\n if filter is not None:\n members = [member for member in members if filter in member[0]]\n for member in members:\n name, value = member\n print(f'{name}: {value}')", 'func_all_members = get_members_of_function_or_method(func, predicate=None)\nloop_through_members(func_all_members)'], 'Out': {}, 'get_ipython': <bound method InteractiveShell.get_ipython of <ipykernel.zmqshell.ZMQInteractiveShell object at 0x7fe274ff02e0>>, 'exit': <IPython.core.autocall.ZMQExitAutocall object at 0x7fe274ff0be0>, 'quit': <IPython.core.autocall.ZMQExitAutocall object at 0x7fe274ff0be0>, 'open': <function open at 0x7fe28a0f2820>, '_': '', '__': '', '___': '', '_i': "@overload\ndef get_members_of_function_or_method(\n func_or_class: Type[object], predicate: Optional[Callable[[Any], bool]] = None\n) -> List[Tuple[str, Any]]:\n ...\n\n\n@overload\ndef get_members_of_function_or_method(\n func_or_class: Callable[..., Any], predicate: Optional[Callable[[Any], bool]] = None\n) -> List[Tuple[str, Any]]:\n ...\n\n\ndef get_members_of_function_or_method(\n func_or_class: Union[Type[object], Callable[..., Any]], predicate: Optional[Callable[[Any], bool]] = None\n) -> List[Tuple[str, Any]]:\n return inspect.getmembers(func_or_class, predicate)\n\ndef loop_through_members(members: List[Tuple[str, Any]], filter: Optional[str] = None) -> None:\n if filter is not None:\n members = [member for member in members if filter in member[0]]\n for member in members:\n name, value = member\n print(f'{name}: {value}')", '_ii': 'def func(a: int, b: str, c: List[int], d: Tuple[str, str], e: Union[int, str], **kwargs: Any) -> str:\n return a, b, c, d, e, kwargs', '_iii': "instance_child = ChildClass(instance_attr='an instance attribute', parent_instance_attr='a parent instance attribute')\nclass_child = ChildClass\n\ninstance_parent = ParentClass(parent_instance_attr='a parent instance attribute')\nclass_parent = ParentClass", '_i1': 'import inspect\nfrom dataclasses import field, make_dataclass\nfrom inspect import Parameter, Signature\nfrom typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, _GenericAlias, get_type_hints, overload\n\nfrom pydantic import BaseModel\nfrom rich.pretty import pprint\nfrom transformers import GPT2LMHeadModel, Trainer, TrainingArguments', 'inspect': <module 'inspect' from '/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/inspect.py'>, 'field': <function field at 0x7fe28a4fb1f0>, 'make_dataclass': <function make_dataclass at 0x7fe28a50eb80>, 'Parameter': <class 'inspect.Parameter'>, 'Signature': <class 'inspect.Signature'>, 'Any': typing.Any, 'Callable': typing.Callable, 'Dict': typing.Dict, 'List': typing.List, 'Optional': typing.Optional, 'Set': typing.Set, 'Tuple': typing.Tuple, 'Type': typing.Type, 'Union': typing.Union, '_GenericAlias': <class 'typing._GenericAlias'>, 'get_type_hints': <function get_type_hints at 0x7fe28c2d6c10>, 'overload': <function overload at 0x7fe28c2d8040>, 'BaseModel': <class 'pydantic.main.BaseModel'>, 'pprint': <function pprint at 0x7fe2746fbca0>, 'GPT2LMHeadModel': <class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'>, 'Trainer': <class 'transformers.trainer.Trainer'>, 'TrainingArguments': <class 'transformers.training_args.TrainingArguments'>, '_i2': 'class ParentClass:\n """This is the parent class."""\n\n parent_class_attr = \'a parent class attribute\'\n\n def __init__(self, parent_instance_attr: str) -> None:\n self.parent_instance_attr = parent_instance_attr\n\n def parent_method(self) -> str:\n """This is a method in the parent class."""\n return "Parent method called"\n\nclass ChildClass(ParentClass):\n """This is a subclass of ParentClass."""\n\n # Class attribute\n class_attr = \'a class attribute\'\n\n # Private and protected attributes\n _protected_attr = \'a protected attribute\'\n __private_attr = \'a private attribute\'\n\n def __init__(self, instance_attr: str, parent_instance_attr: str) -> None:\n """Initialize the instance."""\n super().__init__(parent_instance_attr)\n # Instance attribute\n self.instance_attr = instance_attr\n self.instance_not_in_constructor_attr = \'an instance attribute not in the constructor\'\n self._private_instance_attr = \'a private instance attribute\'\n\n @property\n def read_only_attr(self) -> str:\n """This is a read-only attribute."""\n return \'You can read me, but you cannot change me.\'\n\n def instance_method(self, arg: str) -> str:\n """This is an instance method."""\n return f\'Instance method called with argument: {arg}\'\n\n @classmethod\n def class_method(cls, arg: str) -> str:\n """This is a class method."""\n return f\'Class method called with argument: {arg}\'\n\n @staticmethod\n def static_method(arg: str) -> str:\n """This is a static method."""\n return f\'Static method called with argument: {arg}\'\n\n def __str__(self) -> str:\n """Return a string representation of the instance."""\n return f\'MyClass(instance_attr={self.instance_attr})\'', 'ParentClass': <class '__main__.ParentClass'>, 'ChildClass': <class '__main__.ChildClass'>, '_i3': "instance_child = ChildClass(instance_attr='an instance attribute', parent_instance_attr='a parent instance attribute')\nclass_child = ChildClass\n\ninstance_parent = ParentClass(parent_instance_attr='a parent instance attribute')\nclass_parent = ParentClass", 'instance_child': <__main__.ChildClass object at 0x7fe17704bbe0>, 'class_child': <class '__main__.ChildClass'>, 'instance_parent': <__main__.ParentClass object at 0x7fe17704b760>, 'class_parent': <class '__main__.ParentClass'>, '_i4': 'def func(a: int, b: str, c: List[int], d: Tuple[str, str], e: Union[int, str], **kwargs: Any) -> str:\n return a, b, c, d, e, kwargs', 'func': <function func at 0x7fe1770615e0>, '_i5': "@overload\ndef get_members_of_function_or_method(\n func_or_class: Type[object], predicate: Optional[Callable[[Any], bool]] = None\n) -> List[Tuple[str, Any]]:\n ...\n\n\n@overload\ndef get_members_of_function_or_method(\n func_or_class: Callable[..., Any], predicate: Optional[Callable[[Any], bool]] = None\n) -> List[Tuple[str, Any]]:\n ...\n\n\ndef get_members_of_function_or_method(\n func_or_class: Union[Type[object], Callable[..., Any]], predicate: Optional[Callable[[Any], bool]] = None\n) -> List[Tuple[str, Any]]:\n return inspect.getmembers(func_or_class, predicate)\n\ndef loop_through_members(members: List[Tuple[str, Any]], filter: Optional[str] = None) -> None:\n if filter is not None:\n members = [member for member in members if filter in member[0]]\n for member in members:\n name, value = member\n print(f'{name}: {value}')", 'get_members_of_function_or_method': <function get_members_of_function_or_method at 0x7fe177061dc0>, 'loop_through_members': <function loop_through_members at 0x7fe177061e50>, '_i6': 'func_all_members = get_members_of_function_or_method(func, predicate=None)\nloop_through_members(func_all_members)', 'func_all_members': [('__annotations__', {'a': <class 'int'>, 'b': <class 'str'>, 'c': typing.List[int], 'd': typing.Tuple[str, str], 'e': typing.Union[int, str], 'kwargs': typing.Any, 'return': <class 'str'>}), ('__call__', <method-wrapper '__call__' of function object at 0x7fe1770615e0>), ('__class__', <class 'function'>), ('__closure__', None), ('__code__', <code object func at 0x7fe17706bdf0, file "/tmp/ipykernel_3080/2139551385.py", line 1>), ('__defaults__', None), ('__delattr__', <method-wrapper '__delattr__' of function object at 0x7fe1770615e0>), ('__dict__', {}), ('__dir__', <built-in method __dir__ of function object at 0x7fe1770615e0>), ('__doc__', None), ('__eq__', <method-wrapper '__eq__' of function object at 0x7fe1770615e0>), ('__format__', <built-in method __format__ of function object at 0x7fe1770615e0>), ('__ge__', <method-wrapper '__ge__' of function object at 0x7fe1770615e0>), ('__get__', <method-wrapper '__get__' of function object at 0x7fe1770615e0>), ('__getattribute__', <method-wrapper '__getattribute__' of function object at 0x7fe1770615e0>), ('__globals__', {...}), ('__gt__', <method-wrapper '__gt__' of function object at 0x7fe1770615e0>), ('__hash__', <method-wrapper '__hash__' of function object at 0x7fe1770615e0>), ('__init__', <method-wrapper '__init__' of function object at 0x7fe1770615e0>), ('__init_subclass__', <built-in method __init_subclass__ of type object at 0x7fe28cf65960>), ('__kwdefaults__', None), ('__le__', <method-wrapper '__le__' of function object at 0x7fe1770615e0>), ('__lt__', <method-wrapper '__lt__' of function object at 0x7fe1770615e0>), ('__module__', '__main__'), ('__name__', 'func'), ('__ne__', <method-wrapper '__ne__' of function object at 0x7fe1770615e0>), ('__new__', <built-in method __new__ of type object at 0x7fe28cf65960>), ('__qualname__', 'func'), ('__reduce__', <built-in method __reduce__ of function object at 0x7fe1770615e0>), ('__reduce_ex__', <built-in method __reduce_ex__ of function object at 0x7fe1770615e0>), ('__repr__', <method-wrapper '__repr__' of function object at 0x7fe1770615e0>), ('__setattr__', <method-wrapper '__setattr__' of function object at 0x7fe1770615e0>), ('__sizeof__', <built-in method __sizeof__ of function object at 0x7fe1770615e0>), ('__str__', <method-wrapper '__str__' of function object at 0x7fe1770615e0>), ('__subclasshook__', <built-in method __subclasshook__ of type object at 0x7fe28cf65960>)]}
__gt__: <method-wrapper '__gt__' of function object at 0x7fe1770615e0>
__hash__: <method-wrapper '__hash__' of function object at 0x7fe1770615e0>
__init__: <method-wrapper '__init__' of function object at 0x7fe1770615e0>
__init_subclass__: <built-in method __init_subclass__ of type object at 0x7fe28cf65960>
__kwdefaults__: None
__le__: <method-wrapper '__le__' of function object at 0x7fe1770615e0>
__lt__: <method-wrapper '__lt__' of function object at 0x7fe1770615e0>
__module__: __main__
__name__: func
__ne__: <method-wrapper '__ne__' of function object at 0x7fe1770615e0>
__new__: <built-in method __new__ of type object at 0x7fe28cf65960>
__qualname__: func
__reduce__: <built-in method __reduce__ of function object at 0x7fe1770615e0>
__reduce_ex__: <built-in method __reduce_ex__ of function object at 0x7fe1770615e0>
__repr__: <method-wrapper '__repr__' of function object at 0x7fe1770615e0>
__setattr__: <method-wrapper '__setattr__' of function object at 0x7fe1770615e0>
__sizeof__: <built-in method __sizeof__ of function object at 0x7fe1770615e0>
__str__: <method-wrapper '__str__' of function object at 0x7fe1770615e0>
__subclasshook__: <built-in method __subclasshook__ of type object at 0x7fe28cf65960>
And to get the signature, we can just filter '__annotations__'
.
loop_through_members(func_all_members, filter='__annotations__')
__annotations__: {'a': <class 'int'>, 'b': <class 'str'>, 'c': typing.List[int], 'd': typing.Tuple[str, str], 'e': typing.Union[int, str], 'kwargs': typing.Any, 'return': <class 'str'>}
class_child_all_members = get_members_of_function_or_method(class_child, predicate=None)
loop_through_members(class_child_all_members)
_ChildClass__private_attr: a private attribute
__class__: <class 'type'>
__delattr__: <slot wrapper '__delattr__' of 'object' objects>
__dict__: {'__module__': '__main__', '__doc__': 'This is a subclass of ParentClass.', 'class_attr': 'a class attribute', '_protected_attr': 'a protected attribute', '_ChildClass__private_attr': 'a private attribute', '__init__': <function ChildClass.__init__ at 0x7fe1770619d0>, 'read_only_attr': <property object at 0x7fe1770647c0>, 'instance_method': <function ChildClass.instance_method at 0x7fe177061af0>, 'class_method': <classmethod object at 0x7fe17704b220>, 'static_method': <staticmethod object at 0x7fe17704bdc0>, '__str__': <function ChildClass.__str__ at 0x7fe177061ca0>}
__dir__: <method '__dir__' of 'object' objects>
__doc__: This is a subclass of ParentClass.
__eq__: <slot wrapper '__eq__' of 'object' objects>
__format__: <method '__format__' of 'object' objects>
__ge__: <slot wrapper '__ge__' of 'object' objects>
__getattribute__: <slot wrapper '__getattribute__' of 'object' objects>
__gt__: <slot wrapper '__gt__' of 'object' objects>
__hash__: <slot wrapper '__hash__' of 'object' objects>
__init__: <function ChildClass.__init__ at 0x7fe1770619d0>
__init_subclass__: <built-in method __init_subclass__ of type object at 0x55f0d55ad280>
__le__: <slot wrapper '__le__' of 'object' objects>
__lt__: <slot wrapper '__lt__' of 'object' objects>
__module__: __main__
__ne__: <slot wrapper '__ne__' of 'object' objects>
__new__: <built-in method __new__ of type object at 0x7fe28cf70840>
__reduce__: <method '__reduce__' of 'object' objects>
__reduce_ex__: <method '__reduce_ex__' of 'object' objects>
__repr__: <slot wrapper '__repr__' of 'object' objects>
__setattr__: <slot wrapper '__setattr__' of 'object' objects>
__sizeof__: <method '__sizeof__' of 'object' objects>
__str__: <function ChildClass.__str__ at 0x7fe177061ca0>
__subclasshook__: <built-in method __subclasshook__ of type object at 0x55f0d55ad280>
__weakref__: <attribute '__weakref__' of 'ParentClass' objects>
_protected_attr: a protected attribute
class_attr: a class attribute
class_method: <bound method ChildClass.class_method of <class '__main__.ChildClass'>>
instance_method: <function ChildClass.instance_method at 0x7fe177061af0>
parent_class_attr: a parent class attribute
parent_method: <function ParentClass.parent_method at 0x7fe177061940>
read_only_attr: <property object at 0x7fe1770647c0>
static_method: <function ChildClass.static_method at 0x7fe177061c10>
instance_child_all_members = get_members_of_function_or_method(instance_child, predicate=None)
loop_through_members(instance_child_all_members)
_ChildClass__private_attr: a private attribute
__class__: <class '__main__.ChildClass'>
__delattr__: <method-wrapper '__delattr__' of ChildClass object at 0x7fe17704bbe0>
__dict__: {'parent_instance_attr': 'a parent instance attribute', 'instance_attr': 'an instance attribute', 'instance_not_in_constructor_attr': 'an instance attribute not in the constructor', '_private_instance_attr': 'a private instance attribute'}
__dir__: <built-in method __dir__ of ChildClass object at 0x7fe17704bbe0>
__doc__: This is a subclass of ParentClass.
__eq__: <method-wrapper '__eq__' of ChildClass object at 0x7fe17704bbe0>
__format__: <built-in method __format__ of ChildClass object at 0x7fe17704bbe0>
__ge__: <method-wrapper '__ge__' of ChildClass object at 0x7fe17704bbe0>
__getattribute__: <method-wrapper '__getattribute__' of ChildClass object at 0x7fe17704bbe0>
__gt__: <method-wrapper '__gt__' of ChildClass object at 0x7fe17704bbe0>
__hash__: <method-wrapper '__hash__' of ChildClass object at 0x7fe17704bbe0>
__init__: <bound method ChildClass.__init__ of <__main__.ChildClass object at 0x7fe17704bbe0>>
__init_subclass__: <built-in method __init_subclass__ of type object at 0x55f0d55ad280>
__le__: <method-wrapper '__le__' of ChildClass object at 0x7fe17704bbe0>
__lt__: <method-wrapper '__lt__' of ChildClass object at 0x7fe17704bbe0>
__module__: __main__
__ne__: <method-wrapper '__ne__' of ChildClass object at 0x7fe17704bbe0>
__new__: <built-in method __new__ of type object at 0x7fe28cf70840>
__reduce__: <built-in method __reduce__ of ChildClass object at 0x7fe17704bbe0>
__reduce_ex__: <built-in method __reduce_ex__ of ChildClass object at 0x7fe17704bbe0>
__repr__: <method-wrapper '__repr__' of ChildClass object at 0x7fe17704bbe0>
__setattr__: <method-wrapper '__setattr__' of ChildClass object at 0x7fe17704bbe0>
__sizeof__: <built-in method __sizeof__ of ChildClass object at 0x7fe17704bbe0>
__str__: <bound method ChildClass.__str__ of <__main__.ChildClass object at 0x7fe17704bbe0>>
__subclasshook__: <built-in method __subclasshook__ of type object at 0x55f0d55ad280>
__weakref__: None
_private_instance_attr: a private instance attribute
_protected_attr: a protected attribute
class_attr: a class attribute
class_method: <bound method ChildClass.class_method of <class '__main__.ChildClass'>>
instance_attr: an instance attribute
instance_method: <bound method ChildClass.instance_method of <__main__.ChildClass object at 0x7fe17704bbe0>>
instance_not_in_constructor_attr: an instance attribute not in the constructor
parent_class_attr: a parent class attribute
parent_instance_attr: a parent instance attribute
parent_method: <bound method ParentClass.parent_method of <__main__.ChildClass object at 0x7fe17704bbe0>>
read_only_attr: You can read me, but you cannot change me.
static_method: <function ChildClass.static_method at 0x7fe177061c10>
trainer_all_members = get_members_of_function_or_method(Trainer, predicate=None)
loop_through_members(trainer_all_members)
__class__: <class 'type'>
__delattr__: <slot wrapper '__delattr__' of 'object' objects>
__dict__: {'__module__': 'transformers.trainer', '__doc__': "\n Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.\n\n Args:\n model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*):\n The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.\n\n <Tip>\n\n [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use\n your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers\n models.\n\n </Tip>\n\n args ([`TrainingArguments`], *optional*):\n The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the\n `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.\n data_collator (`DataCollator`, *optional*):\n The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will\n default to [`default_data_collator`] if no `tokenizer` is provided, an instance of\n [`DataCollatorWithPadding`] otherwise.\n train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*):\n The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the\n `model.forward()` method are automatically removed.\n\n Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a\n distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a\n `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will\n manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally\n sets the seed of the RNGs used.\n eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*):\n The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the\n `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each\n dataset prepending the dictionary key to the metric name.\n tokenizer ([`PreTrainedTokenizerBase`], *optional*):\n The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the\n maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an\n interrupted training or reuse the fine-tuned model.\n model_init (`Callable[[], PreTrainedModel]`, *optional*):\n A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start\n from a new instance of the model as given by this function.\n\n The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to\n be able to choose different architectures according to hyper parameters (such as layer count, sizes of\n inner layers, dropout probabilities etc).\n compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):\n The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return\n a dictionary string to metric values.\n callbacks (List of [`TrainerCallback`], *optional*):\n A list of callbacks to customize the training loop. Will add those to the list of default callbacks\n detailed in [here](callback).\n\n If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.\n optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):\n A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your\n model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.\n preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):\n A function that preprocess the logits right before caching them at each evaluation step. Must take two\n tensors, the logits and the labels, and return the logits once processed as desired. The modifications made\n by this function will be reflected in the predictions received by `compute_metrics`.\n\n Note that the labels (second parameter) will be `None` if the dataset does not have them.\n\n Important attributes:\n\n - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]\n subclass.\n - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the\n original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,\n the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner\n model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.\n - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from\n data parallelism, this means some of the model layers are split on different GPUs).\n - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set\n to `False` if model parallel or deepspeed is used, or if the default\n `TrainingArguments.place_model_on_device` is overridden to return `False` .\n - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while\n in `train`)\n\n ", '_get_learning_rate': <function _get_learning_rate at 0x7fe1946c2a60>, 'log_metrics': <function log_metrics at 0x7fe1946c93a0>, 'metrics_format': <function metrics_format at 0x7fe1946c9310>, 'save_metrics': <function save_metrics at 0x7fe1946c9430>, 'save_state': <function save_state at 0x7fe1946c94c0>, '__init__': <function Trainer.__init__ at 0x7fe17c8408b0>, '_activate_neftune': <function Trainer._activate_neftune at 0x7fe17705eca0>, '_deactivate_neftune': <function Trainer._deactivate_neftune at 0x7fe17705ed30>, 'add_callback': <function Trainer.add_callback at 0x7fe17705edc0>, 'pop_callback': <function Trainer.pop_callback at 0x7fe17705ee50>, 'remove_callback': <function Trainer.remove_callback at 0x7fe17705eee0>, '_move_model_to_device': <function Trainer._move_model_to_device at 0x7fe17705ef70>, '_set_signature_columns_if_needed': <function Trainer._set_signature_columns_if_needed at 0x7fe17705f040>, '_remove_unused_columns': <function Trainer._remove_unused_columns at 0x7fe17705f0d0>, '_get_collator_with_removed_columns': <function Trainer._get_collator_with_removed_columns at 0x7fe17705f160>, '_get_train_sampler': <function Trainer._get_train_sampler at 0x7fe17705f1f0>, 'get_train_dataloader': <function Trainer.get_train_dataloader at 0x7fe17705f280>, '_get_eval_sampler': <function Trainer._get_eval_sampler at 0x7fe17705f310>, 'get_eval_dataloader': <function Trainer.get_eval_dataloader at 0x7fe17705f3a0>, 'get_test_dataloader': <function Trainer.get_test_dataloader at 0x7fe17705f430>, 'create_optimizer_and_scheduler': <function Trainer.create_optimizer_and_scheduler at 0x7fe17705f4c0>, 'get_decay_parameter_names': <function Trainer.get_decay_parameter_names at 0x7fe17705f550>, 'create_optimizer': <function Trainer.create_optimizer at 0x7fe17705f5e0>, 'get_optimizer_cls_and_kwargs': <staticmethod object at 0x7fe17c8794c0>, 'create_scheduler': <function Trainer.create_scheduler at 0x7fe17705f700>, 'num_examples': <function Trainer.num_examples at 0x7fe17705f790>, 'num_tokens': <function Trainer.num_tokens at 0x7fe17705f820>, '_hp_search_setup': <function Trainer._hp_search_setup at 0x7fe17705f8b0>, '_report_to_hp_search': <function Trainer._report_to_hp_search at 0x7fe17705f940>, '_tune_save_checkpoint': <function Trainer._tune_save_checkpoint at 0x7fe17705f9d0>, 'call_model_init': <function Trainer.call_model_init at 0x7fe17705fa60>, 'torch_jit_model_eval': <function Trainer.torch_jit_model_eval at 0x7fe17705faf0>, 'ipex_optimize_model': <function Trainer.ipex_optimize_model at 0x7fe17705fb80>, '_wrap_model': <function Trainer._wrap_model at 0x7fe17705fc10>, 'train': <function Trainer.train at 0x7fe17705fca0>, '_inner_training_loop': <function Trainer._inner_training_loop at 0x7fe17705fd30>, '_get_output_dir': <function Trainer._get_output_dir at 0x7fe17705fdc0>, '_load_from_checkpoint': <function Trainer._load_from_checkpoint at 0x7fe17705fe50>, '_load_best_model': <function Trainer._load_best_model at 0x7fe17705fee0>, '_issue_warnings_after_load': <function Trainer._issue_warnings_after_load at 0x7fe17705ff70>, '_maybe_log_save_evaluate': <function Trainer._maybe_log_save_evaluate at 0x7fe17705d040>, '_load_rng_state': <function Trainer._load_rng_state at 0x7fe17705d0d0>, '_save_checkpoint': <function Trainer._save_checkpoint at 0x7fe17705d160>, '_save_rng_state': <function Trainer._save_rng_state at 0x7fe17705d1f0>, '_save_optimizer_and_scheduler': <function Trainer._save_optimizer_and_scheduler at 0x7fe17705d280>, '_load_optimizer_and_scheduler': <function Trainer._load_optimizer_and_scheduler at 0x7fe17705d310>, 'hyperparameter_search': <function Trainer.hyperparameter_search at 0x7fe17705d3a0>, 'log': <function Trainer.log at 0x7fe17705d430>, '_prepare_input': <function Trainer._prepare_input at 0x7fe17705d4c0>, '_prepare_inputs': <function Trainer._prepare_inputs at 0x7fe17705d550>, 'compute_loss_context_manager': <function Trainer.compute_loss_context_manager at 0x7fe17705d5e0>, 'autocast_smart_context_manager': <function Trainer.autocast_smart_context_manager at 0x7fe17705d670>, 'training_step': <function Trainer.training_step at 0x7fe17705d700>, 'compute_loss': <function Trainer.compute_loss at 0x7fe17705d790>, 'is_local_process_zero': <function Trainer.is_local_process_zero at 0x7fe17705d820>, 'is_world_process_zero': <function Trainer.is_world_process_zero at 0x7fe17705d8b0>, 'save_model': <function Trainer.save_model at 0x7fe17705d940>, '_save_tpu': <function Trainer._save_tpu at 0x7fe17705d9d0>, '_save': <function Trainer._save at 0x7fe17705da60>, 'store_flos': <function Trainer.store_flos at 0x7fe17705daf0>, '_sorted_checkpoints': <function Trainer._sorted_checkpoints at 0x7fe17705db80>, '_rotate_checkpoints': <function Trainer._rotate_checkpoints at 0x7fe17705dc10>, 'evaluate': <function Trainer.evaluate at 0x7fe17705dca0>, 'predict': <function Trainer.predict at 0x7fe17705dd30>, 'evaluation_loop': <function Trainer.evaluation_loop at 0x7fe17705ddc0>, '_nested_gather': <function Trainer._nested_gather at 0x7fe17705de50>, 'prediction_step': <function Trainer.prediction_step at 0x7fe17705dee0>, 'floating_point_ops': <function Trainer.floating_point_ops at 0x7fe17705df70>, 'init_hf_repo': <function Trainer.init_hf_repo at 0x7fe177061040>, 'create_model_card': <function Trainer.create_model_card at 0x7fe1770610d0>, '_push_from_checkpoint': <function Trainer._push_from_checkpoint at 0x7fe177061160>, '_finish_current_push': <function Trainer._finish_current_push at 0x7fe1770611f0>, 'push_to_hub': <function Trainer.push_to_hub at 0x7fe177061280>, 'prediction_loop': <function Trainer.prediction_loop at 0x7fe177061310>, '_gather_and_numpify': <function Trainer._gather_and_numpify at 0x7fe1770613a0>, '_add_sm_patterns_to_gitignore': <function Trainer._add_sm_patterns_to_gitignore at 0x7fe177061430>, 'create_accelerator_and_postprocess': <function Trainer.create_accelerator_and_postprocess at 0x7fe1770614c0>, 'propagate_args_to_deepspeed': <function Trainer.propagate_args_to_deepspeed at 0x7fe177061550>, '__dict__': <attribute '__dict__' of 'Trainer' objects>, '__weakref__': <attribute '__weakref__' of 'Trainer' objects>}
__dir__: <method '__dir__' of 'object' objects>
__doc__:
Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
Args:
model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*):
The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.
<Tip>
[`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use
your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers
models.
</Tip>
args ([`TrainingArguments`], *optional*):
The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the
`output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
data_collator (`DataCollator`, *optional*):
The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will
default to [`default_data_collator`] if no `tokenizer` is provided, an instance of
[`DataCollatorWithPadding`] otherwise.
train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*):
The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed.
Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a
distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a
`torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
sets the seed of the RNGs used.
eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*):
The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
dataset prepending the dictionary key to the metric name.
tokenizer ([`PreTrainedTokenizerBase`], *optional*):
The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the
maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
interrupted training or reuse the fine-tuned model.
model_init (`Callable[[], PreTrainedModel]`, *optional*):
A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start
from a new instance of the model as given by this function.
The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
be able to choose different architectures according to hyper parameters (such as layer count, sizes of
inner layers, dropout probabilities etc).
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
a dictionary string to metric values.
callbacks (List of [`TrainerCallback`], *optional*):
A list of callbacks to customize the training loop. Will add those to the list of default callbacks
detailed in [here](callback).
If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
A function that preprocess the logits right before caching them at each evaluation step. Must take two
tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
by this function will be reflected in the predictions received by `compute_metrics`.
Note that the labels (second parameter) will be `None` if the dataset does not have them.
Important attributes:
- **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]
subclass.
- **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,
the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner
model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.
- **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
data parallelism, this means some of the model layers are split on different GPUs).
- **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
to `False` if model parallel or deepspeed is used, or if the default
`TrainingArguments.place_model_on_device` is overridden to return `False` .
- **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while
in `train`)
__eq__: <slot wrapper '__eq__' of 'object' objects>
__format__: <method '__format__' of 'object' objects>
__ge__: <slot wrapper '__ge__' of 'object' objects>
__getattribute__: <slot wrapper '__getattribute__' of 'object' objects>
__gt__: <slot wrapper '__gt__' of 'object' objects>
__hash__: <slot wrapper '__hash__' of 'object' objects>
__init__: <function Trainer.__init__ at 0x7fe17c8408b0>
__init_subclass__: <built-in method __init_subclass__ of type object at 0x55f0d6b62a00>
__le__: <slot wrapper '__le__' of 'object' objects>
__lt__: <slot wrapper '__lt__' of 'object' objects>
__module__: transformers.trainer
__ne__: <slot wrapper '__ne__' of 'object' objects>
__new__: <built-in method __new__ of type object at 0x7fe28cf70840>
__reduce__: <method '__reduce__' of 'object' objects>
__reduce_ex__: <method '__reduce_ex__' of 'object' objects>
__repr__: <slot wrapper '__repr__' of 'object' objects>
__setattr__: <slot wrapper '__setattr__' of 'object' objects>
__sizeof__: <method '__sizeof__' of 'object' objects>
__str__: <slot wrapper '__str__' of 'object' objects>
__subclasshook__: <built-in method __subclasshook__ of type object at 0x55f0d6b62a00>
__weakref__: <attribute '__weakref__' of 'Trainer' objects>
_activate_neftune: <function Trainer._activate_neftune at 0x7fe17705eca0>
_add_sm_patterns_to_gitignore: <function Trainer._add_sm_patterns_to_gitignore at 0x7fe177061430>
_deactivate_neftune: <function Trainer._deactivate_neftune at 0x7fe17705ed30>
_finish_current_push: <function Trainer._finish_current_push at 0x7fe1770611f0>
_gather_and_numpify: <function Trainer._gather_and_numpify at 0x7fe1770613a0>
_get_collator_with_removed_columns: <function Trainer._get_collator_with_removed_columns at 0x7fe17705f160>
_get_eval_sampler: <function Trainer._get_eval_sampler at 0x7fe17705f310>
_get_learning_rate: <function _get_learning_rate at 0x7fe1946c2a60>
_get_output_dir: <function Trainer._get_output_dir at 0x7fe17705fdc0>
_get_train_sampler: <function Trainer._get_train_sampler at 0x7fe17705f1f0>
_hp_search_setup: <function Trainer._hp_search_setup at 0x7fe17705f8b0>
_inner_training_loop: <function Trainer._inner_training_loop at 0x7fe17705fd30>
_issue_warnings_after_load: <function Trainer._issue_warnings_after_load at 0x7fe17705ff70>
_load_best_model: <function Trainer._load_best_model at 0x7fe17705fee0>
_load_from_checkpoint: <function Trainer._load_from_checkpoint at 0x7fe17705fe50>
_load_optimizer_and_scheduler: <function Trainer._load_optimizer_and_scheduler at 0x7fe17705d310>
_load_rng_state: <function Trainer._load_rng_state at 0x7fe17705d0d0>
_maybe_log_save_evaluate: <function Trainer._maybe_log_save_evaluate at 0x7fe17705d040>
_move_model_to_device: <function Trainer._move_model_to_device at 0x7fe17705ef70>
_nested_gather: <function Trainer._nested_gather at 0x7fe17705de50>
_prepare_input: <function Trainer._prepare_input at 0x7fe17705d4c0>
_prepare_inputs: <function Trainer._prepare_inputs at 0x7fe17705d550>
_push_from_checkpoint: <function Trainer._push_from_checkpoint at 0x7fe177061160>
_remove_unused_columns: <function Trainer._remove_unused_columns at 0x7fe17705f0d0>
_report_to_hp_search: <function Trainer._report_to_hp_search at 0x7fe17705f940>
_rotate_checkpoints: <function Trainer._rotate_checkpoints at 0x7fe17705dc10>
_save: <function Trainer._save at 0x7fe17705da60>
_save_checkpoint: <function Trainer._save_checkpoint at 0x7fe17705d160>
_save_optimizer_and_scheduler: <function Trainer._save_optimizer_and_scheduler at 0x7fe17705d280>
_save_rng_state: <function Trainer._save_rng_state at 0x7fe17705d1f0>
_save_tpu: <function Trainer._save_tpu at 0x7fe17705d9d0>
_set_signature_columns_if_needed: <function Trainer._set_signature_columns_if_needed at 0x7fe17705f040>
_sorted_checkpoints: <function Trainer._sorted_checkpoints at 0x7fe17705db80>
_tune_save_checkpoint: <function Trainer._tune_save_checkpoint at 0x7fe17705f9d0>
_wrap_model: <function Trainer._wrap_model at 0x7fe17705fc10>
add_callback: <function Trainer.add_callback at 0x7fe17705edc0>
autocast_smart_context_manager: <function Trainer.autocast_smart_context_manager at 0x7fe17705d670>
call_model_init: <function Trainer.call_model_init at 0x7fe17705fa60>
compute_loss: <function Trainer.compute_loss at 0x7fe17705d790>
compute_loss_context_manager: <function Trainer.compute_loss_context_manager at 0x7fe17705d5e0>
create_accelerator_and_postprocess: <function Trainer.create_accelerator_and_postprocess at 0x7fe1770614c0>
create_model_card: <function Trainer.create_model_card at 0x7fe1770610d0>
create_optimizer: <function Trainer.create_optimizer at 0x7fe17705f5e0>
create_optimizer_and_scheduler: <function Trainer.create_optimizer_and_scheduler at 0x7fe17705f4c0>
create_scheduler: <function Trainer.create_scheduler at 0x7fe17705f700>
evaluate: <function Trainer.evaluate at 0x7fe17705dca0>
evaluation_loop: <function Trainer.evaluation_loop at 0x7fe17705ddc0>
floating_point_ops: <function Trainer.floating_point_ops at 0x7fe17705df70>
get_decay_parameter_names: <function Trainer.get_decay_parameter_names at 0x7fe17705f550>
get_eval_dataloader: <function Trainer.get_eval_dataloader at 0x7fe17705f3a0>
get_optimizer_cls_and_kwargs: <function Trainer.get_optimizer_cls_and_kwargs at 0x7fe17705f670>
get_test_dataloader: <function Trainer.get_test_dataloader at 0x7fe17705f430>
get_train_dataloader: <function Trainer.get_train_dataloader at 0x7fe17705f280>
hyperparameter_search: <function Trainer.hyperparameter_search at 0x7fe17705d3a0>
init_hf_repo: <function Trainer.init_hf_repo at 0x7fe177061040>
ipex_optimize_model: <function Trainer.ipex_optimize_model at 0x7fe17705fb80>
is_local_process_zero: <function Trainer.is_local_process_zero at 0x7fe17705d820>
is_world_process_zero: <function Trainer.is_world_process_zero at 0x7fe17705d8b0>
log: <function Trainer.log at 0x7fe17705d430>
log_metrics: <function log_metrics at 0x7fe1946c93a0>
metrics_format: <function metrics_format at 0x7fe1946c9310>
num_examples: <function Trainer.num_examples at 0x7fe17705f790>
num_tokens: <function Trainer.num_tokens at 0x7fe17705f820>
pop_callback: <function Trainer.pop_callback at 0x7fe17705ee50>
predict: <function Trainer.predict at 0x7fe17705dd30>
prediction_loop: <function Trainer.prediction_loop at 0x7fe177061310>
prediction_step: <function Trainer.prediction_step at 0x7fe17705dee0>
propagate_args_to_deepspeed: <function Trainer.propagate_args_to_deepspeed at 0x7fe177061550>
push_to_hub: <function Trainer.push_to_hub at 0x7fe177061280>
remove_callback: <function Trainer.remove_callback at 0x7fe17705eee0>
save_metrics: <function save_metrics at 0x7fe1946c9430>
save_model: <function Trainer.save_model at 0x7fe17705d940>
save_state: <function save_state at 0x7fe1946c94c0>
store_flos: <function Trainer.store_flos at 0x7fe17705daf0>
torch_jit_model_eval: <function Trainer.torch_jit_model_eval at 0x7fe17705faf0>
train: <function Trainer.train at 0x7fe17705fca0>
training_step: <function Trainer.training_step at 0x7fe17705d700>
Retrieve All Methods of a Class#
There are a few ways to do it.
Using __dict__
#
child_class_methods_using_dict = list(ChildClass.__dict__.keys())
pprint(sorted(child_class_methods_using_dict))
assert 'parent_method' not in child_class_methods_using_dict
assert 'read_only_attr' in child_class_methods_using_dict
assert 'class_method' in child_class_methods_using_dict
[ │ '_ChildClass__private_attr', │ '__doc__', │ '__init__', │ '__module__', │ '__str__', │ '_protected_attr', │ 'class_attr', │ 'class_method', │ 'instance_method', │ 'read_only_attr', │ 'static_method' ]
Notice that the parent class methods are not included!
pprint(instance_child.__class__.__dict__.keys() == ChildClass.__dict__.keys())
True
Using vars
#
vars
and __dict__
are equivalent, but people are preferring the former due
to some efficiency reasons, which can be found in this post.
child_class_methods_using_vars = list(vars(ChildClass).keys())
pprint(sorted(child_class_methods_using_vars))
assert 'parent_method' not in child_class_methods_using_vars
assert 'read_only_attr' in child_class_methods_using_vars
assert 'class_method' in child_class_methods_using_vars
assert set(child_class_methods_using_dict) == set(child_class_methods_using_vars)
[ │ '_ChildClass__private_attr', │ '__doc__', │ '__init__', │ '__module__', │ '__str__', │ '_protected_attr', │ 'class_attr', │ 'class_method', │ 'instance_method', │ 'read_only_attr', │ 'static_method' ]
Using dir
#
To include the base/parent class methods, we can use dir
instead.
child_class_methods_using_dir = list(dir(ChildClass))
pprint(sorted(child_class_methods_using_dir))
assert 'parent_method' in child_class_methods_using_dir
assert 'read_only_attr' in child_class_methods_using_dir
assert 'class_method' in child_class_methods_using_dir
[ │ '_ChildClass__private_attr', │ '__class__', │ '__delattr__', │ '__dict__', │ '__dir__', │ '__doc__', │ '__eq__', │ '__format__', │ '__ge__', │ '__getattribute__', │ '__gt__', │ '__hash__', │ '__init__', │ '__init_subclass__', │ '__le__', │ '__lt__', │ '__module__', │ '__ne__', │ '__new__', │ '__reduce__', │ '__reduce_ex__', │ '__repr__', │ '__setattr__', │ '__sizeof__', │ '__str__', │ '__subclasshook__', │ '__weakref__', │ '_protected_attr', │ 'class_attr', │ 'class_method', │ 'instance_method', │ 'parent_class_attr', │ 'parent_method', │ 'read_only_attr', │ 'static_method' ]
Using inspect.getmembers
#
We use inspect.getmembers
to get all members of a class, and then filter out
via the predicate inspect.isroutine
, a stronger filter than inspect.isfunction
or inspect.ismethod
.
We attach the source code of inspect.isroutine
here for reference.
def isroutine(object):
"""Return true if the object is any kind of function or method."""
return (isbuiltin(object)
or isfunction(object)
or ismethod(object)
or ismethoddescriptor(object))
predicate = inspect.isroutine
child_class_methods_using_getmembers = list(get_members_of_function_or_method(ChildClass, predicate=predicate))
pprint(sorted(child_class_methods_using_getmembers))
[ │ ('__delattr__', <slot wrapper '__delattr__' of 'object' objects>), │ ('__dir__', <method '__dir__' of 'object' objects>), │ ('__eq__', <slot wrapper '__eq__' of 'object' objects>), │ ('__format__', <method '__format__' of 'object' objects>), │ ('__ge__', <slot wrapper '__ge__' of 'object' objects>), │ ('__getattribute__', <slot wrapper '__getattribute__' of 'object' objects>), │ ('__gt__', <slot wrapper '__gt__' of 'object' objects>), │ ('__hash__', <slot wrapper '__hash__' of 'object' objects>), │ ('__init__', <function ChildClass.__init__ at 0x7fe1770619d0>), │ ('__init_subclass__', <built-in method __init_subclass__ of type object at 0x55f0d55ad280>), │ ('__le__', <slot wrapper '__le__' of 'object' objects>), │ ('__lt__', <slot wrapper '__lt__' of 'object' objects>), │ ('__ne__', <slot wrapper '__ne__' of 'object' objects>), │ ('__new__', <built-in method __new__ of type object at 0x7fe28cf70840>), │ ('__reduce__', <method '__reduce__' of 'object' objects>), │ ('__reduce_ex__', <method '__reduce_ex__' of 'object' objects>), │ ('__repr__', <slot wrapper '__repr__' of 'object' objects>), │ ('__setattr__', <slot wrapper '__setattr__' of 'object' objects>), │ ('__sizeof__', <method '__sizeof__' of 'object' objects>), │ ('__str__', <function ChildClass.__str__ at 0x7fe177061ca0>), │ ('__subclasshook__', <built-in method __subclasshook__ of type object at 0x55f0d55ad280>), │ ('class_method', <bound method ChildClass.class_method of <class '__main__.ChildClass'>>), │ ('instance_method', <function ChildClass.instance_method at 0x7fe177061af0>), │ ('parent_method', <function ParentClass.parent_method at 0x7fe177061940>), │ ('static_method', <function ChildClass.static_method at 0x7fe177061c10>) ]
Of course, the reason to retrieve all methods is a convenience if we want to inspect all methods at once. And if we can obtain all methods, we can then iteratively inspect each method’s signature.
Method Resolution Order#
The above examples do not take into account complicated cases, such as when the class is a subclass of multiple classes, in which case if you just print out the methods of the class, you will have a hard time to know which methods are from which class. You can do so via more filtering, but this is beyond the scope of this notebook.
predicate = inspect.isroutine
GPT2LMHeadModel_methods_using_getmembers = list(get_members_of_function_or_method(GPT2LMHeadModel, predicate=predicate))
pprint(sorted(GPT2LMHeadModel_methods_using_getmembers))
[ │ ('__call__', <function Module._wrapped_call_impl at 0x7fe25c526160>), │ ('__delattr__', <function Module.__delattr__ at 0x7fe25c5264c0>), │ ('__dir__', <function Module.__dir__ at 0x7fe25c52a4c0>), │ ('__eq__', <slot wrapper '__eq__' of 'object' objects>), │ ('__format__', <method '__format__' of 'object' objects>), │ ('__ge__', <slot wrapper '__ge__' of 'object' objects>), │ ('__getattr__', <function Module.__getattr__ at 0x7fe25c5263a0>), │ ('__getattribute__', <slot wrapper '__getattribute__' of 'object' objects>), │ ('__getstate__', <function Module.__getstate__ at 0x7fe25c526280>), │ ('__gt__', <slot wrapper '__gt__' of 'object' objects>), │ ('__hash__', <slot wrapper '__hash__' of 'object' objects>), │ ('__init__', <function GPT2LMHeadModel.__init__ at 0x7fe194af7430>), │ ('__init_subclass__', <built-in method __init_subclass__ of type object at 0x55f0d559d500>), │ ('__le__', <slot wrapper '__le__' of 'object' objects>), │ ('__lt__', <slot wrapper '__lt__' of 'object' objects>), │ ('__ne__', <slot wrapper '__ne__' of 'object' objects>), │ ('__new__', <built-in method __new__ of type object at 0x7fe28cf70840>), │ ('__reduce__', <method '__reduce__' of 'object' objects>), │ ('__reduce_ex__', <method '__reduce_ex__' of 'object' objects>), │ ('__repr__', <function Module.__repr__ at 0x7fe25c52a430>), │ ('__setattr__', <function Module.__setattr__ at 0x7fe25c526430>), │ ('__setstate__', <function Module.__setstate__ at 0x7fe25c526310>), │ ('__sizeof__', <method '__sizeof__' of 'object' objects>), │ ('__str__', <slot wrapper '__str__' of 'object' objects>), │ ('__subclasshook__', <built-in method __subclasshook__ of type object at 0x55f0d559d500>), │ ('_apply', <function Module._apply at 0x7fe25c522430>), │ ( │ │ '_autoset_attn_implementation', │ │ <bound method PreTrainedModel._autoset_attn_implementation of <class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'>> │ ), │ ( │ │ '_backward_compatibility_gradient_checkpointing', │ │ <function PreTrainedModel._backward_compatibility_gradient_checkpointing at 0x7fe194b8fdc0> │ ), │ ('_call_impl', <function Module._call_impl at 0x7fe25c5261f0>), │ ( │ │ '_check_and_enable_flash_attn_2', │ │ <bound method PreTrainedModel._check_and_enable_flash_attn_2 of <class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'>> │ ), │ ( │ │ '_check_and_enable_sdpa', │ │ <bound method PreTrainedModel._check_and_enable_sdpa of <class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'>> │ ), │ ('_convert_head_mask_to_5d', <function ModuleUtilsMixin._convert_head_mask_to_5d at 0x7fe194b8f940>), │ ( │ │ '_copy_lm_head_original_to_resized', │ │ <function PreTrainedModel._copy_lm_head_original_to_resized at 0x7fe194b92b80> │ ), │ ('_create_repo', <function PushToHubMixin._create_repo at 0x7fe196dccf70>), │ ('_dispatch_accelerate_model', <function PeftAdapterMixin._dispatch_accelerate_model at 0x7fe194be04c0>), │ ('_expand_inputs_for_generation', <function GenerationMixin._expand_inputs_for_generation at 0x7fe194bd0a60>), │ ( │ │ '_extract_past_from_model_output', │ │ <function GenerationMixin._extract_past_from_model_output at 0x7fe194bd0af0> │ ), │ ( │ │ '_from_config', │ │ <bound method PreTrainedModel._from_config of <class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'>> │ ), │ ('_get_backward_hooks', <function Module._get_backward_hooks at 0x7fe25c522dc0>), │ ('_get_backward_pre_hooks', <function Module._get_backward_pre_hooks at 0x7fe25c522e50>), │ ('_get_candidate_generator', <function GenerationMixin._get_candidate_generator at 0x7fe194bd0ca0>), │ ('_get_decoder_start_token_id', <function GenerationMixin._get_decoder_start_token_id at 0x7fe194bd09d0>), │ ('_get_files_timestamps', <function PushToHubMixin._get_files_timestamps at 0x7fe196df7040>), │ ('_get_generation_mode', <function GenerationMixin._get_generation_mode at 0x7fe194bd0dc0>), │ ('_get_logits_processor', <function GenerationMixin._get_logits_processor at 0x7fe194bd0e50>), │ ('_get_logits_warper', <function GenerationMixin._get_logits_warper at 0x7fe194bd0d30>), │ ('_get_name', <function Module._get_name at 0x7fe25c52a310>), │ ('_get_no_split_modules', <function PreTrainedModel._get_no_split_modules at 0x7fe194b928b0>), │ ('_get_resized_embeddings', <function PreTrainedModel._get_resized_embeddings at 0x7fe194b92a60>), │ ('_get_resized_lm_head', <function PreTrainedModel._get_resized_lm_head at 0x7fe194b92af0>), │ ('_get_stopping_criteria', <function GenerationMixin._get_stopping_criteria at 0x7fe194bd0ee0>), │ ( │ │ '_hook_rss_memory_post_forward', │ │ <function ModuleUtilsMixin._hook_rss_memory_post_forward at 0x7fe194b8f430> │ ), │ ('_hook_rss_memory_pre_forward', <function ModuleUtilsMixin._hook_rss_memory_pre_forward at 0x7fe194b8f3a0>), │ ('_init_weights', <function GPT2PreTrainedModel._init_weights at 0x7fe194af3ca0>), │ ('_initialize_weights', <function PreTrainedModel._initialize_weights at 0x7fe194b92670>), │ ('_load_from_state_dict', <function Module._load_from_state_dict at 0x7fe25c5269d0>), │ ( │ │ '_load_pretrained_model', │ │ <bound method PreTrainedModel._load_pretrained_model of <class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'>> │ ), │ ( │ │ '_load_pretrained_model_low_mem', │ │ <function PreTrainedModel._load_pretrained_model_low_mem at 0x7fe194b94670> │ ), │ ( │ │ '_maybe_initialize_input_ids_for_generation', │ │ <function GenerationMixin._maybe_initialize_input_ids_for_generation at 0x7fe194bd0790> │ ), │ ( │ │ '_maybe_warn_non_full_backward_hook', │ │ <function Module._maybe_warn_non_full_backward_hook at 0x7fe25c522ee0> │ ), │ ( │ │ '_merge_criteria_processor_list', │ │ <function GenerationMixin._merge_criteria_processor_list at 0x7fe194bd0f70> │ ), │ ('_named_members', <function Module._named_members at 0x7fe25c526af0>), │ ( │ │ '_prepare_attention_mask_for_generation', │ │ <function GenerationMixin._prepare_attention_mask_for_generation at 0x7fe194bd0820> │ ), │ ( │ │ '_prepare_decoder_input_ids_for_generation', │ │ <function GenerationMixin._prepare_decoder_input_ids_for_generation at 0x7fe194bd0940> │ ), │ ( │ │ '_prepare_encoder_decoder_kwargs_for_generation', │ │ <function GenerationMixin._prepare_encoder_decoder_kwargs_for_generation at 0x7fe194bd08b0> │ ), │ ('_prepare_model_inputs', <function GenerationMixin._prepare_model_inputs at 0x7fe194bd0700>), │ ( │ │ '_register_load_state_dict_pre_hook', │ │ <function Module._register_load_state_dict_pre_hook at 0x7fe25c526820> │ ), │ ('_register_state_dict_hook', <function Module._register_state_dict_hook at 0x7fe25c526550>), │ ('_reorder_cache', <function GPT2LMHeadModel._reorder_cache at 0x7fe194af7820>), │ ('_replicate_for_data_parallel', <function Module._replicate_for_data_parallel at 0x7fe25c52a550>), │ ('_resize_token_embeddings', <function PreTrainedModel._resize_token_embeddings at 0x7fe194b929d0>), │ ('_save_to_state_dict', <function Module._save_to_state_dict at 0x7fe25c526700>), │ ( │ │ '_set_default_torch_dtype', │ │ <bound method PreTrainedModel._set_default_torch_dtype of <class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'>> │ ), │ ('_set_gradient_checkpointing', <function PreTrainedModel._set_gradient_checkpointing at 0x7fe194b92ee0>), │ ('_slow_forward', <function Module._slow_forward at 0x7fe25c5260d0>), │ ('_temporary_reorder_cache', <function GenerationMixin._temporary_reorder_cache at 0x7fe194bdb5e0>), │ ('_tie_encoder_decoder_weights', <function PreTrainedModel._tie_encoder_decoder_weights at 0x7fe194b92790>), │ ('_tie_or_clone_weights', <function PreTrainedModel._tie_or_clone_weights at 0x7fe194b92820>), │ ( │ │ '_update_model_kwargs_for_generation', │ │ <function GenerationMixin._update_model_kwargs_for_generation at 0x7fe194bd0b80> │ ), │ ('_upload_modified_files', <function PushToHubMixin._upload_modified_files at 0x7fe196df70d0>), │ ('_validate_generated_length', <function GenerationMixin._validate_generated_length at 0x7fe194bdb1f0>), │ ('_validate_model_class', <function GenerationMixin._validate_model_class at 0x7fe194bdb0d0>), │ ('_validate_model_kwargs', <function GenerationMixin._validate_model_kwargs at 0x7fe194bdb160>), │ ('_wrapped_call_impl', <function Module._wrapped_call_impl at 0x7fe25c526160>), │ ('active_adapter', <function PeftAdapterMixin.active_adapter at 0x7fe194be03a0>), │ ('active_adapters', <function PeftAdapterMixin.active_adapters at 0x7fe194be0310>), │ ('add_adapter', <function PeftAdapterMixin.add_adapter at 0x7fe194be00d0>), │ ('add_memory_hooks', <function ModuleUtilsMixin.add_memory_hooks at 0x7fe194b8f4c0>), │ ('add_model_tags', <function PreTrainedModel.add_model_tags at 0x7fe194b8fe50>), │ ('add_module', <function Module.add_module at 0x7fe25c520f70>), │ ('apply', <function Module.apply at 0x7fe25c5224c0>), │ ('assisted_decoding', <function GenerationMixin.assisted_decoding at 0x7fe194bdb8b0>), │ ('beam_sample', <function GenerationMixin.beam_sample at 0x7fe194bdb700>), │ ('beam_search', <function GenerationMixin.beam_search at 0x7fe194bdb670>), │ ('bfloat16', <function Module.bfloat16 at 0x7fe25c522a60>), │ ('buffers', <function Module.buffers at 0x7fe25c526ca0>), │ ( │ │ 'can_generate', │ │ <bound method PreTrainedModel.can_generate of <class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'>> │ ), │ ('children', <function Module.children at 0x7fe25c526dc0>), │ ('compile', <function Module.compile at 0x7fe25c52a5e0>), │ ('compute_transition_scores', <function GenerationMixin.compute_transition_scores at 0x7fe194bdb040>), │ ('constrained_beam_search', <function GenerationMixin.constrained_beam_search at 0x7fe194bdb820>), │ ('contrastive_search', <function GenerationMixin.contrastive_search at 0x7fe194bdb430>), │ ('cpu', <function Module.cpu at 0x7fe25c522790>), │ ( │ │ 'create_extended_attention_mask_for_decoder', │ │ <function ModuleUtilsMixin.create_extended_attention_mask_for_decoder at 0x7fe194b8f790> │ ), │ ('cuda', <function Module.cuda at 0x7fe194b94280>), │ ('deparallelize', <function GPT2LMHeadModel.deparallelize at 0x7fe194af7670>), │ ('disable_adapters', <function PeftAdapterMixin.disable_adapters at 0x7fe194be01f0>), │ ('disable_input_require_grads', <function PreTrainedModel.disable_input_require_grads at 0x7fe194b923a0>), │ ('double', <function Module.double at 0x7fe25c522940>), │ ('enable_adapters', <function PeftAdapterMixin.enable_adapters at 0x7fe194be0280>), │ ('enable_input_require_grads', <function PreTrainedModel.enable_input_require_grads at 0x7fe194b92310>), │ ('estimate_tokens', <function ModuleUtilsMixin.estimate_tokens at 0x7fe194b8fa60>), │ ('eval', <function Module.eval at 0x7fe25c52a0d0>), │ ('extra_repr', <function Module.extra_repr at 0x7fe25c52a3a0>), │ ('float', <function PreTrainedModel.float at 0x7fe194b94430>), │ ('floating_point_ops', <function ModuleUtilsMixin.floating_point_ops at 0x7fe194b8faf0>), │ ('forward', <function GPT2LMHeadModel.forward at 0x7fe194af7940>), │ ( │ │ 'from_pretrained', │ │ <bound method PreTrainedModel.from_pretrained of <class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'>> │ ), │ ('generate', <function GenerationMixin.generate at 0x7fe194bdb310>), │ ('get_adapter_state_dict', <function PeftAdapterMixin.get_adapter_state_dict at 0x7fe194be0430>), │ ('get_buffer', <function Module.get_buffer at 0x7fe25c522280>), │ ('get_extended_attention_mask', <function ModuleUtilsMixin.get_extended_attention_mask at 0x7fe194b8f820>), │ ('get_extra_state', <function Module.get_extra_state at 0x7fe25c522310>), │ ('get_head_mask', <function ModuleUtilsMixin.get_head_mask at 0x7fe194b8f8b0>), │ ('get_input_embeddings', <function PreTrainedModel.get_input_embeddings at 0x7fe194b92430>), │ ('get_memory_footprint', <function PreTrainedModel.get_memory_footprint at 0x7fe194b941f0>), │ ('get_output_embeddings', <function GPT2LMHeadModel.get_output_embeddings at 0x7fe194af74c0>), │ ('get_parameter', <function Module.get_parameter at 0x7fe25c5221f0>), │ ('get_position_embeddings', <function PreTrainedModel.get_position_embeddings at 0x7fe194b92ca0>), │ ('get_submodule', <function Module.get_submodule at 0x7fe25c5220d0>), │ ( │ │ 'gradient_checkpointing_disable', │ │ <function PreTrainedModel.gradient_checkpointing_disable at 0x7fe194b92f70> │ ), │ ('gradient_checkpointing_enable', <function PreTrainedModel.gradient_checkpointing_enable at 0x7fe194b92e50>), │ ('greedy_search', <function GenerationMixin.greedy_search at 0x7fe194bdb4c0>), │ ('group_beam_search', <function GenerationMixin.group_beam_search at 0x7fe194bdb790>), │ ('half', <function PreTrainedModel.half at 0x7fe194b943a0>), │ ('init_weights', <function PreTrainedModel.init_weights at 0x7fe194b92d30>), │ ('invert_attention_mask', <function ModuleUtilsMixin.invert_attention_mask at 0x7fe194b8f700>), │ ('ipu', <function Module.ipu at 0x7fe25c5225e0>), │ ('load_adapter', <function PeftAdapterMixin.load_adapter at 0x7fe194be0040>), │ ('load_state_dict', <function Module.load_state_dict at 0x7fe25c526a60>), │ ('load_tf_weights', <function load_tf_weights_in_gpt2 at 0x7fe1965791f0>), │ ('modules', <function Module.modules at 0x7fe25c526ee0>), │ ('mtia', <function Module.mtia at 0x7fe25c522700>), │ ('named_buffers', <function Module.named_buffers at 0x7fe25c526d30>), │ ('named_children', <function Module.named_children at 0x7fe25c526e50>), │ ('named_modules', <function Module.named_modules at 0x7fe25c526f70>), │ ('named_parameters', <function Module.named_parameters at 0x7fe25c526c10>), │ ('num_parameters', <function ModuleUtilsMixin.num_parameters at 0x7fe194b8f9d0>), │ ('parallelize', <function GPT2LMHeadModel.parallelize at 0x7fe194af75e0>), │ ('parameters', <function Module.parameters at 0x7fe25c526b80>), │ ('post_init', <function PreTrainedModel.post_init at 0x7fe194b8fd30>), │ ('prepare_inputs_for_generation', <function GPT2LMHeadModel.prepare_inputs_for_generation at 0x7fe194af7790>), │ ('prune_heads', <function PreTrainedModel.prune_heads at 0x7fe194b92dc0>), │ ('push_to_hub', <function PushToHubMixin.push_to_hub at 0x7fe194b8f310>), │ ('register_backward_hook', <function Module.register_backward_hook at 0x7fe25c522ca0>), │ ('register_buffer', <function Module.register_buffer at 0x7fe25c520e50>), │ ( │ │ 'register_for_auto_class', │ │ <bound method PreTrainedModel.register_for_auto_class of <class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'>> │ ), │ ('register_forward_hook', <function Module.register_forward_hook at 0x7fe25c526040>), │ ('register_forward_pre_hook', <function Module.register_forward_pre_hook at 0x7fe25c522f70>), │ ('register_full_backward_hook', <function Module.register_full_backward_hook at 0x7fe25c522d30>), │ ('register_full_backward_pre_hook', <function Module.register_full_backward_pre_hook at 0x7fe25c522c10>), │ ( │ │ 'register_load_state_dict_post_hook', │ │ <function Module.register_load_state_dict_post_hook at 0x7fe25c526940> │ ), │ ('register_load_state_dict_pre_hook', <function Module.register_load_state_dict_pre_hook at 0x7fe25c5268b0>), │ ('register_module', <function Module.register_module at 0x7fe25c522040>), │ ('register_parameter', <function Module.register_parameter at 0x7fe25c520ee0>), │ ('register_state_dict_post_hook', <function Module.register_state_dict_post_hook at 0x7fe25c5265e0>), │ ('register_state_dict_pre_hook', <function Module.register_state_dict_pre_hook at 0x7fe25c526670>), │ ('requires_grad_', <function Module.requires_grad_ at 0x7fe25c52a160>), │ ('reset_memory_hooks_state', <function ModuleUtilsMixin.reset_memory_hooks_state at 0x7fe194b8f550>), │ ('resize_position_embeddings', <function PreTrainedModel.resize_position_embeddings at 0x7fe194b92c10>), │ ('resize_token_embeddings', <function PreTrainedModel.resize_token_embeddings at 0x7fe194b92940>), │ ('retrieve_modules_from_names', <function PreTrainedModel.retrieve_modules_from_names at 0x7fe194b945e0>), │ ('reverse_bettertransformer', <function PreTrainedModel.reverse_bettertransformer at 0x7fe194b94820>), │ ('sample', <function GenerationMixin.sample at 0x7fe194bdb550>), │ ('save_pretrained', <function PreTrainedModel.save_pretrained at 0x7fe194b940d0>), │ ('set_adapter', <function PeftAdapterMixin.set_adapter at 0x7fe194be0160>), │ ('set_extra_state', <function Module.set_extra_state at 0x7fe25c5223a0>), │ ('set_input_embeddings', <function PreTrainedModel.set_input_embeddings at 0x7fe194b924c0>), │ ('set_output_embeddings', <function GPT2LMHeadModel.set_output_embeddings at 0x7fe194af7700>), │ ('set_submodule', <function Module.set_submodule at 0x7fe25c522160>), │ ('share_memory', <function Module.share_memory at 0x7fe25c52a280>), │ ('state_dict', <function Module.state_dict at 0x7fe25c526790>), │ ('tie_weights', <function PreTrainedModel.tie_weights at 0x7fe194b92700>), │ ('to', <function Module.to at 0x7fe194b94310>), │ ('to_bettertransformer', <function PreTrainedModel.to_bettertransformer at 0x7fe194b94790>), │ ('to_empty', <function Module.to_empty at 0x7fe25c522af0>), │ ('train', <function Module.train at 0x7fe25c52a040>), │ ('type', <function Module.type at 0x7fe25c522820>), │ ( │ │ 'warn_if_padding_and_no_attention_mask', │ │ <function PreTrainedModel.warn_if_padding_and_no_attention_mask at 0x7fe194b948b0> │ ), │ ('xpu', <function Module.xpu at 0x7fe25c522670>), │ ('zero_grad', <function Module.zero_grad at 0x7fe25c52a1f0>) ]
You can get the method resolution order (MRO) of a class via cls.__mro__
.
inspect.getmro(GPT2LMHeadModel)
(transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel,
transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel,
transformers.modeling_utils.PreTrainedModel,
torch.nn.modules.module.Module,
transformers.modeling_utils.ModuleUtilsMixin,
transformers.generation.utils.GenerationMixin,
transformers.utils.hub.PushToHubMixin,
transformers.integrations.peft.PeftAdapterMixin,
object)
A pseudocode to get all signatures of a class via MRO is as follows:
def get_all_args(cls: Type[object]) -> Dict[str, inspect.Parameter]:
mro = inspect.getmro(cls)
all_args = {}
for base_class in mro[::-1]: # reverse to start from topmost class
if base_class is object: # skip the base 'object' class
continue
sig = inspect.signature(base_class.__init__)
all_args.update(sig.parameters)
return all_args
Get Class and Instance Attributes#
pprint(list(class_child.__dict__.keys())) # class attributes
pprint(list(instance_child.__dict__.keys())) # instance attributes
[ │ '__module__', │ '__doc__', │ 'class_attr', │ '_protected_attr', │ '_ChildClass__private_attr', │ '__init__', │ 'read_only_attr', │ 'instance_method', │ 'class_method', │ 'static_method', │ '__str__' ]
['parent_instance_attr', 'instance_attr', 'instance_not_in_constructor_attr', '_private_instance_attr']
union_class_and_instance_attributes = list(set(class_child.__dict__.keys()).union(set(instance_child.__dict__.keys())))
pprint(union_class_and_instance_attributes)
[ │ '__str__', │ '__init__', │ '_ChildClass__private_attr', │ 'read_only_attr', │ '_protected_attr', │ '_private_instance_attr', │ 'class_method', │ 'class_attr', │ '__module__', │ 'instance_not_in_constructor_attr', │ 'static_method', │ '__doc__', │ 'instance_method', │ 'parent_instance_attr', │ 'instance_attr' ]
Get Signature and Type Annotations of a Function#
func_sig: Signature = inspect.signature(func)
pprint(func_sig.parameters)
pprint(func_sig.return_annotation)
mappingproxy({ │ 'a': <Parameter "a: int">, │ 'b': <Parameter "b: str">, │ 'c': <Parameter "c: List[int]">, │ 'd': <Parameter "d: Tuple[str, str]">, │ 'e': <Parameter "e: Union[int, str]">, │ 'kwargs': <Parameter "**kwargs: Any"> })
<class 'str'>
Here are the 4 key properties of the Parameter
object
of the Signature
object.
@property
def name(self):
return self._name
@property
def default(self):
return self._default
@property
def annotation(self):
return self._annotation
@property
def kind(self):
return self._kind
We will also use get_type_hints
to get the type hints of a function
instead of using the annotations
property of inspect.Signature
. The reason
can be found in the docstring of get_type_hints
:
def get_type_hints(obj, globalns=None, localns=None, include_extras=False):
"""Return type hints for an object.
This is often the same as obj.__annotations__, but it handles
forward references encoded as string literals, adds Optional[t] if a
default value equal to None is set and recursively replaces all
'Annotated[T, ...]' with 'T' (unless 'include_extras=True').
...
"""
def no_type_hints(a, b, c, d, e, **kwargs):
return a, b, c, d, e, kwargs
get_type_hints(no_type_hints), inspect.signature(no_type_hints)
({}, <Signature (a, b, c, d, e, **kwargs)>)
How to know if a parameter is optional or not? We can use the inspect.Parameter.empty
property.
for name, value in inspect.signature(func).parameters.items():
print(value.default)
print(value.default is inspect.Parameter.empty)
<class 'inspect._empty'>
True
<class 'inspect._empty'>
True
<class 'inspect._empty'>
True
<class 'inspect._empty'>
True
<class 'inspect._empty'>
True
<class 'inspect._empty'>
True
We will also use __mro__
to get the method resolution order of a class
because __bases__
only returns the immediate parent class.
ChildClass.__bases__, GPT2LMHeadModel.__bases__[0].__bases__[0].__bases__
((__main__.ParentClass,),
(torch.nn.modules.module.Module,
transformers.modeling_utils.ModuleUtilsMixin,
transformers.generation.utils.GenerationMixin,
transformers.utils.hub.PushToHubMixin,
transformers.integrations.peft.PeftAdapterMixin))
list(reversed(inspect.getmro(GPT2LMHeadModel)))
[object,
transformers.integrations.peft.PeftAdapterMixin,
transformers.utils.hub.PushToHubMixin,
transformers.generation.utils.GenerationMixin,
transformers.modeling_utils.ModuleUtilsMixin,
torch.nn.modules.module.Module,
transformers.modeling_utils.PreTrainedModel,
transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel,
transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel]
def get_base_classes(cls: Type[Any], include_self: bool = False) -> Set[Type[Any]]:
"""
Get the base classes of a class and all its base classes.
"""
return set(cls.__mro__[0:-1] if include_self else cls.__mro__[1:-1])
pprint(get_base_classes(GPT2LMHeadModel, include_self=True))
{ │ <class 'transformers.modeling_utils.ModuleUtilsMixin'>, │ <class 'transformers.generation.utils.GenerationMixin'>, │ <class 'transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel'>, │ <class 'transformers.utils.hub.PushToHubMixin'>, │ <class 'torch.nn.modules.module.Module'>, │ <class 'transformers.integrations.peft.PeftAdapterMixin'>, │ <class 'transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel'>, │ <class 'transformers.modeling_utils.PreTrainedModel'> }
def get_default(param: Parameter) -> Any:
"""Return the parameter's default value or None if not specified."""
return param.default if param.default is not param.empty else None
def get_field_annotations(func_or_method: Callable[..., Any]) -> Tuple[List[Tuple[str, Any, Any]], Dict[str, Any]]:
if not inspect.isroutine(func_or_method):
raise ValueError("Expected a function or method")
required_fields = []
optional_fields = []
annotations = {}
try:
sig: Signature = inspect.signature(func_or_method)
type_hints: Dict[str, Any] = get_type_hints(func_or_method)
except ValueError:
raise ValueError("Object does not support signature or type hints extraction.") from None
for name, param in sig.parameters.items():
if name == "self":
continue
type_hint = type_hints.get(name, Any)
annotations[name] = type_hint
if param.default is param.empty:
required_fields.append((name, type_hint, Ellipsis))
else:
default_value = param.default
optional_fields.append((name, type_hint, default_value))
fields = required_fields + optional_fields
return fields, annotations
# TODO: Tuple[str, Any, Any] should be Tuple[str, Any, ellipsis]
def get_constructor_field_annotations(
cls: Type[Any], include_bases: bool = True
) -> Tuple[List[Tuple[str, Any, Any]], Dict[str, Any]]:
fields = []
annotations = {}
classes_to_inspect = [cls] + list(get_base_classes(cls, include_self=False)) if include_bases else [cls]
for c in reversed(classes_to_inspect): # Reverse to respect MRO
if hasattr(c, "__init__"):
class_fields, class_annotations = get_field_annotations(c.__init__)
# Update fields and annotations with those from the current class,
# avoiding duplicates.
for field in class_fields:
if field[0] not in annotations:
fields.append(field) # noqa: PERF401
annotations.update(class_annotations)
return fields, annotations
fields, annotations = get_constructor_field_annotations(TrainingArguments, include_bases=False)
for field in fields:
assert len(field) == 3
print(f"{field[0]}, {field[1]}, {field[2]}")
assert get_field_annotations(TrainingArguments.__init__) == (fields, annotations)
output_dir, <class 'str'>, Ellipsis
overwrite_output_dir, <class 'bool'>, False
do_train, <class 'bool'>, False
do_eval, <class 'bool'>, False
do_predict, <class 'bool'>, False
evaluation_strategy, typing.Union[transformers.trainer_utils.IntervalStrategy, str], no
prediction_loss_only, <class 'bool'>, False
per_device_train_batch_size, <class 'int'>, 8
per_device_eval_batch_size, <class 'int'>, 8
per_gpu_train_batch_size, typing.Optional[int], None
per_gpu_eval_batch_size, typing.Optional[int], None
gradient_accumulation_steps, <class 'int'>, 1
eval_accumulation_steps, typing.Optional[int], None
eval_delay, typing.Optional[float], 0
learning_rate, <class 'float'>, 5e-05
weight_decay, <class 'float'>, 0.0
adam_beta1, <class 'float'>, 0.9
adam_beta2, <class 'float'>, 0.999
adam_epsilon, <class 'float'>, 1e-08
max_grad_norm, <class 'float'>, 1.0
num_train_epochs, <class 'float'>, 3.0
max_steps, <class 'int'>, -1
lr_scheduler_type, typing.Union[transformers.trainer_utils.SchedulerType, str], linear
lr_scheduler_kwargs, typing.Optional[typing.Dict], <factory>
warmup_ratio, <class 'float'>, 0.0
warmup_steps, <class 'int'>, 0
log_level, typing.Optional[str], passive
log_level_replica, typing.Optional[str], warning
log_on_each_node, <class 'bool'>, True
logging_dir, typing.Optional[str], None
logging_strategy, typing.Union[transformers.trainer_utils.IntervalStrategy, str], steps
logging_first_step, <class 'bool'>, False
logging_steps, <class 'float'>, 500
logging_nan_inf_filter, <class 'bool'>, True
save_strategy, typing.Union[transformers.trainer_utils.IntervalStrategy, str], steps
save_steps, <class 'float'>, 500
save_total_limit, typing.Optional[int], None
save_safetensors, typing.Optional[bool], True
save_on_each_node, <class 'bool'>, False
save_only_model, <class 'bool'>, False
no_cuda, <class 'bool'>, False
use_cpu, <class 'bool'>, False
use_mps_device, <class 'bool'>, False
seed, <class 'int'>, 42
data_seed, typing.Optional[int], None
jit_mode_eval, <class 'bool'>, False
use_ipex, <class 'bool'>, False
bf16, <class 'bool'>, False
fp16, <class 'bool'>, False
fp16_opt_level, <class 'str'>, O1
half_precision_backend, <class 'str'>, auto
bf16_full_eval, <class 'bool'>, False
fp16_full_eval, <class 'bool'>, False
tf32, typing.Optional[bool], None
local_rank, <class 'int'>, -1
ddp_backend, typing.Optional[str], None
tpu_num_cores, typing.Optional[int], None
tpu_metrics_debug, <class 'bool'>, False
debug, typing.Union[str, typing.List[transformers.debug_utils.DebugOption]],
dataloader_drop_last, <class 'bool'>, False
eval_steps, typing.Optional[float], None
dataloader_num_workers, <class 'int'>, 0
dataloader_prefetch_factor, typing.Optional[int], None
past_index, <class 'int'>, -1
run_name, typing.Optional[str], None
disable_tqdm, typing.Optional[bool], None
remove_unused_columns, typing.Optional[bool], True
label_names, typing.Optional[typing.List[str]], None
load_best_model_at_end, typing.Optional[bool], False
metric_for_best_model, typing.Optional[str], None
greater_is_better, typing.Optional[bool], None
ignore_data_skip, <class 'bool'>, False
fsdp, typing.Union[typing.List[transformers.trainer_utils.FSDPOption], str, NoneType],
fsdp_min_num_params, <class 'int'>, 0
fsdp_config, typing.Union[dict, str, NoneType], None
fsdp_transformer_layer_cls_to_wrap, typing.Optional[str], None
accelerator_config, typing.Optional[str], None
deepspeed, typing.Optional[str], None
label_smoothing_factor, <class 'float'>, 0.0
optim, typing.Union[transformers.training_args.OptimizerNames, str], adamw_torch
optim_args, typing.Optional[str], None
adafactor, <class 'bool'>, False
group_by_length, <class 'bool'>, False
length_column_name, typing.Optional[str], length
report_to, typing.Optional[typing.List[str]], None
ddp_find_unused_parameters, typing.Optional[bool], None
ddp_bucket_cap_mb, typing.Optional[int], None
ddp_broadcast_buffers, typing.Optional[bool], None
dataloader_pin_memory, <class 'bool'>, True
dataloader_persistent_workers, <class 'bool'>, False
skip_memory_metrics, <class 'bool'>, True
use_legacy_prediction_loop, <class 'bool'>, False
push_to_hub, <class 'bool'>, False
resume_from_checkpoint, typing.Optional[str], None
hub_model_id, typing.Optional[str], None
hub_strategy, typing.Union[transformers.trainer_utils.HubStrategy, str], every_save
hub_token, typing.Optional[str], None
hub_private_repo, <class 'bool'>, False
hub_always_push, <class 'bool'>, False
gradient_checkpointing, <class 'bool'>, False
gradient_checkpointing_kwargs, typing.Optional[dict], None
include_inputs_for_metrics, <class 'bool'>, False
fp16_backend, <class 'str'>, auto
push_to_hub_model_id, typing.Optional[str], None
push_to_hub_organization, typing.Optional[str], None
push_to_hub_token, typing.Optional[str], None
mp_parameters, <class 'str'>,
auto_find_batch_size, <class 'bool'>, False
full_determinism, <class 'bool'>, False
torchdynamo, typing.Optional[str], None
ray_scope, typing.Optional[str], last
ddp_timeout, typing.Optional[int], 1800
torch_compile, <class 'bool'>, False
torch_compile_backend, typing.Optional[str], None
torch_compile_mode, typing.Optional[str], None
dispatch_batches, typing.Optional[bool], None
split_batches, typing.Optional[bool], None
include_tokens_per_second, typing.Optional[bool], False
include_num_input_tokens_seen, typing.Optional[bool], False
neftune_noise_alpha, typing.Optional[float], None
Warning: it does not play too well with dataclass
and pydantic
classes
because they have more complex bells and whistles. However, because of the perks
of dataclass
and pydantic
, we can just use
property
like model_fields
to get all fields and their types.
As we can see from above, we did not handle lr_scheduler_kwargs
well:
lr_scheduler_kwargs, typing.Optional[typing.Dict], <factory>
where <factory>
is the default value of the parameter. But it is actually
referring to the default_factory
of the dataclass
field, which can be a default
dict etc.
def type_hint_to_str(type_hint: Any) -> str:
"""
Convert a type hint into its string representation.
"""
if hasattr(type_hint, '__name__'):
return type_hint.__name__
elif hasattr(type_hint, '_name') and type_hint._name is not None:
return str(type_hint._name)
elif type(type_hint) == _GenericAlias: # For Python 3.8+
# Handles complex types, e.g., List[int], Union[str, int]
origin = type_hint_to_str(type_hint.__origin__)
args = ', '.join(type_hint_to_str(arg) for arg in type_hint.__args__)
return f"{origin}[{args}]"
else:
# Fallback for unhandled types
return str(type_hint)
def create_config_class_str(fields: List[Tuple[str, Any, Any]]) -> str:
lines = ["class Config:"]
if not fields:
lines.append(" ...")
else:
init_params = ["self"]
init_body = []
for name, type_hint, default in fields:
type_hint_str = type_hint_to_str(type_hint)
if default is Ellipsis: # Required argument
param_str = f"{name}: {type_hint_str}"
elif default is field:
param_str = f"{name}: {type_hint_str} = field(default_factory=dict)"
else:
default_repr = repr(default) if default is not None else 'None'
param_str = f"{name}: {type_hint_str} = {default_repr}"
init_params.append(param_str)
init_body.append(f" self.{name} = {name}")
lines.append(f" def __init__({', '.join(init_params)}):")
lines.extend(init_body)
return '\n'.join(lines)
config_class_str = create_config_class_str(fields)
print(config_class_str)
class Config:
def __init__(self, output_dir: str, overwrite_output_dir: bool = False, do_train: bool = False, do_eval: bool = False, do_predict: bool = False, evaluation_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no', prediction_loss_only: bool = False, per_device_train_batch_size: int = 8, per_device_eval_batch_size: int = 8, per_gpu_train_batch_size: typing.Optional[int] = None, per_gpu_eval_batch_size: typing.Optional[int] = None, gradient_accumulation_steps: int = 1, eval_accumulation_steps: typing.Optional[int] = None, eval_delay: typing.Optional[float] = 0, learning_rate: float = 5e-05, weight_decay: float = 0.0, adam_beta1: float = 0.9, adam_beta2: float = 0.999, adam_epsilon: float = 1e-08, max_grad_norm: float = 1.0, num_train_epochs: float = 3.0, max_steps: int = -1, lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear', lr_scheduler_kwargs: typing.Optional[typing.Dict] = <factory>, warmup_ratio: float = 0.0, warmup_steps: int = 0, log_level: typing.Optional[str] = 'passive', log_level_replica: typing.Optional[str] = 'warning', log_on_each_node: bool = True, logging_dir: typing.Optional[str] = None, logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps', logging_first_step: bool = False, logging_steps: float = 500, logging_nan_inf_filter: bool = True, save_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps', save_steps: float = 500, save_total_limit: typing.Optional[int] = None, save_safetensors: typing.Optional[bool] = True, save_on_each_node: bool = False, save_only_model: bool = False, no_cuda: bool = False, use_cpu: bool = False, use_mps_device: bool = False, seed: int = 42, data_seed: typing.Optional[int] = None, jit_mode_eval: bool = False, use_ipex: bool = False, bf16: bool = False, fp16: bool = False, fp16_opt_level: str = 'O1', half_precision_backend: str = 'auto', bf16_full_eval: bool = False, fp16_full_eval: bool = False, tf32: typing.Optional[bool] = None, local_rank: int = -1, ddp_backend: typing.Optional[str] = None, tpu_num_cores: typing.Optional[int] = None, tpu_metrics_debug: bool = False, debug: typing.Union[str, typing.List[transformers.debug_utils.DebugOption]] = '', dataloader_drop_last: bool = False, eval_steps: typing.Optional[float] = None, dataloader_num_workers: int = 0, dataloader_prefetch_factor: typing.Optional[int] = None, past_index: int = -1, run_name: typing.Optional[str] = None, disable_tqdm: typing.Optional[bool] = None, remove_unused_columns: typing.Optional[bool] = True, label_names: typing.Optional[typing.List[str]] = None, load_best_model_at_end: typing.Optional[bool] = False, metric_for_best_model: typing.Optional[str] = None, greater_is_better: typing.Optional[bool] = None, ignore_data_skip: bool = False, fsdp: typing.Union[typing.List[transformers.trainer_utils.FSDPOption], str, NoneType] = '', fsdp_min_num_params: int = 0, fsdp_config: typing.Union[dict, str, NoneType] = None, fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None, accelerator_config: typing.Optional[str] = None, deepspeed: typing.Optional[str] = None, label_smoothing_factor: float = 0.0, optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch', optim_args: typing.Optional[str] = None, adafactor: bool = False, group_by_length: bool = False, length_column_name: typing.Optional[str] = 'length', report_to: typing.Optional[typing.List[str]] = None, ddp_find_unused_parameters: typing.Optional[bool] = None, ddp_bucket_cap_mb: typing.Optional[int] = None, ddp_broadcast_buffers: typing.Optional[bool] = None, dataloader_pin_memory: bool = True, dataloader_persistent_workers: bool = False, skip_memory_metrics: bool = True, use_legacy_prediction_loop: bool = False, push_to_hub: bool = False, resume_from_checkpoint: typing.Optional[str] = None, hub_model_id: typing.Optional[str] = None, hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save', hub_token: typing.Optional[str] = None, hub_private_repo: bool = False, hub_always_push: bool = False, gradient_checkpointing: bool = False, gradient_checkpointing_kwargs: typing.Optional[dict] = None, include_inputs_for_metrics: bool = False, fp16_backend: str = 'auto', push_to_hub_model_id: typing.Optional[str] = None, push_to_hub_organization: typing.Optional[str] = None, push_to_hub_token: typing.Optional[str] = None, mp_parameters: str = '', auto_find_batch_size: bool = False, full_determinism: bool = False, torchdynamo: typing.Optional[str] = None, ray_scope: typing.Optional[str] = 'last', ddp_timeout: typing.Optional[int] = 1800, torch_compile: bool = False, torch_compile_backend: typing.Optional[str] = None, torch_compile_mode: typing.Optional[str] = None, dispatch_batches: typing.Optional[bool] = None, split_batches: typing.Optional[bool] = None, include_tokens_per_second: typing.Optional[bool] = False, include_num_input_tokens_seen: typing.Optional[bool] = False, neftune_noise_alpha: typing.Optional[float] = None):
self.output_dir = output_dir
self.overwrite_output_dir = overwrite_output_dir
self.do_train = do_train
self.do_eval = do_eval
self.do_predict = do_predict
self.evaluation_strategy = evaluation_strategy
self.prediction_loss_only = prediction_loss_only
self.per_device_train_batch_size = per_device_train_batch_size
self.per_device_eval_batch_size = per_device_eval_batch_size
self.per_gpu_train_batch_size = per_gpu_train_batch_size
self.per_gpu_eval_batch_size = per_gpu_eval_batch_size
self.gradient_accumulation_steps = gradient_accumulation_steps
self.eval_accumulation_steps = eval_accumulation_steps
self.eval_delay = eval_delay
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.adam_beta1 = adam_beta1
self.adam_beta2 = adam_beta2
self.adam_epsilon = adam_epsilon
self.max_grad_norm = max_grad_norm
self.num_train_epochs = num_train_epochs
self.max_steps = max_steps
self.lr_scheduler_type = lr_scheduler_type
self.lr_scheduler_kwargs = lr_scheduler_kwargs
self.warmup_ratio = warmup_ratio
self.warmup_steps = warmup_steps
self.log_level = log_level
self.log_level_replica = log_level_replica
self.log_on_each_node = log_on_each_node
self.logging_dir = logging_dir
self.logging_strategy = logging_strategy
self.logging_first_step = logging_first_step
self.logging_steps = logging_steps
self.logging_nan_inf_filter = logging_nan_inf_filter
self.save_strategy = save_strategy
self.save_steps = save_steps
self.save_total_limit = save_total_limit
self.save_safetensors = save_safetensors
self.save_on_each_node = save_on_each_node
self.save_only_model = save_only_model
self.no_cuda = no_cuda
self.use_cpu = use_cpu
self.use_mps_device = use_mps_device
self.seed = seed
self.data_seed = data_seed
self.jit_mode_eval = jit_mode_eval
self.use_ipex = use_ipex
self.bf16 = bf16
self.fp16 = fp16
self.fp16_opt_level = fp16_opt_level
self.half_precision_backend = half_precision_backend
self.bf16_full_eval = bf16_full_eval
self.fp16_full_eval = fp16_full_eval
self.tf32 = tf32
self.local_rank = local_rank
self.ddp_backend = ddp_backend
self.tpu_num_cores = tpu_num_cores
self.tpu_metrics_debug = tpu_metrics_debug
self.debug = debug
self.dataloader_drop_last = dataloader_drop_last
self.eval_steps = eval_steps
self.dataloader_num_workers = dataloader_num_workers
self.dataloader_prefetch_factor = dataloader_prefetch_factor
self.past_index = past_index
self.run_name = run_name
self.disable_tqdm = disable_tqdm
self.remove_unused_columns = remove_unused_columns
self.label_names = label_names
self.load_best_model_at_end = load_best_model_at_end
self.metric_for_best_model = metric_for_best_model
self.greater_is_better = greater_is_better
self.ignore_data_skip = ignore_data_skip
self.fsdp = fsdp
self.fsdp_min_num_params = fsdp_min_num_params
self.fsdp_config = fsdp_config
self.fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap
self.accelerator_config = accelerator_config
self.deepspeed = deepspeed
self.label_smoothing_factor = label_smoothing_factor
self.optim = optim
self.optim_args = optim_args
self.adafactor = adafactor
self.group_by_length = group_by_length
self.length_column_name = length_column_name
self.report_to = report_to
self.ddp_find_unused_parameters = ddp_find_unused_parameters
self.ddp_bucket_cap_mb = ddp_bucket_cap_mb
self.ddp_broadcast_buffers = ddp_broadcast_buffers
self.dataloader_pin_memory = dataloader_pin_memory
self.dataloader_persistent_workers = dataloader_persistent_workers
self.skip_memory_metrics = skip_memory_metrics
self.use_legacy_prediction_loop = use_legacy_prediction_loop
self.push_to_hub = push_to_hub
self.resume_from_checkpoint = resume_from_checkpoint
self.hub_model_id = hub_model_id
self.hub_strategy = hub_strategy
self.hub_token = hub_token
self.hub_private_repo = hub_private_repo
self.hub_always_push = hub_always_push
self.gradient_checkpointing = gradient_checkpointing
self.gradient_checkpointing_kwargs = gradient_checkpointing_kwargs
self.include_inputs_for_metrics = include_inputs_for_metrics
self.fp16_backend = fp16_backend
self.push_to_hub_model_id = push_to_hub_model_id
self.push_to_hub_organization = push_to_hub_organization
self.push_to_hub_token = push_to_hub_token
self.mp_parameters = mp_parameters
self.auto_find_batch_size = auto_find_batch_size
self.full_determinism = full_determinism
self.torchdynamo = torchdynamo
self.ray_scope = ray_scope
self.ddp_timeout = ddp_timeout
self.torch_compile = torch_compile
self.torch_compile_backend = torch_compile_backend
self.torch_compile_mode = torch_compile_mode
self.dispatch_batches = dispatch_batches
self.split_batches = split_batches
self.include_tokens_per_second = include_tokens_per_second
self.include_num_input_tokens_seen = include_num_input_tokens_seen
self.neftune_noise_alpha = neftune_noise_alpha
Using this as is will yield a SyntaxError
because of the <factory>
issue
highlighted above. We can use on a “normal” class Trainer
.
fields, annotations = get_constructor_field_annotations(Trainer, include_bases=False)
config_class_str = create_config_class_str(fields)
print(config_class_str)
class Config:
def __init__(self, model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module, NoneType] = None, args: typing.Optional[transformers.training_args.TrainingArguments] = None, data_collator: typing.Optional[DataCollator] = None, train_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None, eval_dataset: typing.Union[torch.utils.data.dataset.Dataset, typing.Dict[str, torch.utils.data.dataset.Dataset], NoneType] = None, tokenizer: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None, model_init: typing.Optional[typing.Callable[[], transformers.modeling_utils.PreTrainedModel]] = None, compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalPrediction], typing.Dict]] = None, callbacks: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = None, optimizers: Tuple = (None, None), preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None):
self.model = model
self.args = args
self.data_collator = data_collator
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.tokenizer = tokenizer
self.model_init = model_init
self.compute_metrics = compute_metrics
self.callbacks = callbacks
self.optimizers = optimizers
self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
import transformers
import typing
import torch
from transformers import DataCollator
NoneType = type(None)
config_class_str = create_config_class_str(fields)
# Execute the generated class definition string
namespace = {}
exec(config_class_str, globals(), namespace)
# Extract the newly created class from the namespace
ConfigClass = namespace['Config']
inspect.signature(ConfigClass.__init__)
<Signature (self, model: Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module, NoneType] = None, args: Optional[transformers.training_args.TrainingArguments] = None, data_collator: Optional[DataCollator] = None, train_dataset: Optional[torch.utils.data.dataset.Dataset] = None, eval_dataset: Union[torch.utils.data.dataset.Dataset, Dict[str, torch.utils.data.dataset.Dataset], NoneType] = None, tokenizer: Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None, model_init: Optional[Callable[[], transformers.modeling_utils.PreTrainedModel]] = None, compute_metrics: Optional[Callable[[transformers.trainer_utils.EvalPrediction], Dict]] = None, callbacks: Optional[List[transformers.trainer_callback.TrainerCallback]] = None, optimizers: Tuple = (None, None), preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None)>