espnet2.enh.diffusion.sdes.OUVPSDE
espnet2.enh.diffusion.sdes.OUVPSDE
class espnet2.enh.diffusion.sdes.OUVPSDE(beta_min, beta_max, stiffness=1, N=1000, **ignored_kwargs)
Bases: SDE
OUVPSDE class.
!!! SGMSE authors observed instabilities around t=0.2. !!!
Construct an Ornstein-Uhlenbeck Variance Preserving SDE:
dx = -1/2 * beta(t) * stiffness * (y-x) dt + sqrt(beta(t)) * dw
with
beta(t) = beta_min + t(beta_max - beta_min)
Note that the “steady-state mean” y is not provided at construction, but must rather be given as an argument to the methods which require it (e.g., sde or marginal_prob).
beta_min
Smallest value for beta.
- Type: float
beta_max
Largest value for beta.
- Type: float
stiffness
Stiffness factor of the drift, default is 1.
- Type: float
N
Number of discretization steps.
Type: int
Parameters:
- beta_min – Smallest beta value.
- beta_max – Largest beta value.
- stiffness – Stiffness factor of the drift. 1 by default.
- N – Number of discretization steps.
Returns: None
################# Examples
>>> ou_vpsde = OUVPSDE(beta_min=0.1, beta_max=0.5, stiffness=1, N=1000)
>>> x0 = torch.randn(10, 3, 32, 32) # Example input
>>> y = torch.randn(10, 3, 32, 32) # Steady-state mean
>>> t = torch.tensor(0.5) # Time step
>>> mean, std = ou_vpsde.marginal_prob(x0, t, y)
- Raises:NotImplementedError – If prior_logp is called as it is not implemented.
OUVPSDE class.
!!! SGMSE authors observed instabilities around t=0.2. !!!
Construct an Ornstein-Uhlenbeck Variance Preserving SDE:
dx = -1/2 * beta(t) * stiffness * (y-x) dt + sqrt(beta(t)) * dw
with
beta(t) = beta_min + t(beta_max - beta_min)
Note that the “steady-state mean” y is not provided at construction, but must rather be given as an argument to the methods which require it (e.g., sde or marginal_prob).
- Parameters:
- beta_min – smallest sigma.
- beta_max – largest sigma.
- stiffness – stiffness factor of the drift. 1 by default.
- N – number of discretization steps
property T
Abstract SDE classes, Reverse SDE, and VE/VP SDEs.
This module defines abstract classes and methods for Stochastic Differential Equations (SDEs), including the Reverse SDE and Variance Exploding/Preserving SDEs. It is adapted from the works available at: https://github.com/yang-song/score_sde_pytorch and https://github.com/sp-uhh/sgmse.
Classes: : SDE: Abstract class for Stochastic Differential Equations. OUVESDE: Ornstein-Uhlenbeck Variance Exploding SDE class. OUVPSDE: Ornstein-Uhlenbeck Variance Preserving SDE class.
Functions: : batch_broadcast: Broadcasts a tensor over all dimensions of another tensor, except the batch dimension.
T
End time of the SDE.
- Parameters:N – Number of discretization time steps.
- Raises:NotImplementedError – Raised if a method that is not implemented is called.
################# Examples
Creating an instance of OUVPSDE
ouvpsde = OUVPSDE(beta_min=0.1, beta_max=1.0, stiffness=1.0, N=1000)
Using the instance to sample from the prior distribution
sample_shape = (10, 3, 64, 64) # Example shape y = torch.zeros(sample_shape) # Example steady-state mean sample = ouvpsde.prior_sampling(shape=sample_shape, y=y)
Getting marginal probabilities
mean, std = ouvpsde.marginal_prob(x0=torch.zeros(sample_shape), t=0.5, y=y)
copy()
Abstract SDE classes, Reverse SDE, and VE/VP SDEs.
Taken and adapted from https://github.com/yang-song/score_sde_pytorch and https://github.com/sp-uhh/sgmse
This module defines the abstract base class SDE for stochastic differential equations (SDEs) and its derived classes for specific types of SDEs. The main purpose of these classes is to provide a framework for implementing and working with SDEs, including the ability to discretize them, generate samples from their distributions, and compute marginal probabilities.
The SDE class is an abstract class that defines the necessary methods and properties that must be implemented by any concrete SDE subclass.
The OUVESDE and OUVPSDE classes are concrete implementations of the Ornstein-Uhlenbeck SDEs with variance exploding and variance preserving characteristics, respectively.
N
number of discretization time steps.
- Parameters:N – int, number of discretization time steps for the SDE.
- Returns: A class that implements the specified SDE functionality.
- Yields: None
- Raises:
- NotImplementedError – If any of the abstract methods are not implemented
- by a subclass. –
################# Examples
Example usage of the OUVPSDE class
sde = OUVPSDE(beta_min=0.1, beta_max=1.0, stiffness=1.0, N=1000) x = torch.tensor([[0.0]]) y = torch.tensor([[1.0]]) t = torch.tensor([0.5]) drift, diffusion = sde.sde(x, t, y) print(drift, diffusion)
Example usage of the OUVESDE class
ouve_sde = OUVESDE(theta=1.5, sigma_min=0.05, sigma_max=0.5, N=1000) x0 = torch.tensor([[0.0]]) y = torch.tensor([[1.0]]) t = torch.tensor([0.5]) mean, std = ouve_sde.marginal_prob(x0, t, y) print(mean, std)
marginal_prob(x0, t, y)
Calculate the marginal distribution of the SDE.
This method computes the parameters that define the marginal distribution of the stochastic differential equation (SDE), denoted as ( p_t(x|args) ). The specific implementation should return the mean and standard deviation of the marginal distribution.
- Parameters:
- x – A tensor representing the current state.
- t – A float representing the time step.
- *args – Additional arguments required for the calculation, such as the steady-state mean y.
- Returns:
- mean: The mean of the marginal distribution.
- std: The standard deviation of the marginal distribution.
- Return type: A tuple containing
################# Examples
>>> sde = OUVPSDE(beta_min=0.1, beta_max=0.5)
>>> mean, std = sde.marginal_prob(x=torch.tensor([0.0]), t=0.5, y=1.0)
>>> print(mean, std)
##
N
prior_logp(z)
Compute log-density of the prior distribution.
This method is useful for computing the log-likelihood via probability flow ODE. It should be implemented in subclasses to provide the actual log-density computation for the prior distribution.
- Parameters:z – A tensor representing the latent code for which the log-density is to be computed.
- Returns: A scalar tensor representing the log-density of the prior distribution evaluated at the given latent code.
- Return type: log probability density
- Raises:NotImplementedError – If the method is not implemented in a subclass.
################# Examples
Assuming model is an instance of a subclass that implements
prior_logp and latent_code is a tensor representing the
latent variable.
log_density = model.prior_logp(latent_code)
prior_sampling(shape, y)
Generate one sample from the prior distribution,
$p_T(x|args)$ with shape shape.
The method generates a sample at the end time T of the SDE, which can be influenced by the steady-state mean y. If the provided shape does not match the shape of y, a warning will be issued, and the target shape will be ignored.
- Parameters:
- shape – The desired shape of the generated sample.
- y – The steady-state mean to influence the sample generation.
- Returns: A tensor representing a sample drawn from the prior distribution.
- Raises:UserWarning – If the provided shape does not match the shape of y.
################# Examples
>>> sde = OUVPSDE(beta_min=0.1, beta_max=0.5)
>>> y = torch.randn(10, 3, 32, 32) # Example steady-state mean
>>> sample = sde.prior_sampling((10, 3, 32, 32), y)
>>> print(sample.shape) # Output: torch.Size([10, 3, 32, 32])
sde(x, t, y)
Abstract SDE classes, Reverse SDE, and VE/VP SDEs.
Taken and adapted from: https://github.com/yang-song/score_sde_pytorch and https://github.com/sp-uhh/sgmse
This module defines the abstract base class for Stochastic Differential Equations (SDEs) and implements specific SDE types, including the Ornstein-Uhlenbeck Variance Exploding SDE (OUVESDE) and Ornstein-Uhlenbeck Variance Preserving SDE (OUVPSDE). These classes facilitate the modeling and simulation of SDEs, particularly for applications in generative modeling and diffusion processes.
The SDE class serves as an abstract base for specific implementations, defining the core methods that must be implemented by derived classes, such as marginal_prob, prior_sampling, and prior_logp. It also provides methods for discretizing the SDE and creating reverse-time SDEs.
N
Number of discretization time steps.
Type: int
Parameters:N (int) – Number of discretization time steps.
################# Examples
Example of using the OUVPSDE class
sde = OUVPSDE(beta_min=0.1, beta_max=1.0, stiffness=1, N=1000) x0 = torch.randn(10, 3, 32, 32) # Example input tensor y = torch.randn(10, 3, 32, 32) # Example steady-state mean t = torch.tensor(0.5) # Example time mean, std = sde.marginal_prob(x0, t, y)
Example of reverse-time SDE creation
score_model = lambda x, t,
*
args: -x # Dummy score model reverse_sde = sde.reverse(score_model)
##
N