espnet2.enh.diffusion.sampling.predictors.Predictor
espnet2.enh.diffusion.sampling.predictors.Predictor
class espnet2.enh.diffusion.sampling.predictors.Predictor(sde, score_fn, probability_flow=False)
Bases: ABC
The abstract class for a predictor algorithm.
This class serves as a base for various predictor algorithms that can be implemented for different types of stochastic differential equations (SDEs). The predictor utilizes a score function and the SDE model to update its predictions over time.
sde
The stochastic differential equation model.
rsde
The reverse stochastic differential equation derived from the score function.
score_fn
The function used to compute the score.
probability_flow
A boolean indicating if the probability flow is utilized.
- Parameters:
- sde – The SDE model to be used.
- score_fn – The score function that will be used for the predictions.
- probability_flow – (Optional) A boolean indicating whether to use probability flow (default: False).
update_fn(x, t, *args)
An abstract method that updates the predictor state based on the current state and time.
debug_update_fn(x, t, *args)
Raises a NotImplementedError indicating that the debug update function is not implemented.
- Raises:NotImplementedError – If the debug_update_fn is called without implementation.
######### Examples
>>> predictor = SomeConcretePredictor(sde, score_fn)
>>> next_state, mean_state = predictor.update_fn(current_state, time)
debug_update_fn(x, t, *args)
Debug update function for the Predictor class.
This function is intended for debugging purposes and should be implemented in subclasses of Predictor to provide specific functionality. Currently, it raises a NotImplementedError to indicate that the debug update function has not been defined for the specific predictor instance.
- Parameters:
- x – A PyTorch tensor representing the current state.
- t – A PyTorch tensor representing the current time step.
- *args – Possibly additional arguments that may be used in specific implementations.
- Raises:NotImplementedError – If the debug update function is not implemented for the predictor.
######### Examples
Example usage (will raise NotImplementedError)
predictor = SomePredictorClass(…) x, t = torch.randn(1, 3), torch.tensor(0.0) predictor.debug_update_fn(x, t)
NOTE
Subclasses should override this method to provide a meaningful implementation for debugging.
abstract update_fn(x, t, *args)
One update of the predictor.
This method computes one update step for the predictor algorithm. It takes the current state and time step as input and returns the updated state along with the mean state, which is useful for denoising.
- Parameters:
- x (torch.Tensor) – A PyTorch tensor representing the current state.
- t (torch.Tensor) – A PyTorch tensor representing the current time step.
- *args – Possibly additional arguments, in particular y for OU processes.
- Returns: A tuple containing: : - x (torch.Tensor): A PyTorch tensor of the next state.
- x_mean (torch.Tensor): A PyTorch tensor representing the next state without random noise, useful for denoising.
- Return type: tuple
######### Examples
Example usage:
x_current = torch.tensor([[0.0], [1.0]]) t_current = torch.tensor([0.1]) next_state, next_state_mean = predictor.update_fn(x_current, t_current)
NOTE
This method is abstract and must be implemented in subclasses of the Predictor class.