espnet2.enh.diffusion.sampling.correctors.Corrector
espnet2.enh.diffusion.sampling.correctors.Corrector
class espnet2.enh.diffusion.sampling.correctors.Corrector(sde, score_fn, snr, n_steps)
Bases: ABC
The abstract class for a corrector algorithm.
This class serves as a base for implementing various corrector algorithms used in diffusion models. It requires subclasses to implement the update_fn method, which performs a single update step of the corrector algorithm.
rsde
The reverse SDE obtained from the provided score function.
score_fn
A function that estimates the score (gradient of log probability) of the data.
snr
Signal-to-noise ratio used in the update step.
n_steps
Number of steps to perform in the update process.
- Parameters:
- sde – The stochastic differential equation to be used.
- score_fn – A callable function to estimate the score.
- snr – A float representing the signal-to-noise ratio.
- n_steps – An integer indicating the number of steps for the update.
- Raises:NotImplementedError – If the subclass does not implement the update_fn method.
####### Examples
Example subclass implementation
class MyCorrector(Corrector):
def update_fn(self, x, t,
*
args): : # Custom update logic here pass
abstract update_fn(x, t, *args)
One update of the corrector.
This method performs a single update step of the corrector algorithm, adjusting the current state based on the provided inputs and the specific corrector implementation.
- Parameters:
- x (torch.Tensor) – A PyTorch tensor representing the current state.
- t (torch.Tensor) – A PyTorch tensor representing the current time step.
- *args – Possibly additional arguments, in particular y for OU processes.
- Returns: A tuple containing: : - x (torch.Tensor): A PyTorch tensor of the next state.
- x_mean (torch.Tensor): A PyTorch tensor of the next state without random noise. Useful for denoising.
- Return type: tuple
####### Examples
Example usage of update_fn
current_state = torch.randn(1, 3, 32, 32) # Example current state current_time = torch.tensor(0.5) # Example current time step next_state, denoised_state = corrector_instance.update_fn(current_state, current_time)
NOTE
The specific behavior of the update step depends on the concrete implementation of the Corrector subclass.