espnet2.enh.layers.dcunet.ComplexBatchNorm
espnet2.enh.layers.dcunet.ComplexBatchNorm
class espnet2.enh.layers.dcunet.ComplexBatchNorm(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
Bases: Module
Complex Batch Normalization layer.
This layer normalizes complex-valued inputs by applying batch normalization independently to the real and imaginary parts. It can be used to improve training stability and convergence in deep learning models that handle complex data.
num_features
Number of features (channels) in the input.
- Type: int
eps
A small value added for numerical stability during division.
- Type: float
momentum
The momentum for the moving average of the running statistics.
- Type: float
affine
If True, this layer has learnable affine parameters.
- Type: bool
track_running_stats
If True, this layer tracks running statistics (mean and variance) during training.
Type: bool
Parameters:
- num_features (int) – Number of features (channels) in the input.
- eps (float , optional) – A small value added for numerical stability (default: 1e-5).
- momentum (float , optional) – Momentum for running statistics (default: 0.1).
- affine (bool , optional) – If True, add learnable parameters to the layer (default: True).
- track_running_stats (bool , optional) – If True, track running statistics (default: False).
Raises:AssertionError – If the dimensions of the input tensors are not as expected.
########
Example
>>> batch_norm = ComplexBatchNorm(num_features=64)
>>> input_tensor = torch.randn(32, 64, 128) + 1j * torch.randn(32, 64, 128)
>>> output_tensor = batch_norm(input_tensor)
######## NOTE The layer expects the input to be a complex tensor with separate real and imaginary parts. The output will also be a complex tensor.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
extra_repr()
Applies Batch Normalization for complex-valued inputs.
This layer normalizes the input using the mean and variance of each feature across the batch. It can optionally apply affine transformation and track running statistics.
num_features
Number of features (channels).
- Type: int
eps
A value added to the denominator for numerical stability.
- Type: float
momentum
Momentum for the moving average.
- Type: float
affine
If True, this layer has learnable parameters.
- Type: bool
track_running_stats
If True, tracks the running mean and variance during training.
Type: bool
Parameters:
- num_features (int) – Number of features (channels) in the input.
- eps (float , optional) – A value added to the denominator for numerical stability (default: 1e-5).
- momentum (float , optional) – Momentum for the moving average (default: 0.1).
- affine (bool , optional) – If True, this layer has learnable parameters (default: True).
- track_running_stats (bool , optional) – If True, tracks the running mean and variance during training (default: False).
Example
>>> layer = ComplexBatchNorm(num_features=64)
>>> input_tensor = torch.randn(10, 64, 32, 32) + 1j * torch.randn(10, 64, 32, 32)
>>> output_tensor = layer(input_tensor)
- Raises:AssertionError – If the real and imaginary parts of the input do not have the same shape.
forward(x)
Passes the input through the DCUNet architecture.
This method processes the input complex spectrogram tensor and embeds the time information before passing it through the encoder-decoder architecture. The expected input shape is $(batch, nfreqs, time)$, where $nfreqs - 1$ must be divisible by the frequency strides of the encoders, and $time - 1$ must be divisible by the time strides of the encoders.
- Parameters:spec (Tensor) – A complex spectrogram tensor of shape (batch, nfreqs, time). The last dimension should represent time.
- Returns: The output tensor of shape (batch, time) or (time) : depending on the architecture and input dimensions.
- Return type: Tensor
- Raises:TypeError – If the input shape is not compatible with the expected dimensions.
########
Example
>>> net = DCUNet()
>>> dnn_input = torch.randn(4, 2, 257, 256) + 1j * torch.randn(4, 2, 257, 256)
>>> score = net(dnn_input, torch.randn(4))
>>> print(score.shape) # Output shape will depend on the architecture
######## NOTE Ensure that the input tensor dimensions conform to the requirements outlined in the method description. The method will perform checks and raise errors if the dimensions do not match the expected format.
reset_parameters()
Applies Batch Normalization over a complex input.
This layer normalizes the input across the batch dimension and is designed specifically for complex-valued inputs. The normalization is performed separately for the real and imaginary parts of the input.
num_features
Number of features or channels.
- Type: int
eps
A value added to the denominator for numerical stability.
- Type: float
momentum
Momentum for the moving average.
- Type: float
affine
If True, this layer has learnable parameters.
- Type: bool
track_running_stats
If True, this layer tracks the running mean and variance.
Type: bool
Parameters:
- num_features (int) – Number of features or channels in the input.
- eps (float , optional) – Default is 1e-5.
- momentum (float , optional) – Default is 0.1.
- affine (bool , optional) – Default is True.
- track_running_stats (bool , optional) – Default is False.
########
Example
>>> layer = ComplexBatchNorm(num_features=32)
>>> input_tensor = torch.randn(64, 32, 10) + 1j * torch.randn(64, 32, 10)
>>> output = layer(input_tensor)
######## NOTE The layer can be used in both training and evaluation modes. In training mode, it normalizes the input using the current batch statistics and updates the running statistics. In evaluation mode, it uses the running statistics for normalization.
reset_running_stats()
Applies Batch Normalization over complex-valued inputs.
This class implements batch normalization for complex-valued inputs, allowing for affine transformations. The forward method normalizes the input using batch statistics during training and running statistics during evaluation.
num_features
Number of features (channels) in the input.
- Type: int
eps
A small value added for numerical stability.
- Type: float
momentum
Momentum for the running statistics.
- Type: float
affine
Whether to include learnable parameters.
- Type: bool
track_running_stats
Whether to track running statistics.
- Type: bool
Wrr
Weight for the real-real part.
- Type: torch.nn.Parameter
Wri
Weight for the real-imaginary part.
- Type: torch.nn.Parameter
Wii
Weight for the imaginary-imaginary part.
- Type: torch.nn.Parameter
Br
Bias for the real part.
- Type: torch.nn.Parameter
Bi
Bias for the imaginary part.
- Type: torch.nn.Parameter
RMr
Running mean for the real part.
- Type: torch.Tensor
RMi
Running mean for the imaginary part.
- Type: torch.Tensor
RVrr
Running variance for the real-real part.
- Type: torch.Tensor
RVri
Running variance for the real-imaginary part.
- Type: torch.Tensor
RVii
Running variance for the imaginary-imaginary part.
- Type: torch.Tensor
num_batches_tracked
Number of batches processed.
Type: torch.Tensor
Parameters:
- num_features (int) – Number of features (channels) in the input.
- eps (float) – A small value added for numerical stability (default: 1e-5).
- momentum (float) – Momentum for the running statistics (default: 0.1).
- affine (bool) – Whether to include learnable parameters (default: True).
- track_running_stats (bool) – Whether to track running statistics (default: False).
########
Example
>>> import torch
>>> layer = ComplexBatchNorm(num_features=3)
>>> input_tensor = torch.randn(10, 3, 5) + 1j * torch.randn(10, 3, 5)
>>> output = layer(input_tensor)
######## NOTE The input tensor should have a shape of (batch_size, num_features, …).
- Raises:ValueError – If the input dimensions are incorrect.
reset_running_stats()
Resets the running mean and variance.
reset_parameters()
Resets learnable parameters to their initial state.