espnet2.asr_transducer.activation.get_activation
espnet2.asr_transducer.activation.get_activation
espnet2.asr_transducer.activation.get_activation(activation_type: str, ftswish_threshold: float = -0.2, ftswish_mean_shift: float = 0.0, hardtanh_min_val: int = -1.0, hardtanh_max_val: int = 1.0, leakyrelu_neg_slope: float = 0.01, smish_alpha: float = 1.0, smish_beta: float = 1.0, softplus_beta: float = 1.0, softplus_threshold: int = 20, swish_beta: float = 1.0) → Module
Return the specified activation function as a PyTorch module.
This function provides a way to obtain various activation functions based on the specified type. It supports standard activations such as ReLU, Tanh, and others, as well as custom formulations like FTSwish, Mish, Smish, and Swish. The parameters can be adjusted to customize the behavior of certain activation functions.
- Parameters:
- activation_type (str) – The type of activation function to return. Options include: ‘ftswish’, ‘hardtanh’, ‘leaky_relu’, ‘mish’, ‘relu’, ‘selu’, ‘smish’, ‘swish’, ‘tanh’, ‘identity’.
- ftswish_threshold (float) – Threshold value for FTSwish activation formulation.
- ftswish_mean_shift (float) – Mean shifting value for FTSwish activation formulation.
- hardtanh_min_val (int) – Minimum value of the linear region range for HardTanh.
- hardtanh_max_val (int) – Maximum value of the linear region range for HardTanh.
- leakyrelu_neg_slope (float) – Negative slope value for LeakyReLU activation.
- smish_alpha (float) – Alpha value for Smish activation formulation.
- smish_beta (float) – Beta value for Smish activation formulation.
- softplus_beta (float) – Beta value for softplus activation formulation in Mish.
- softplus_threshold (int) – Values above this revert to a linear function in Mish.
- swish_beta (float) – Beta value for Swish variant formulation.
- Returns: A PyTorch activation function module corresponding to the specified activation_type.
- Return type: torch.nn.Module
- Raises:KeyError – If the specified activation_type is not recognized.
Examples
>>> activation = get_activation('relu')
>>> x = torch.tensor([-1.0, 0.0, 1.0])
>>> activation(x)
tensor([0., 0., 1.])
>>> activation = get_activation('ftswish', ftswish_threshold=-0.1)
>>> activation(x)
tensor([-0.0000, 0.0000, 1.0000])
NOTE
Ensure that the chosen activation function is compatible with your model architecture. For custom functions like FTSwish, additional parameters may be required for proper behavior.