espnet2.enh.diffusion.sdes.SDE
espnet2.enh.diffusion.sdes.SDE
class espnet2.enh.diffusion.sdes.SDE(N)
Bases: ABC
Abstract SDE classes, Reverse SDE, and VE/VP SDEs.
This module provides an abstract base class for Stochastic Differential Equations (SDEs) and implementations for specific types of SDEs, such as Ornstein-Uhlenbeck Variance Exploding (OUVESDE) and Ornstein-Uhlenbeck Variance Preserving (OUVPSDE) SDEs. The SDE class includes methods for defining the dynamics of the SDE, marginal probabilities, and sampling from the prior distribution.
Taken and adapted from:
Classes: : SDE: Abstract class for SDEs. OUVESDE: Implementation of the Ornstein-Uhlenbeck Variance Exploding SDE. OUVPSDE: Implementation of the Ornstein-Uhlenbeck Variance Preserving SDE.
Usage: : To create a specific SDE, instantiate one of the subclasses (e.g., OUVESDE or OUVPSDE) and use its methods for sampling and computing probabilities.
N
Number of discretization time steps.
Type: int
Parameters:N – Number of discretization time steps.
##################### Examples
>>> ouvesde = OUVESDE(theta=1.5, sigma_min=0.05, sigma_max=0.5, N=1000)
>>> x = torch.randn(10, 3, 32, 32) # Example input tensor
>>> t = torch.tensor(0.5) # Example time step
>>> y = torch.randn(10, 3, 32, 32) # Example steady-state mean
>>> drift, diffusion = ouvesde.sde(x, t, y)
>>> mean, std = ouvesde.marginal_prob(x, t, y)
>>> sample = ouvesde.prior_sampling((10, 3, 32, 32), y)
######
N
Construct an SDE.
- Parameters:N – number of discretization time steps.
abstract property T
Abstract SDE classes, Reverse SDE, and VE/VP SDEs.
This module provides abstract classes for Stochastic Differential Equations (SDEs), including Reverse SDEs and Variance Exploding/Preserving SDEs. It is adapted from the following repositories:
N
Number of discretization time steps.
Type: int
Parameters:N (int) – Number of discretization time steps.
Returns: End time of the SDE.
Return type: float
Yields: None
Raises:NotImplementedError – If the method is not implemented in a subclass.
##################### Examples
sde = MySDEClass(N=1000) end_time = sde.T drift, diffusion = sde.sde(x, t,
*
args) mean, std = sde.marginal_prob(x, t,
*
args) sample = sde.prior_sampling(shape,
*
args) log_density = sde.prior_logp(z)
######
N
abstract copy()
Abstract SDE classes, Reverse SDE, and VE/VP SDEs.
This module provides an abstract base class for Stochastic Differential Equations (SDEs) and concrete implementations for specific SDE types, including Ornstein-Uhlenbeck Variance Exploding SDE (OUVESDE) and Ornstein-Uhlenbeck Variance Preserving SDE (OUVPSDE).
Taken and adapted from:
N
Number of discretization time steps.
- Parameters:N (int) – The number of discretization time steps for the SDE.
- Returns: None
##################### Examples
Example usage of creating an instance of an SDE subclass
ouvesde = OUVESDE(theta=1.5, sigma_min=0.05, sigma_max=0.5, N=1000) drift, diffusion = ouvesde.sde(x, t, y)
######
N
discretize(x, t, *args)
Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
This method is useful for reverse diffusion sampling and probability flow sampling. It defaults to the Euler-Maruyama discretization method.
- Parameters:
- x (torch.Tensor) – A tensor representing the state at time t.
- t (torch.Tensor) – A torch float representing the time step (from 0 to self.T).
- *args – Additional arguments that may be required by the specific SDE implementation.
- Returns: A tuple containing: : - f (torch.Tensor): The drift term scaled by the time step.
- G (torch.Tensor): The diffusion term scaled by the square root of the time step.
- Return type: Tuple[torch.Tensor, torch.Tensor]
##################### Examples
>>> sde = MySDEClass(N=100) # Replace with actual SDE class
>>> x = torch.tensor([0.0])
>>> t = torch.tensor(0.5)
>>> f, G = sde.discretize(x, t)
>>> print(f, G)
######
N
abstract marginal_prob(x, t, *args)
Parameters to determine the marginal distribution of
the SDE, $p_t(x|args)$.
- Parameters:
- x – A tensor representing the initial state of the system.
- t – A float representing the time step at which to evaluate the marginal distribution (from 0 to self.T).
- y – A tensor representing the steady-state mean that influences the marginal distribution.
- Returns:
- mean: The expected value of the state at time t.
- std: The standard deviation of the state at time t.
- Return type: A tuple containing
##################### Examples
>>> sde = OUVESDE(theta=1.5, sigma_min=0.05, sigma_max=0.5, N=1000)
>>> mean, std = sde.marginal_prob(x0=torch.tensor([0.0]), t=0.5, y=1.0)
>>> print(mean)
>>> print(std)
######
N
abstract prior_logp(z)
Compute log-density of the prior distribution.
This function is useful for computing the log-likelihood via the probability flow ODE. It should be implemented in subclasses of SDE to provide the log-density of the prior distribution given a latent code.
- Parameters:z – A tensor representing the latent code for which the log-density is to be computed.
- Returns: A tensor representing the log-density of the prior distribution evaluated at the given latent code.
- Return type: log probability density
- Raises:
- NotImplementedError – If this method is not implemented in a
- subclass. –
##################### Examples
Example usage in a subclass:
class MySDE(SDE):
def prior_logp(self, z): : # Custom implementation for computing log-density return -0.5 * torch.sum(z ** 2)
sde = MySDE(N=1000) latent_code = torch.randn(10, 3) # Example latent code log_density = sde.prior_logp(latent_code) print(log_density) # Output the log-density
abstract prior_sampling(shape, *args)
Generate one sample from the prior distribution.
This method samples from the prior distribution, denoted as $p_T(x|args)$, with a specified output shape. It allows for generating latent codes from the learned distribution at the end time T.
- Parameters:
- shape (tuple) – The desired shape of the generated sample.
- *args – Additional arguments specific to the SDE implementation.
- Returns: A sample drawn from the prior distribution with the specified shape.
- Return type: torch.Tensor
- Raises:
- UserWarning – If the provided shape does not match the shape of
- the given y. –
##################### Examples
>>> sde = OUVESDE()
>>> sample = sde.prior_sampling((10, 3, 32, 32), y=torch.randn(10, 3, 32, 32))
>>> print(sample.shape)
torch.Size([10, 3, 32, 32])
reverse(score_model, probability_flow=False)
Abstract SDE classes, Reverse SDE, and VE/VP SDEs.
Taken and adapted from https://github.com/yang-song/score_sde_pytorch and https://github.com/sp-uhh/sgmse
N
Number of discretization time steps.
Type: int
Parameters:N (int) – Number of discretization time steps.
Returns: An instance of the reverse-time SDE/ODE.
Return type: RSDE
Raises:
- NotImplementedError – If the abstract methods are not implemented in
- subclasses. –
##################### Examples
Creating a reverse-time SDE with a score model
reverse_sde = sde_instance.reverse(score_model, probability_flow=True)
Using the reverse SDE to generate samples
samples = reverse_sde.prior_sampling(shape=(100, 3, 32, 32), y=some_tensor)
######
N
abstract sde(x, t, *args)
Abstract SDE classes, Reverse SDE, and VE/VP SDEs.
This module contains abstract classes and implementations for Stochastic Differential Equations (SDEs), including reverse SDEs and variance exploding/preserving SDEs. These classes are designed for use with mini-batch inputs in machine learning applications.
The code is adapted from:
Classes: : SDE: Abstract base class for SDEs. OUVESDE: Implementation of an Ornstein-Uhlenbeck Variance Exploding SDE. OUVPSDE: Implementation of an Ornstein-Uhlenbeck Variance Preserving SDE.
Notes
The “steady-state mean” y is not provided at construction but must be supplied as an argument to methods that require it (e.g., sde or marginal_prob).
##################### Examples
Creating an instance of OUVESDE
ouvesde = OUVESDE(theta=1.5, sigma_min=0.05, sigma_max=0.5, N=1000)
Sampling from the prior distribution
x_T = ouvesde.prior_sampling(shape=(10, 3, 32, 32), y=torch.zeros((10, 3, 32, 32)))
Creating an instance of OUVPSDE
ouvpsde = OUVPSDE(beta_min=0.1, beta_max=0.5, stiffness=1, N=1000)
Obtaining marginal probabilities
mean, std = ouvpsde.marginal_prob(x0=torch.zeros((10, 3, 32, 32)), t=0.5, y=torch.ones((10, 3, 32, 32)))