espnet2.enh.layers.ncsnpp.NCSNpp
espnet2.enh.layers.ncsnpp.NCSNpp
class espnet2.enh.layers.ncsnpp.NCSNpp(scale_by_sigma=True, nonlinearity='swish', nf=128, ch_mult=(1, 1, 2, 2, 2, 2, 2), num_res_blocks=2, attn_resolutions=(16,), resamp_with_conv=True, conditional=True, fir=True, fir_kernel=[1, 3, 3, 1], skip_rescale=True, resblock_type='biggan', progressive='output_skip', progressive_input='input_skip', progressive_combine='sum', init_scale=0.0, fourier_scale=16, image_size=256, embedding_type='fourier', dropout=0.0, centered=True, **unused_kwargs)
Bases: Module
NCSN++ model, adapted from the score SDE implementation.
This class implements the NCSN++ model as described in the following repositories:
The model utilizes a U-Net architecture with residual blocks and attention mechanisms, suitable for score-based generative modeling.
act
Activation function used in the model.
nf
Number of filters in the model.
num_res_blocks
Number of residual blocks at each resolution.
attn_resolutions
Resolutions at which attention is applied.
conditional
Indicates if the model is conditional or not.
centered
Indicates if the input data is centered.
scale_by_sigma
If True, scales output by the noise level.
all_modules
A list of all modules in the model.
- Parameters:
- scale_by_sigma (bool) – Whether to scale the output by sigma.
- nonlinearity (str) – The nonlinearity to use (“swish” or others).
- nf (int) – Number of filters in the model.
- ch_mult (tuple) – Channel multipliers for each resolution.
- num_res_blocks (int) – Number of residual blocks per resolution.
- attn_resolutions (tuple) – Resolutions where attention is applied.
- resamp_with_conv (bool) – If True, use convolution for resampling.
- conditional (bool) – If True, the model is conditional.
- fir (bool) – If True, use FIR filters for upsampling/downsampling.
- fir_kernel (list) – Kernel size for FIR filters.
- skip_rescale (bool) – If True, apply skip rescaling.
- resblock_type (str) – Type of residual block to use (“biggan” or “ddpm”).
- progressive (str) – Progressive method for combining features.
- progressive_input (str) – Method for combining input features.
- progressive_combine (str) – Method for combining progressive features.
- init_scale (float) – Initial scale for normalization.
- fourier_scale (int) – Scale for Fourier features.
- image_size (int) – Size of the input images.
- embedding_type (str) – Type of embedding to use (“fourier” or “positional”).
- dropout (float) – Dropout rate.
- centered (bool) – If True, input data is centered to [0, 1].
- **unused_kwargs – Additional arguments not used in the initialization.
- Raises:ValueError – If an unknown embedding or resblock type is specified.
######### Examples
>>> model = NCSNpp(scale_by_sigma=True, nf=128)
>>> x = torch.randn(1, 2, 256, 256) # Batch of complex images
>>> time_cond = torch.randn(1, 10) # Time conditioning input
>>> output = model(x, time_cond)
NOTE
The input tensor x should be in the shape of (batch_size, 2, H, W), where H and W are the height and width of the input images.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x, time_cond)
Perform a forward pass through the NCSN++ model.
This method takes the input tensor x and a tensor representing the time conditioning information time_cond, and computes the output of the NCSN++ model by processing the input through a series of neural network layers. The method supports conditional inputs and can be used for both training and inference.
- Parameters:
- x (torch.Tensor) – A complex tensor of shape (B, 2, H, W), where B is the batch size, H is the height, and W is the width of the input image. The tensor contains real and imaginary parts in the first two channels.
- time_cond (torch.Tensor) – A tensor of shape (B, T) representing the time conditioning values, where T is the number of timesteps.
- Returns: A complex tensor of shape (B, 1, H, W) representing the output of the NCSN++ model after processing the input.
- Return type: torch.Tensor
- Raises:ValueError – If the embedding type is unknown.
######### Examples
>>> model = NCSNpp()
>>> input_tensor = torch.randn(8, 2, 256, 256) # Batch of 8
>>> time_cond = torch.randn(8, 100) # Example time conditioning
>>> output = model(input_tensor, time_cond)
>>> print(output.shape) # Output shape should be (8, 1, 256, 256)
NOTE
The input tensor x should be prepared such that the first channel contains the real part and the second channel contains the imaginary part of the complex input.
pad_spec(Y)
Pads the input tensor along the time dimension to ensure that its size
is a multiple of 64. This is useful for maintaining consistent input dimensions for further processing.
- Parameters:Y (torch.Tensor) – The input tensor to be padded. It is expected to have a shape of (batch_size, channels, height, width), where the height corresponds to the time dimension.
- Returns: The padded tensor with the same shape as Y, but with : the time dimension adjusted to be a multiple of 64.
- Return type: torch.Tensor
######### Examples
>>> import torch
>>> pad_spec = NCSNpp().pad_spec
>>> Y = torch.randn(8, 4, 16, 16) # A tensor with time dimension 16
>>> padded_Y = pad_spec(Y)
>>> padded_Y.shape
torch.Size([8, 4, 64, 16]) # Padded to next multiple of 64