espnet2.schedulers.abs_scheduler.AbsScheduler
espnet2.schedulers.abs_scheduler.AbsScheduler
class espnet2.schedulers.abs_scheduler.AbsScheduler
Bases: ABC
Abstract base class for defining learning rate schedulers.
This class provides a blueprint for implementing various types of learning rate schedulers in PyTorch. Schedulers help to adjust the learning rate during training based on specific strategies.
Classes inheriting from AbsScheduler should implement the following abstract methods:
- step: Updates the learning rate based on the current epoch or other criteria.
- state_dict: Returns the state of the scheduler as a dictionary.
- load_state_dict: Loads the scheduler state from a given dictionary.
Example usage:
``
`
python class CustomScheduler(AbsScheduler):
def __init__(self, optimizer): : self.scheduler = L.StepLR(optimizer, step_size=10, gamma=0.1)
def step(self, epoch: int = None): : self.scheduler.step(epoch)
def state_dict(self): : return self.scheduler.state_dict()
def load_state_dict(self, state): : self.scheduler.load_state_dict(state)
``
`
None
- Parameters:None
- Returns: None
- Yields: None
- Raises:
- NotImplementedError – If any of the abstract methods are not
- implemented in a subclass. –
abstract load_state_dict(state)
Loads the state of the learning rate scheduler from a given state dictionary.
This method is intended to restore the internal state of the scheduler from a previously saved state. This is particularly useful when resuming training from a checkpoint, allowing the scheduler to continue its operation as expected.
- Parameters:
- state (dict) – A state dictionary containing the parameters to load into the
- the (scheduler. This should be a dictionary that was previously saved by)
- method. (state_dict)
- Raises:
- ValueError – If the provided state does not match the expected format or
- contains invalid keys. –
######### Examples
Example of saving and loading state in a custom scheduler
scheduler = MyCustomScheduler()
Save the state
state = scheduler.state_dict()
Load the state
scheduler.load_state_dict(state)
NOTE
This method should be implemented by subclasses to define how the state is restored, ensuring that all necessary attributes are correctly updated.
abstract state_dict()
Abstract method to retrieve the state of the scheduler.
This method should return a dictionary containing the current state of the scheduler. The returned state can be used to save the scheduler’s state and later load it using the load_state_dict method.
- Returns: A dictionary containing the state of the scheduler.
- Return type: dict
######### Examples
Example of how to use state_dict in a custom scheduler
class CustomScheduler(AbsScheduler):
def __init__(self): : self.state =
def step(self, epoch: int = None): : if epoch is not None: : self.state[‘epoch’] = epoch
def state_dict(self): : return self.state
def load_state_dict(self, state): : self.state = state
scheduler = CustomScheduler() print(scheduler.state_dict()) # Output:
abstract step(epoch: int | None = None)
Abstract base class for defining custom learning rate schedulers.
This class provides a blueprint for creating learning rate schedulers that can be integrated into training loops. It defines the required methods that any scheduler should implement, including the step method, which updates the learning rate based on the current epoch.
None
- Parameters:epoch (int , optional) – The current epoch. If None, the method should handle it accordingly. Defaults to None.
- Returns: None
- Raises:NotImplementedError – If the derived class does not implement the step method.
######### Examples
class CustomScheduler(AbsScheduler): : def __init__(self, optimizer): : self.optimizer = optimizer <br/> def step(self, epoch=None): : # Custom logic to adjust the learning rate pass <br/> def state_dict(self): : # Return the state of the scheduler pass <br/> def load_state_dict(self, state): : # Load the state of the scheduler pass
NOTE
This is an abstract class and should not be instantiated directly. Derived classes must implement the abstract methods.