espnet2.gan_codec.shared.encoder.seanet_2d.get_activation
espnet2.gan_codec.shared.encoder.seanet_2d.get_activation
espnet2.gan_codec.shared.encoder.seanet_2d.get_activation(activation: str | None = None, channels=None, **kwargs)
Get the specified activation function.
This function returns an activation function as a PyTorch module based on the provided name. It supports custom activation functions such as ‘snake’ which requires the number of channels to be specified.
espnet2.gan_codec.shared.encoder.seanet_2d.activation
The name of the activation function to retrieve.
- Type: str
espnet2.gan_codec.shared.encoder.seanet_2d.channels
The number of channels required for specific activation functions (e.g., ‘snake’).
- Type: Optional[int]
espnet2.gan_codec.shared.encoder.seanet_2d.kwargs
Additional parameters for the activation function.
Type: Any
Parameters:
- activation (str) – The name of the activation function to use. Common options include ‘ReLU’, ‘ELU’, ‘LeakyReLU’, etc.
- channels (Optional *[*int ]) – The number of input channels for activation functions that require it.
- **kwargs – Additional keyword arguments for the activation function.
Returns: The corresponding activation function as a PyTorch module.
Return type: nn.Module
Raises:AssertionError – If ‘snake’ is specified without providing the number of channels.
Examples
>>> relu = get_activation('ReLU')
>>> snake_activation = get_activation('snake', channels=64)
NOTE
The function uses getattr to dynamically retrieve the activation function from the torch.nn module. Make sure to pass valid activation names.