espnet2.asr_transducer.activation.Mish
espnet2.asr_transducer.activation.Mish
class espnet2.asr_transducer.activation.Mish(softplus_beta: float = 1.0, softplus_threshold: int = 20, use_builtin: bool = False)
Bases: Module
Mish activation definition.
The Mish activation function is defined as:
Mish(x) = x * tanh(softplus(x))
Where softplus is defined as:
softplus(x) = log(1 + exp(x))
This activation function has been shown to improve the performance of deep neural networks in various tasks.
Reference: https://arxiv.org/abs/1908.08681.
- Parameters:
- softplus_beta (float) – Beta value for the softplus activation formulation. Typically, this should satisfy the condition 0 < softplus_beta < 2.
- softplus_threshold (float) – Values above this threshold revert to a linear function. Typically, it should satisfy the condition 10 < softplus_threshold < 20.
- use_builtin (bool) – Flag to indicate whether to use the built-in PyTorch Mish activation function if available (introduced in PyTorch 1.9).
Examples
>>> mish = Mish(softplus_beta=1.0, softplus_threshold=20)
>>> input_tensor = torch.tensor([-1.0, 0.0, 1.0])
>>> output_tensor = mish(input_tensor)
>>> print(output_tensor)
tensor([-0.3133, 0.0000, 0.7311])
NOTE
The Mish activation is continuously differentiable and non-monotonic, which can lead to better training dynamics.
- Raises:
- AssertionError – If softplus_beta is not in the valid range or if
- softplus_threshold –
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x: Tensor) → Tensor
Activation functions for Transducer models.
This module provides various activation functions, including FTSwish, Mish, Smish, and Swish, which can be utilized in neural network architectures. The activation functions are implemented as subclasses of torch.nn.Module and can be instantiated with configurable parameters.
Classes: : - FTSwish: Implements the Flatten-T Swish activation function.
- Mish: Implements the Mish activation function.
- Smish: Implements the Smish activation function.
- Swish: Implements the Swish activation function.
Example usage: : # Create an instance of FTSwish ftswish = FTSwish(threshold=-0.1, mean_shift=0.5) <br/>
Apply the activation function to a tensor
input_tensor = torch.tensor([-1.0, 0.0, 1.0]) output_tensor = ftswish(input_tensor) print(output_tensor) <br/>
Get a specific activation function using the helper function
activation_function = get_activation(“mish”, softplus_beta=1.0) output_tensor = activation_function(input_tensor) print(output_tensor)