espnet2.enh.diffusion.sampling.predictors.EulerMaruyamaPredictor
espnet2.enh.diffusion.sampling.predictors.EulerMaruyamaPredictor
class espnet2.enh.diffusion.sampling.predictors.EulerMaruyamaPredictor(sde, score_fn, probability_flow=False)
Bases: Predictor
Euler-Maruyama predictor for stochastic differential equations.
This class implements the Euler-Maruyama method, which is a numerical method for simulating stochastic differential equations (SDEs). It predicts the next state based on the current state and the given parameters of the SDE.
- Parameters:
- sde – The stochastic differential equation to be solved.
- score_fn – The score function used to evaluate the gradient.
- probability_flow (bool) – If True, use probability flow; defaults to False.
sde
The stochastic differential equation.
rsde
The reverse SDE corresponding to the forward SDE.
score_fn
The score function.
probability_flow
Boolean indicating the use of probability flow.
- Returns: A PyTorch tensor of the next state. x_mean: A PyTorch tensor representing the next state without
random noise, useful for denoising.
- Return type: x
####### Examples
>>> predictor = EulerMaruyamaPredictor(sde, score_fn)
>>> x_next, x_mean = predictor.update_fn(x_current, t_current)
NOTE
This predictor is typically used in the context of sampling from a diffusion process.
- Raises:NotImplementedError – If the update function is not implemented.
update_fn(x, t, *args)
One update of the predictor.
This method computes one update step for the predictor algorithm. It takes the current state and time step as input and returns the updated state along with the mean state, which is useful for denoising.
- Parameters:
- x (torch.Tensor) – A PyTorch tensor representing the current state.
- t (torch.Tensor) – A PyTorch tensor representing the current time step.
- *args – Possibly additional arguments, in particular y for OU processes.
- Returns: A tuple containing: : - x (torch.Tensor): A PyTorch tensor of the next state.
- x_mean (torch.Tensor): A PyTorch tensor representing the next state without random noise, useful for denoising.
- Return type: tuple
####### Examples
Example usage:
x_current = torch.tensor([[0.0], [1.0]]) t_current = torch.tensor([0.1]) next_state, next_state_mean = predictor.update_fn(x_current, t_current)
NOTE
This method is abstract and must be implemented in subclasses of the Predictor class.