espnet2.enh.layers.dcunet.DCUNet
espnet2.enh.layers.dcunet.DCUNet
class espnet2.enh.layers.dcunet.DCUNet(dcunet_architecture: str = 'DilDCUNet-v2', dcunet_time_embedding: str = 'gfp', dcunet_temb_layers_global: int = 2, dcunet_temb_layers_local: int = 1, dcunet_temb_activation: str = 'silu', dcunet_time_embedding_complex: bool = False, dcunet_fix_length: str = 'pad', dcunet_mask_bound: str = 'none', dcunet_norm_type: str = 'bN', dcunet_activation: str = 'relu', embed_dim: int = 128, **kwargs)
Bases: Module
DCUNet model for speech enhancement.
This class implements the DCUNet architecture as proposed in S. Welker et al., “Speech Enhancement with Score-Based Generative Models in the Complex STFT Domain”. The architecture is designed for enhancing speech signals in the complex domain using deep learning techniques.
architecture
Name of the DCUNet architecture to use.
- Type: str
fix_length_mode
Method to handle input length; can be ‘pad’, ‘trim’, or None.
- Type: str
norm_type
Type of normalization to apply; can be ‘bN’ or ‘CbN’.
- Type: str
activation
Activation function to use in the model.
- Type: str
input_channels
Number of input channels, typically 2 for complex inputs.
- Type: int
time_embedding
Type of time embedding to use, such as ‘gfp’ or ‘ds’.
- Type: str
time_embedding
Indicates if the time embedding is complex-valued.
- Type: bool
temb_layers_global
Number of global time embedding layers.
- Type: int
temb_layers_local
Number of local time embedding layers.
- Type: int
temb_activation
Activation function for time embedding layers.
- Type: str
embed
Sequential container for time embedding layers.
- Type: nn.Sequential
encoders
List of encoder blocks.
- Type: nn.ModuleList
decoders
List of decoder blocks.
- Type: nn.ModuleList
output_layer
Output layer of the network.
Type: nn.Module
Parameters:
- dcunet_architecture (str) – The architecture of the DCUNet to use.
- dcunet_time_embedding (str) – The type of time embedding to use.
- dcunet_temb_layers_global (int) – Number of global time embedding layers.
- dcunet_temb_layers_local (int) – Number of local time embedding layers.
- dcunet_temb_activation (str) – Activation function for time embedding.
- dcunet_time_embedding_complex (bool) – Whether to use complex time embedding.
- dcunet_fix_length (str) – Method to handle input length (‘pad’, ‘trim’, or ‘none’).
- dcunet_mask_bound (str) – Mask bounding strategy (‘none’ or others).
- dcunet_norm_type (str) – Normalization type (‘bN’ or ‘CbN’).
- dcunet_activation (str) – Activation function for the network.
- embed_dim (int) – Dimension of the embedding layer.
- **kwargs – Additional keyword arguments for customization.
Returns: Output tensor of the model.
Return type: Tensor
########### Examples
>>> net = DCUNet()
>>> dnn_input = torch.randn(4, 2, 257, 256) + 1j * torch.randn(4, 2, 257, 256)
>>> score = net(dnn_input, torch.randn(4))
>>> print(score.shape)
torch.Size([4, 1, n_fft, frames])
####### NOTE Input shape is expected to be (batch, nfreqs, time), where nfreqs - 1 is divisible by the frequency strides of the encoders, and time - 1 is divisible by the time strides of the encoders.
- Raises:NotImplementedError – If the mask bounding method is not implemented.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
fix_input_dims(x)
Adjust the input dimensions to be compatible with DCUNet.
This method pads or trims the input tensor x to ensure its shape is compatible with the architecture of DCUNet. Specifically, it checks that the frequency and time dimensions are divisible by the stride products of the encoder layers. The method operates based on the fix_length_mode attribute, which determines whether to pad or trim the input tensor.
- Parameters:x (Tensor) – The input tensor of shape (batch, channels, freq, time).
- Returns: The adjusted input tensor with compatible dimensions.
- Return type: Tensor
- Raises:
- TypeError – If the input shape is not compatible with the expected frequency and time dimensions.
- ValueError – If an unknown fix_length_mode is specified.
########### Examples
>>> import torch
>>> dcu_net = DCUNet()
>>> input_tensor = torch.randn(4, 2, 258, 256) # Example shape
>>> adjusted_tensor = dcu_net.fix_input_dims(input_tensor)
>>> print(adjusted_tensor.shape) # Output shape will be adjusted
fix_output_dims(out, x)
Adjusts the output dimensions to match the original input shape.
This method fixes the shape of the output tensor out to the original shape of the input tensor x by padding or cropping, based on the specified length mode. It ensures that the output dimensions are compatible with the expected output shape for further processing.
- Parameters:
- out (Tensor) – The output tensor from the model. It is expected to be a tensor with dimensions that may differ from the input tensor.
- x (Tensor) – The original input tensor whose shape will be used to adjust the output tensor.
- Returns: The adjusted output tensor, with its shape modified to match the original input tensor’s shape.
- Return type: Tensor
####### NOTE The method uses padding if the output is shorter than the input and crops the output if it is longer. The specific behavior is controlled by the fix_length_mode attribute of the class.
########### Examples
>>> input_tensor = torch.randn(4, 2, 257, 256)
>>> output_tensor = torch.randn(4, 2, 255, 256)
>>> fixed_output = self.fix_output_dims(output_tensor, input_tensor)
>>> fixed_output.shape
torch.Size([4, 2, 257, 256]) # Shape matches the input tensor
forward(spec, t) → Tensor
Process input through the DCUNet architecture.
The input shape is expected to be $(batch, nfreqs, time)$, where $nfreqs - 1$ must be divisible by the product of frequency strides from the encoders, and $time - 1$ must be divisible by the product of time strides from the encoders.
- Parameters:spec (Tensor) – A complex spectrogram tensor. This can be a 1D, 2D, or 3D tensor, with the time dimension last.
- Returns: The output tensor, which has a shape of (batch, time) or (time).
- Return type: Tensor
########### Examples
>>> net = DCUNet()
>>> dnn_input = torch.randn(4, 2, 257, 256) + 1j * torch.randn(4, 2, 257, 256)
>>> score = net(dnn_input, torch.randn(4))
>>> print(score.shape)
torch.Size([4, 2, 257, 256])
####### NOTE The input tensor must have a valid shape that complies with the model’s stride requirements. If the shape is invalid, a TypeError will be raised.