espnet2.enh.diffusion.sampling.correctors.LangevinCorrector
espnet2.enh.diffusion.sampling.correctors.LangevinCorrector
class espnet2.enh.diffusion.sampling.correctors.LangevinCorrector(sde, score_fn, snr, n_steps)
Bases: Corrector
Langevin Corrector for score-based diffusion models.
This class implements the Langevin Corrector algorithm, which is used in score-based generative models to iteratively refine samples by using Langevin dynamics. It updates the state based on the gradients from the score function and adds noise to facilitate exploration of the sample space.
score_fn
A function that computes the score at a given state and time.
- Type: callable
snr
The signal-to-noise ratio used to scale the updates.
- Type: float
n_steps
The number of update steps to perform in each call.
Type: int
Parameters:
- sde – The stochastic differential equation object.
- score_fn – The score function used to compute gradients.
- snr – The signal-to-noise ratio.
- n_steps – The number of steps for the Langevin dynamics update.
Returns: A tuple containing: : - x (torch.Tensor): The updated state after applying the Langevin dynamics.
- x_mean (torch.Tensor): The mean state without random noise, useful for denoising.
Return type: tuple
####### Examples
>>> import torch
>>> sde = ... # Define your SDE here
>>> score_fn = ... # Define your score function here
>>> corrector = LangevinCorrector(sde, score_fn, snr=1.0, n_steps=10)
>>> x_init = torch.randn(1, 3, 64, 64) # Example initial state
>>> t = torch.tensor(0.5) # Example time step
>>> x_updated, x_mean = corrector.update_fn(x_init, t)
NOTE
This corrector assumes that the score function is well-defined and can handle the inputs provided.
- Raises:
- NotImplementedError – If the score function or SDE is not compatible
- with the Langevin dynamics. –
update_fn(x, t, *args)
One update of the corrector.
This method performs a single update step of the corrector algorithm, modifying the current state x based on the specified time step t. The update involves computing the gradient of the score function, generating noise, and calculating the new state.
- Parameters:
- x (torch.Tensor) – A PyTorch tensor representing the current state.
- t (torch.Tensor) – A PyTorch tensor representing the current time step.
- *args – Possibly additional arguments, in particular y for OU processes.
- Returns: A tuple containing: : - x (torch.Tensor): A PyTorch tensor of the next state.
- x_mean (torch.Tensor): A PyTorch tensor representing the next state without random noise. Useful for denoising.
- Return type: tuple
####### Examples
>>> corrector = LangevinCorrector(sde, score_fn, snr, n_steps)
>>> next_state, denoised_state = corrector.update_fn(current_state, time_step)
NOTE
This method is expected to be overridden in subclasses to provide specific update logic based on the corrector algorithm being implemented.