11"""Module for the PINA Optimizer."""
22
33from abc import ABCMeta , abstractmethod
4+ from ..utils import check_consistency
45
56
67class Optimizer (metaclass = ABCMeta ):
@@ -9,15 +10,56 @@ class Optimizer(metaclass=ABCMeta):
910 should inherit form this class and implement the required methods.
1011 """
1112
13+ def __init__ (self ):
14+ """
15+ Initialization of the :class:`Optimizer` class.
16+ """
17+ self ._extra_optim_args = {}
18+ self ._optimizer_instance = None
19+
1220 @property
13- @abstractmethod
1421 def instance (self ):
1522 """
16- Abstract property to retrieve the optimizer instance.
23+ Get the optimizer instance.
24+
25+ :return: The optimizer instance.
26+ :rtype: torch.optim.Optimizer
27+ """
28+ return self ._optimizer_instance
29+
30+ @instance .setter
31+ def instance (self , value ):
32+ """
33+ Set the optimizer instance.
34+
35+ :param Any value: The optimizer instance.
1736 """
37+ self ._optimizer_instance = value
1838
1939 @abstractmethod
2040 def hook (self ):
2141 """
2242 Abstract method to define the hook logic for the optimizer.
2343 """
44+
45+ def get_optim_extra_args (self , trainer , batch , losses ):
46+ """
47+ Retrieve and set extra optimizer arguments from the optimizer instance.
48+
49+ This method calls the ``get_optim_extra_args`` method of the underlying
50+ optimizer instance (if it exists) and stores the resulting dictionary in
51+ :attr:`extra_optim_args` of the optimizer instance.
52+
53+ :param trainer: The training manager controlling the optimization loop.
54+ :type trainer: :class:`~pina.trainer.Trainer`
55+ :param dict batch: The current batch of data used for training.
56+ :param dict losses: Dictionary containing the computed loss values.
57+ :raises ValueError: If ``extra_get_optim_extra_argsargs_dict`` does not
58+ return a dictionary.
59+ """
60+ if hasattr (self .instance , "get_optim_extra_args" ):
61+ extra_args = self .instance .get_optim_extra_args (
62+ trainer = trainer , batch = batch , losses = losses
63+ )
64+ check_consistency (extra_args , dict )
65+ self .instance .extra_optim_args = extra_args
0 commit comments