espnet2.asr.encoder.beats_encoder.get_activation_fn
Less than 1 minute
espnet2.asr.encoder.beats_encoder.get_activation_fn
espnet2.asr.encoder.beats_encoder.get_activation_fn(activation: str)
Returns the activation function corresponding to activation.
This function maps a string representation of an activation function to its corresponding PyTorch function. Supported activation functions include ReLU, GELU, Tanh, and others.
- Parameters:activation (str) – The name of the activation function. Supported values are: “relu”, “gelu”, “gelu_fast”, “gelu_accurate”, “tanh”, “linear”, and “glu”.
- Returns: The corresponding activation function.
- Return type: Callable
- Raises:RuntimeError – If the specified activation function is not supported.
Examples
>>> relu_fn = get_activation_fn("relu")
>>> output = relu_fn(torch.tensor([-1.0, 0.0, 1.0]))
tensor([0., 0., 1.])
>>> gelu_fn = get_activation_fn("gelu")
>>> output = gelu_fn(torch.tensor([-1.0, 0.0, 1.0]))
tensor([-0.1587, 0.0000, 0.8413])
>>> tanh_fn = get_activation_fn("tanh")
>>> output = tanh_fn(torch.tensor([-1.0, 0.0, 1.0]))
tensor([-0.7616, 0.0000, 0.7616])
NOTE
The activation function “gelu_fast” has been renamed to “gelu_accurate”. If “gelu_fast” is requested, a warning will be issued, and the “gelu_accurate” function will be returned instead.