espnet2.enh.diffusion.sdes.OUVESDE
espnet2.enh.diffusion.sdes.OUVESDE
class espnet2.enh.diffusion.sdes.OUVESDE(theta=1.5, sigma_min=0.05, sigma_max=0.5, N=1000, **ignored_kwargs)
Bases: SDE
Construct an Ornstein-Uhlenbeck Variance Exploding Stochastic Differential Equation (SDE).
This SDE is characterized by the following dynamics:
dx = -theta * (y - x) dt + sigma(t) dw
where:
sigma(t) = sigma_min * (sigma_max/sigma_min)^t * sqrt(2 * log(sigma_max/sigma_min))
The “steady-state mean” y must be provided as an argument to the methods requiring it (e.g., sde or marginal_prob).
theta
Stiffness parameter.
- Type: float
sigma_min
Minimum value for sigma.
- Type: float
sigma_max
Maximum value for sigma.
- Type: float
N
Number of discretization steps.
- Type: int
logsig
Logarithm of the ratio between sigma_max and sigma_min.
Type: float
Parameters:
- theta (float) – Stiffness parameter.
- sigma_min (float) – Smallest sigma.
- sigma_max (float) – Largest sigma.
- N (int) – Number of discretization steps.
########
Example
>>> ouvesde = OUVESDE(theta=1.5, sigma_min=0.05, sigma_max=0.5, N=1000)
>>> x = torch.tensor([0.0])
>>> y = torch.tensor([1.0])
>>> t = torch.tensor(0.5)
>>> drift, diffusion = ouvesde.sde(x, t, y)
>>> mean, std = ouvesde.marginal_prob(x, t, y)
- Raises:NotImplementedError – If prior_logp is called as it is not implemented.
Construct an Ornstein-Uhlenbeck Variance Exploding SDE.
Note that the “steady-state mean” y is not provided at construction, but must rather be given as an argument to the methods which require it (e.g., sde or marginal_prob).
dx = -theta (y-x) dt + sigma(t) dw
with
sigma(t) = sigma_min (sigma_max/sigma_min)^t * sqrt(2 log(sigma_max/sigma_min))
- Parameters:
- theta – stiffness parameter.
- sigma_min – smallest sigma.
- sigma_max – largest sigma.
- N – number of discretization steps
property T
Abstract SDE classes, Reverse SDE, and VE/VP SDEs.
Taken and adapted from https://github.com/yang-song/score_sde_pytorch and https://github.com/sp-uhh/sgmse
This module provides abstract classes and implementations for stochastic differential equations (SDEs) that are used in various probabilistic models. The classes are designed to handle different types of SDEs, including Ornstein-Uhlenbeck processes with variance exploding and variance preserving properties.
Classes: : SDE: Abstract base class for stochastic differential equations. OUVESDE: Implementation of the Ornstein-Uhlenbeck Variance Exploding SDE. OUVPSDE: Implementation of the Ornstein-Uhlenbeck Variance Preserving SDE.
Usage: : These classes can be used as base classes for specific SDE implementations where the drift and diffusion functions can be defined. The user can instantiate these classes and call their methods to simulate SDE behavior or to compute marginal probabilities.
copy()
Abstract SDE classes, Reverse SDE, and VE/VP SDEs.
Taken and adapted from https://github.com/yang-song/score_sde_pytorch and https://github.com/sp-uhh/sgmse
marginal_prob(x0, t, y)
Compute the marginal probability distribution of the SDE.
This method calculates the mean and standard deviation of the marginal distribution at a given time t for the state variable x, conditioned on the steady-state mean y. The marginal distribution is defined by the Ornstein-Uhlenbeck process parameters.
- Parameters:
- x0 – Initial state variable (tensor).
- t – Time at which to evaluate the marginal distribution (float).
- y – Steady-state mean (tensor).
- Returns:
- mean (tensor): The mean of the marginal distribution.
- std (tensor): The standard deviation of the marginal distribution.
- Return type: A tuple containing
########
Example
>>> sde = OUVESDE(theta=1.5, sigma_min=0.05, sigma_max=0.5)
>>> mean, std = sde.marginal_prob(x0=torch.tensor([0.0]),
... t=0.5,
... y=torch.tensor([1.0]))
>>> print(mean, std)
##
N
prior_logp(z)
Compute log-density of the prior distribution.
This method is useful for computing the log-likelihood via probability flow ODE.
- Parameters:z – Latent code for which the log-density is to be computed.
- Returns: log probability density corresponding to the input latent code.
- Raises:
- NotImplementedError – This method is not yet implemented for the
- Ornstein-Uhlenbeck Variance Exploding SDE. –
########
Example
>>> sde = OUVESDE()
>>> z = torch.tensor([0.5, 0.2])
>>> log_density = sde.prior_logp(z)
NotImplementedError: prior_logp for OU SDE not yet implemented!
prior_sampling(shape, y)
Generate one sample from the prior distribution, $p_T(x|args)$.
This method generates a sample from the prior distribution defined by the Ornstein-Uhlenbeck process. It adds Gaussian noise to the input y, scaled by the standard deviation computed at time T.
- Parameters:
- shape – Desired shape of the output sample. If it does not match the shape of y, a warning is issued and the shape of y is used instead.
- y – The steady-state mean around which the sample is generated. This should be a tensor of shape compatible with the output.
- Returns: A tensor of shape shape containing a sample from the prior distribution.
- Raises:Warning – If the target shape does not match the shape of y.
########
Example
>>> ouvesde = OUVESDE()
>>> y = torch.zeros((10, 3, 32, 32))
>>> sample = ouvesde.prior_sampling((10, 3, 32, 32), y)
>>> print(sample.shape)
torch.Size([10, 3, 32, 32])
sde(x, t, y)
Abstract SDE classes, Reverse SDE, and VE/VP SDEs.
This module contains abstract classes for Stochastic Differential Equations (SDEs), including Reverse SDE and Variance Exploding/Preserving SDEs. It has been adapted from the following repositories:
Classes: : SDE: Abstract class for SDEs, designed for mini-batch processing. OUVESDE: Implements an Ornstein-Uhlenbeck Variance Exploding SDE. OUVPSDE: Implements an Ornstein-Uhlenbeck Variance Preserving SDE.
Usage: : You can create an instance of OUVESDE or OUVPSDE and call their methods to perform operations like sampling, computing marginal probabilities, etc.
Example
Create an instance of OUVESDE
sde = OUVESDE(theta=1.5, sigma_min=0.05, sigma_max=0.5, N=1000)
Sample from the prior distribution
sample = sde.prior_sampling(shape=(10, 3), y=torch.tensor([[0.0, 0.0, 0.0]]))
Compute marginal probability
mean, std = sde.marginal_prob(x0=torch.tensor([[0.0, 0.0, 0.0]]), t=0.5, y=torch.tensor([[1.0, 1.0, 1.0]]))
##
N