espnet2.asr.layers.multiconv_cgmlp.MultiConvolutionalGatingMLP
espnet2.asr.layers.multiconv_cgmlp.MultiConvolutionalGatingMLP
class espnet2.asr.layers.multiconv_cgmlp.MultiConvolutionalGatingMLP(size: int, linear_units: int, arch_type: str, kernel_sizes: str, merge_conv_kernel: int, use_non_linear: bool, dropout_rate: float, use_linear_after_conv: bool, activation, gate_activation: str)
Bases: Module
Convolutional Gating MLP (cgMLP).
This class implements a multi-convolutional gating mechanism for MLPs, extending the capabilities of traditional MLPs with convolutional gating units (CSGUs). It can utilize various architectures to combine features from multiple convolutions and apply gating mechanisms for enhanced performance in tasks such as speech recognition.
channel_proj1
A sequential layer projecting input features to a higher-dimensional space with a GELU activation.
- Type: torch.nn.Sequential
csgu
An instance of the MultiConvolutionalSpatialGatingUnit class that performs the convolutional gating.
channel_proj2
A linear layer that projects the output back to the original size.
Type: torch.nn.Linear
Parameters:
- size (int) – The dimensionality of the input features.
- linear_units (int) – The number of linear units in the first projection.
- arch_type (str) – The architecture type for the convolutional gating, can be ‘sum’, ‘weighted_sum’, ‘concat’, or ‘concat_fusion’.
- kernel_sizes (str) – A comma-separated string specifying the sizes of the convolutional kernels to be used.
- merge_conv_kernel (int) – The kernel size for merging convolutions, applicable in ‘concat_fusion’ architecture.
- use_non_linear (bool) – Whether to apply non-linear activation after convolution.
- dropout_rate (float) – The dropout rate to be applied.
- use_linear_after_conv (bool) – Whether to apply a linear layer after the convolutional layers.
- activation – The activation function to use for the model.
- gate_activation (str) – The activation function to use for gating; typically ‘identity’ or other activation functions.
Returns: The output tensor with the same dimensionality as the : input features after passing through the MLP.
Return type: torch.Tensor
####### Examples
>>> model = MultiConvolutionalGatingMLP(
... size=256,
... linear_units=512,
... arch_type='sum',
... kernel_sizes='3,5',
... merge_conv_kernel=3,
... use_non_linear=True,
... dropout_rate=0.1,
... use_linear_after_conv=True,
... activation=torch.nn.GELU(),
... gate_activation='identity'
... )
>>> input_tensor = torch.randn(10, 256) # Batch of 10 samples
>>> output = model(input_tensor)
>>> output.shape
torch.Size([10, 256])
NOTE
The implementation assumes that the input tensor has the correct dimensions and types. Ensure that the input tensor shape is compatible with the expected input size.
- Raises:NotImplementedError – If an unsupported architecture type is provided.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x, mask=None)
Forward pass for the MultiConvolutionalGatingMLP.
This method computes the output of the MultiConvolutionalGatingMLP by applying a series of linear transformations followed by the MultiConvolutionalSpatialGatingUnit (M-CSGU). The input is processed through the channel projection layers and the spatial gating unit to produce the final output.
- Parameters:
- x (Union *[*torch.Tensor , tuple ]) – Input tensor of shape (N, T, D) or a tuple containing the input tensor and positional embedding.
- mask (torch.Tensor , optional) – An optional mask tensor to apply attention. Default is None.
- Returns: The output tensor of shape (N, T, D) or a tuple containing the output tensor and positional embedding if provided.
- Return type: torch.Tensor or tuple
####### Examples
>>> model = MultiConvolutionalGatingMLP(size=128, linear_units=256,
... arch_type='sum', kernel_sizes='3,5', merge_conv_kernel=3,
... use_non_linear=True, dropout_rate=0.1, use_linear_after_conv=True,
... activation=torch.nn.GELU(), gate_activation='relu')
>>> input_tensor = torch.randn(32, 10, 128) # (N, T, D)
>>> output = model(input_tensor)
>>> print(output.shape) # Output: (32, 10, 128)
NOTE
The input tensor x should have a shape compatible with the specified size parameter during initialization.