espnet2.gan_codec.shared.encoder.seanet.get_norm_module
Less than 1 minute
espnet2.gan_codec.shared.encoder.seanet.get_norm_module
espnet2.gan_codec.shared.encoder.seanet.get_norm_module(module: Module, causal: bool = False, norm: str = 'none', **norm_kwargs) → Module
Return the proper normalization module. If causal is True, this will ensure the
returned module is causal, or return an error if the normalization doesn’t support causal evaluation.
- Parameters:
- module (nn.Module) – The module to which normalization will be applied.
- causal (bool) – Indicates whether to enforce causal normalization.
- norm (str) – The type of normalization to apply. Options include:
- “none”: No normalization.
- “weight_norm”: Weight normalization.
- “spectral_norm”: Spectral normalization.
- “time_layer_norm”: Time layer normalization.
- “layer_norm”: Layer normalization.
- “time_group_norm”: Time group normalization.
- **norm_kwargs – Additional keyword arguments specific to the normalization method.
- Returns: The module with the specified normalization applied.
- Return type: nn.Module
- Raises:
- AssertionError – If the specified normalization type is not in CONV_NORMALIZATIONS.
- ValueError – If causal is True and the specified normalization type does not support causal evaluation (e.g., GroupNorm).
Examples
>>> conv_module = nn.Conv1d(16, 33, kernel_size=3)
>>> norm_module = get_norm_module(conv_module, causal=False, norm="layer_norm")
>>> assert isinstance(norm_module, ConvLayerNorm)
>>> norm_module = get_norm_module(conv_module, causal=True, norm="layer_norm")
>>> assert isinstance(norm_module, ConvLayerNorm)
>>> norm_module = get_norm_module(conv_module, causal=True, norm="time_group_norm")
ValueError: GroupNorm doesn't support causal evaluation.