espnet2.enh.layers.dcunet.BatchNorm
espnet2.enh.layers.dcunet.BatchNorm
class espnet2.enh.layers.dcunet.BatchNorm(num_features: int, eps: float = 1e-05, momentum: float | None = 0.1, affine: bool = True, track_running_stats: bool = True, device=None, dtype=None)
Bases: _BatchNorm
Batch Normalization layer for complex-valued inputs.
This class extends the standard BatchNorm class to handle complex input tensors. It checks the input dimensions to ensure that the inputs are compatible with the expected dimensionality for batch normalization.
num_features
Number of features in the input.
- Type: int
eps
A value added to the denominator for numerical stability.
- Type: float
momentum
The value used for the running_mean and running_var computation.
- Type: float
affine
If True, this module has learnable parameters.
- Type: bool
track_running_stats
If True, tracks running statistics.
Type: bool
Parameters:
- num_features (int) – Number of features in the input.
- eps (float , optional) – A value added to the denominator for numerical stability (default: 1e-5).
- momentum (float , optional) – The value used for the running_mean and running_var computation (default: 0.1).
- affine (bool , optional) – If True, this module has learnable parameters (default: True).
- track_running_stats (bool , optional) – If True, tracks running statistics (default: False).
Raises:ValueError – If the input tensor dimensions are not 3D or 4D.
Examples
>>> batch_norm = BatchNorm(num_features=64)
>>> input_tensor = torch.randn(10, 64, 32, 32) # 4D input
>>> output_tensor = batch_norm(input_tensor)
NOTE
The input tensor must have a shape of (N, C, H, W) or (N, C, L) where N is the batch size, C is the number of channels, H and W are the height and width of the input for 4D tensors, or L is the length for 3D tensors.
Initialize internal Module state, shared by both nn.Module and ScriptModule.