espnet2.enh.diffusion.sampling.predictors.NonePredictor
espnet2.enh.diffusion.sampling.predictors.NonePredictor
class espnet2.enh.diffusion.sampling.predictors.NonePredictor(*args, **kwargs)
Bases: Predictor
An empty predictor that does nothing.
This class serves as a placeholder for situations where no prediction is required. It inherits from the Predictor abstract base class and implements the update_fn method to simply return the input state without any modifications.
None
- Parameters:
- *args – Variable length argument list.
- **kwargs – Arbitrary keyword arguments.
- Returns:
- x: A PyTorch tensor representing the current state (unchanged).
- x: A PyTorch tensor representing the next state (unchanged).
- Return type: A tuple containing
####### Examples
>>> predictor = NonePredictor()
>>> x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
>>> t = torch.tensor(0)
>>> next_state, _ = predictor.update_fn(x, t)
>>> print(next_state)
tensor([[1.0, 2.0],
[3.0, 4.0]])
update_fn(x, t, *args)
One update of the predictor.
This method is responsible for computing the next state of the predictor based on the current state and time step. The behavior of this method is determined by the specific implementation of the predictor class.
- Parameters:
- x – A PyTorch tensor representing the current state.
- t – A PyTorch tensor representing the current time step.
- *args – Possibly additional arguments, in particular y for OU processes.
- Returns: x: A PyTorch tensor of the next state. x_mean: A PyTorch tensor representing the next state without random
noise, useful for denoising.
- Return type: A tuple containing
####### Examples
Example usage:
x_current = torch.tensor([[0.0, 0.0]]) t_current = torch.tensor([1.0]) next_state, next_mean = predictor.update_fn(x_current, t_current)
NOTE
The actual behavior of this method will depend on the specific implementation of the predictor class (e.g., EulerMaruyamaPredictor or ReverseDiffusionPredictor).