espnet2.enh.diffusion.sampling.correctors.AnnealedLangevinDynamics
espnet2.enh.diffusion.sampling.correctors.AnnealedLangevinDynamics
class espnet2.enh.diffusion.sampling.correctors.AnnealedLangevinDynamics(sde, score_fn, snr, n_steps)
Bases: Corrector
The original annealed Langevin dynamics predictor in NCSN/NCSNv2.
This class implements the Annealed Langevin Dynamics algorithm as a corrector for the diffusion sampling process. It is specifically designed for use with Ornstein-Uhlenbeck (OU) processes, leveraging the score function to iteratively refine the state estimate.
sde
The stochastic differential equation (SDE) used for the process.
score_fn
The function that estimates the score of the data.
snr
The signal-to-noise ratio used in the dynamics.
n_steps
The number of steps to take in the Langevin update.
- Parameters:
- sde – An instance of a stochastic differential equation (SDE) class, specifically expected to be an OU process.
- score_fn – A callable that computes the score (gradient of the log probability) of the current state.
- snr – A float representing the desired signal-to-noise ratio.
- n_steps – An integer representing the number of Langevin steps to perform in the update.
- Raises:NotImplementedError – If the provided SDE is not an instance of sdes.OUVESDE.
- Returns:
- x: The updated state tensor after applying the Langevin dynamics.
- x_mean: The denoised state tensor (mean state) without : random noise, useful for further processing.
- Return type: A tuple of two PyTorch tensors
####### Examples
Assuming sde, score_fn, initial_state, and time_step are defined
ald_corrector = AnnealedLangevinDynamics(sde, score_fn, snr=1.0, n_steps=10) updated_state, denoised_state = ald_corrector.update_fn(initial_state, time_step)
NOTE
The algorithm relies on the SDE’s ability to compute marginal probabilities and assumes that the score function is properly defined for the given state and time step.
update_fn(x, t, *args)
One update of the corrector.
This method performs a single update step of the corrector algorithm, adjusting the current state x based on the score function and the provided time step t. The update is performed over a specified number of steps (n_steps), incorporating noise and the signal-to-noise ratio (SNR) to achieve the desired state.
- Parameters:
- x – A PyTorch tensor representing the current state.
- t – A PyTorch tensor representing the current time step.
- *args – Possibly additional arguments, in particular y for OU processes.
- Returns:
- x: A PyTorch tensor of the next state after the update.
- x_mean: A PyTorch tensor representing the next state without random noise, useful for denoising.
- Return type: A tuple containing
####### Examples
Example usage:
Assuming score_fn is defined and returns a tensor of gradients
for the current state x at time t, and sde is an instance
of the SDE class.
x_next, x_mean_next = self.update_fn(x_current, t_current)
NOTE
The update step is influenced by the score_fn, which is expected to compute the gradient of the log probability of the data. The noise added during the update is sampled from a standard normal distribution.
- Raises:ValueError – If the input tensors x and t are not of the expected shape or type.