espnet2.enh.layers.ncsnpp_utils.normalization.get_normalization
Less than 1 minute
espnet2.enh.layers.ncsnpp_utils.normalization.get_normalization
espnet2.enh.layers.ncsnpp_utils.normalization.get_normalization(config, conditional=False)
Obtain normalization modules from the config file.
This function retrieves the appropriate normalization layer based on the provided configuration. It supports both conditional and non-conditional normalization layers.
- Parameters:
- config (dict) – A configuration dictionary containing model parameters. It must have the key ‘normalization’ to specify the type of normalization to use.
- conditional (bool) – A flag indicating whether to return a conditional normalization layer. Defaults to False.
- Returns: A normalization layer class or a partial function for the specified normalization type.
- Return type: callable
- Raises:
- NotImplementedError – If a conditional normalization type is specified that is not implemented.
- ValueError – If the specified normalization type is unknown.
Examples
>>> config = {'model': {'normalization': 'InstanceNorm'}}
>>> normalization_layer = get_normalization(config)
>>> print(normalization_layer) # Output: <class 'torch.nn.modules.normalization.InstanceNorm2d'>
>>> config = {'model': {'normalization': 'InstanceNorm++'}}
>>> normalization_layer = get_normalization(config, conditional=True)
>>> print(normalization_layer) # Output: functools.partial object for ConditionalInstanceNorm2dPlus