espnet2.enh.diffusion.sampling.predictors.ReverseDiffusionPredictor
espnet2.enh.diffusion.sampling.predictors.ReverseDiffusionPredictor
class espnet2.enh.diffusion.sampling.predictors.ReverseDiffusionPredictor(sde, score_fn, probability_flow=False)
Bases: Predictor
Predictor for reverse diffusion sampling.
This class implements the reverse diffusion process using a score-based generative model. It is designed to perform sampling by reversing the diffusion process through learned score functions.
sde
The stochastic differential equation object used for sampling.
score_fn
The score function used to guide the reverse diffusion.
probability_flow
A boolean indicating whether to use probability flow or not. If True, it applies a deterministic approach.
Parameters:
- sde – An instance of a stochastic differential equation.
- score_fn – A callable that estimates the score.
- probability_flow (bool) – A flag to indicate the use of probability flow (default is False).
Returns: A PyTorch tensor representing the next state in the diffusion : process.
x_mean: A PyTorch tensor representing the mean of the next state : without added noise, useful for denoising.
Return type: x
####### Examples
>>> predictor = ReverseDiffusionPredictor(sde, score_fn)
>>> x, x_mean = predictor.update_fn(current_state, current_time)
NOTE
This predictor assumes that the input x and t are properly formatted tensors, and additional arguments can be passed as needed.
- Raises:NotImplementedError – If the update function is called without implementing the necessary functionality.
update_fn(x, t, *args)
Predictor for reverse diffusion processes.
This class implements the update function for reverse diffusion, utilizing a specified stochastic differential equation (SDE) and score function. The reverse diffusion process aims to recover the original data from a noisy version through a learned score function.
sde
A stochastic differential equation object.
rsde
A reverse stochastic differential equation object.
score_fn
A function that computes the score (gradient of log probability).
probability_flow
A boolean indicating whether to use probability flow in the update.
- Parameters:
- sde – An instance of the SDE to be used.
- score_fn – A callable that returns the score function.
- probability_flow – A boolean flag for enabling probability flow (default: False).
- Returns: A PyTorch tensor representing the next state after the update. x_mean: A PyTorch tensor representing the next state without
random noise, useful for denoising.
- Return type: x
####### Examples
>>> predictor = ReverseDiffusionPredictor(sde, score_fn)
>>> x_next, x_mean = predictor.update_fn(x_current, t_current)
NOTE
This implementation assumes that the input tensor x is compatible with the dimensions expected by the SDE and score function.
- Raises:ValueError – If the input tensor x or time step t are not of the expected shape or type.