espnet2.asr_transducer.activation.FTSwish
espnet2.asr_transducer.activation.FTSwish
class espnet2.asr_transducer.activation.FTSwish(threshold: float = -0.2, mean_shift: float = 0)
Bases: Module
Flatten-T Swish activation definition.
FTSwish(x) = x * sigmoid(x) + threshold : where FTSwish(x) < 0 = threshold.
This activation function is designed to provide a smooth transition for values below a specified threshold, allowing for improved gradient flow during training. It can be particularly useful in neural networks for tasks such as speech recognition.
Reference: https://arxiv.org/abs/1812.06247
- Parameters:
- threshold (float) – Threshold value for FTSwish activation formulation. Must be less than 0.
- mean_shift (float) – Mean shifting value for FTSwish activation formulation. Applied only if not equal to 0 (disabled by default).
####### Examples
>>> ftswish = FTSwish(threshold=-0.5, mean_shift=0.1)
>>> input_tensor = torch.tensor([-1.0, 0.0, 1.0])
>>> output_tensor = ftswish(input_tensor)
>>> print(output_tensor)
tensor([-0.5000, 0.5000, 1.0000])
- Raises:AssertionError – If the threshold is not less than 0.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x: Tensor) → Tensor
Activation functions for Transducer models.
This module provides various activation functions, including FTSwish, Mish, Smish, and Swish, suitable for use in transducer models. Each activation function is defined as a class that inherits from torch.nn.Module, allowing for seamless integration into PyTorch models.
The main function, get_activation, returns the desired activation function based on the provided type and parameters.
- FTSwish
Flatten-T Swish activation function.
- Mish
Mish activation function.
- Smish
Smish activation function.
- Swish
Swish activation function.
- Parameters:
- activation_type (str) – The type of activation function to retrieve.
- ftswish_threshold (float) – Threshold value for FTSwish activation formulation (default: -0.2).
- ftswish_mean_shift (float) – Mean shifting value for FTSwish activation (default: 0.0).
- hardtanh_min_val (int) – Minimum value for HardTanh (default: -1.0).
- hardtanh_max_val (int) – Maximum value for HardTanh (default: 1.0).
- leakyrelu_neg_slope (float) – Negative slope for LeakyReLU (default: 0.01).
- smish_alpha (float) – Alpha value for Smish activation (default: 1.0).
- smish_beta (float) – Beta value for Smish activation (default: 1.0).
- softplus_beta (float) – Beta value for softplus in Mish (default: 1.0).
- softplus_threshold (int) – Threshold for softplus in Mish (default: 20).
- swish_beta (float) – Beta value for Swish (default: 1.0).
- Returns: The specified activation function as a PyTorch module.
- Return type: torch.nn.Module
####### Examples
>>> activation = get_activation('smish', smish_alpha=1.0, smish_beta=1.0)
>>> output = activation(torch.tensor([-1.0, 0.0, 1.0]))
- Raises:ValueError – If the specified activation type is not supported.
NOTE
Ensure that the input tensor is of type torch.Tensor when using the activation functions.