espnet2.asr_transducer.activation.Swish
espnet2.asr_transducer.activation.Swish
class espnet2.asr_transducer.activation.Swish(beta: float = 1.0, use_builtin: bool = False)
Bases: Module
Swish activation definition.
Swish(x) = (beta * x) * sigmoid(x), where beta = 1 defines standard Swish activation.
References
- https://arxiv.org/abs/2108.12943
- https://arxiv.org/abs/1710.05941
- E-swish variant: https://arxiv.org/abs/1801.07145
beta
Beta parameter for E-Swish. Should be greater than or equal to
- Type: float
- If beta < 1, standard Swish is used.
use_builtin
Whether to use the built-in PyTorch function if available.
Type: bool
Parameters:
- beta – Beta parameter for E-Swish activation. (beta >= 1)
- use_builtin – If True, utilize the built-in PyTorch SiLU function if available.
####### Examples
>>> swish = Swish(beta=1.0)
>>> input_tensor = torch.tensor([0.0, 1.0, 2.0])
>>> output_tensor = swish(input_tensor)
>>> print(output_tensor)
tensor([0.0000, 0.7311, 1.7616])
>>> swish_e = Swish(beta=0.5, use_builtin=True)
>>> output_tensor_e = swish_e(input_tensor)
>>> print(output_tensor_e)
tensor([0.0000, 0.5000, 1.0000])
NOTE
The Swish activation function has been shown to improve performance in certain deep learning tasks compared to traditional activation functions like ReLU.
- Raises:ValueError – If beta is less than 1.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x: Tensor) → Tensor
Activation functions for Transducer models.
This module provides various activation functions, including FTSwish, Mish, Smish, and Swish, suitable for use in transducer models. Each activation function is defined as a class that inherits from torch.nn.Module, allowing for seamless integration into PyTorch models.
The main function, get_activation, returns the desired activation function based on the provided type and parameters.
- FTSwish
Flatten-T Swish activation function.
- Mish
Mish activation function.
- Smish
Smish activation function.
- Swish
Swish activation function.
- Parameters:
- activation_type (str) – The type of activation function to retrieve.
- ftswish_threshold (float) – Threshold value for FTSwish activation formulation (default: -0.2).
- ftswish_mean_shift (float) – Mean shifting value for FTSwish activation (default: 0.0).
- hardtanh_min_val (int) – Minimum value for HardTanh (default: -1.0).
- hardtanh_max_val (int) – Maximum value for HardTanh (default: 1.0).
- leakyrelu_neg_slope (float) – Negative slope for LeakyReLU (default: 0.01).
- smish_alpha (float) – Alpha value for Smish activation (default: 1.0).
- smish_beta (float) – Beta value for Smish activation (default: 1.0).
- softplus_beta (float) – Beta value for softplus in Mish (default: 1.0).
- softplus_threshold (int) – Threshold for softplus in Mish (default: 20).
- swish_beta (float) – Beta value for Swish (default: 1.0).
- Returns: The specified activation function as a PyTorch module.
- Return type: torch.nn.Module
####### Examples
>>> activation = get_activation('smish', smish_alpha=1.0, smish_beta=1.0)
>>> output = activation(torch.tensor([-1.0, 0.0, 1.0]))
- Raises:ValueError – If the specified activation type is not supported.
NOTE
Ensure that the input tensor is of type torch.Tensor when using the activation functions.