espnet2.gan_codec.shared.encoder.snake_activation.Snake1d
espnet2.gan_codec.shared.encoder.snake_activation.Snake1d
class espnet2.gan_codec.shared.encoder.snake_activation.Snake1d
Bases: Module
Snake1d is a PyTorch neural network module that applies the Snake activation
function to the input tensor. The Snake activation function enhances the non-linearity of the neural network by applying a sine transformation to the input values.
alpha
A tensor that controls the shape of the activation function, initialized to ones with shape (1, 1, 1).
Type: torch.Tensor
Parameters:x (torch.Tensor) – The input tensor to which the Snake activation will be applied. It is expected to be of shape (N, C, H, W), where N is the batch size, C is the number of channels, and H, W are the height and width of the input, respectively.
Returns: The output tensor after applying the Snake activation, : preserving the same shape as the input tensor.
Return type: torch.Tensor
####### Examples
>>> import torch
>>> snake1d = Snake1d()
>>> input_tensor = torch.randn(2, 3, 4) # Example input tensor
>>> output_tensor = snake1d(input_tensor)
>>> print(output_tensor.shape) # Should print: torch.Size([2, 3, 4])
NOTE
The alpha parameter can be modified to achieve different activation behaviors. By default, it is set to a tensor of ones.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Applies the Snake transformation to the input tensor.
The Snake transformation is a non-linear operation defined by the function snake, which modifies the input tensor x based on a parameter alpha. This method is useful in various neural network architectures for enhancing feature extraction.
alpha
A tensor containing the alpha parameters for the Snake transformation, initialized to a tensor of ones.
Type: torch.Tensor
Parameters:x (torch.Tensor) – The input tensor of shape (batch_size, channels, height, width) or (batch_size, channels, -1) where batch_size is the number of samples, channels is the number of channels, and height and width are spatial dimensions.
Returns: The transformed tensor after applying the Snake operation, : with the same shape as the input tensor x.
Return type: torch.Tensor
####### Examples
>>> model = Snake1d()
>>> input_tensor = torch.randn(2, 3, 4, 4) # (batch_size, channels, height, width)
>>> output_tensor = model(input_tensor)
>>> print(output_tensor.shape) # Should be (2, 3, 4, 4)
NOTE
The alpha parameter can be modified for different transformation behaviors.