espnet2.enh.layers.dcunet.OnReIm
espnet2.enh.layers.dcunet.OnReIm
class espnet2.enh.layers.dcunet.OnReIm(module_cls, *args, **kwargs)
Bases: Module
Module to apply a real-valued operation to the real and imaginary parts.
This class allows for applying two different modules to the real and imaginary components of a complex-valued input tensor. It is useful for operations that need to be performed separately on the real and imaginary parts of complex numbers, and then combine the results back into a complex representation.
re_module
Module applied to the real part of the input.
- Type: nn.Module
im_module
Module applied to the imaginary part of the input.
Type: nn.Module
Parameters:
- module_cls (callable) – A class or function that returns a Torch module/functional. It is called twice with *args, **kwargs to create two separate modules for real and imaginary parts.
- *args – Variable length argument list passed to the module_cls.
- **kwargs – Arbitrary keyword arguments passed to the module_cls.
####### Examples
>>> real_module = nn.Linear(10, 5)
>>> imag_module = nn.Linear(10, 5)
>>> model = OnReIm(lambda: real_module, lambda: imag_module)
>>> complex_input = torch.randn(2, 10) + 1j * torch.randn(2, 10)
>>> output = model(complex_input)
>>> output.shape
torch.Size([2, 5])
- Raises:ValueError – If the input tensor does not have a complex type.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
DCUNet model for complex spectrogram processing.
This implementation is based on the architecture described in the paper “Speech Enhancement with Score-Based Generative Models in the Complex STFT Domain”. It leverages deep learning techniques for audio signal processing, specifically focusing on complex-valued input tensors.
- Parameters:
- dcunet_architecture (str) – The architecture to use for the DCUNet model. Default is “DilDCUNet-v2”.
- dcunet_time_embedding (str) – The type of time embedding to use. Options are “gfp” for Gaussian Fourier Projection or “ds” for Diffusion Step embedding. Default is “gfp”.
- dcunet_temb_layers_global (int) – Number of global time embedding layers. Default is 2.
- dcunet_temb_layers_local (int) – Number of local time embedding layers. Default is 1.
- dcunet_temb_activation (str) – Activation function for time embedding layers. Default is “silu”.
- dcunet_time_embedding_complex (bool) – Whether to use complex time embedding. Default is False.
- dcunet_fix_length (str) – Method to handle input length. Options are “pad” or “trim”. Default is “pad”.
- dcunet_mask_bound (str) – Mask bounding option. Default is “none”.
- dcunet_norm_type (str) – Type of normalization to apply. Default is “bN”.
- dcunet_activation (str) – Activation function for the model. Default is “relu”.
- embed_dim (int) – Dimensionality of the embedding. Default is 128.
- Returns: None
####### Examples
>>> model = DCUNet()
>>> input_tensor = torch.randn(4, 2, 257, 256) + 1j * torch.randn(4, 2, 257, 256)
>>> output = model(input_tensor, torch.randn(4))
>>> print(output.shape)
torch.Size([4, 1, 257, 256])
NOTE
This model expects the input tensor shape to be (batch, nfreqs, time), where nfreqs - 1 is divisible by the frequency strides of the encoders and time - 1 is divisible by the time strides of the encoders.
- Raises:
- TypeError – If input tensor dimensions are incompatible.
- NotImplementedError – If an unsupported architecture or mask bounding is specified.