espnet2.asr_transducer.encoder.modules.convolution.ConvolutionalSpatialGatingUnit
espnet2.asr_transducer.encoder.modules.convolution.ConvolutionalSpatialGatingUnit
class espnet2.asr_transducer.encoder.modules.convolution.ConvolutionalSpatialGatingUnit(size: int, kernel_size: int, norm_class: ~torch.nn.modules.module.Module = <class 'torch.nn.modules.normalization.LayerNorm'>, norm_args: ~typing.Dict = {}, dropout_rate: float = 0.0, causal: bool = False)
Bases: Module
Convolutional Spatial Gating Unit module definition.
This module performs a convolution operation that splits the input into two parts, applies normalization to one part, and then combines the results through an element-wise multiplication. It is designed to work in the context of convolutional neural networks, particularly for attention mechanisms.
- Parameters:
- size – Initial size to determine the number of channels. The input will be split into two equal parts.
- kernel_size – Size of the convolving kernel.
- norm_class – Normalization module class (default: torch.nn.LayerNorm).
- norm_args – Normalization module arguments (default: empty dictionary).
- dropout_rate – Dropout rate to apply after the gating operation (default: 0.0).
- causal – Whether to use causal convolution (set to True if streaming).
kernel_size
Size of the convolution kernel.
- Type: int
lorder
The left order for causal convolution.
- Type: int
conv
The convolutional layer for gating.
- Type: torch.nn.Conv1d
norm
The normalization layer.
- Type: torch.nn.Module
activation
The activation function.
- Type: torch.nn.Module
dropout
The dropout layer.
- Type: torch.nn.Dropout
####### Examples
>>> unit = ConvolutionalSpatialGatingUnit(size=64, kernel_size=3)
>>> input_tensor = torch.rand(10, 32, 64) # (B, T, D_hidden)
>>> output, cache = unit(input_tensor)
>>> output.shape
torch.Size([10, ?, 64])
NOTE
The input tensor is expected to have a shape of (B, T, D_hidden), where B is the batch size, T is the sequence length, and D_hidden is the dimensionality of the hidden state. The output tensor will have a shape of (B, ?, D_hidden), where ‘?’ depends on the operations performed.
- Raises:ValueError – If the size is not a positive even integer.
Construct a ConvolutionalSpatialGatingUnit object.
Compute convolution module.
This method processes the input tensor x through a series of convolutional operations, applying a gating mechanism, and handling optional masking and caching for causal convolutions.
Parameters:
- x – ConvolutionalSpatialGatingUnit input sequences. Shape (B, T, D_hidden),
- size (where B is the batch)
- length (T is the sequence)
- D_hidden (and)
- features. (is the number of)
- mask – Optional source mask. Shape (B, T_2), used to prevent certain positions in the input from being processed.
- cache – Optional input cache for maintaining state across calls. Shape (1, D_hidden, conv_kernel).
Returns:
- x: ConvolutionalSpatialGatingUnit output sequences. Shape (B, ?,
D_hidden), where the second dimension may vary based on the convolution operation.
- cache: ConvolutionalSpatialGatingUnit output cache. Shape (1, D_hidden, conv_kernel).
Return type: Tuple[torch.Tensor, torch.Tensor]
####### Examples
>>> unit = ConvolutionalSpatialGatingUnit(size=128, kernel_size=3)
>>> input_tensor = torch.randn(10, 20, 64) # Batch of 10, 20 timesteps
>>> output, new_cache = unit(input_tensor)
>>> print(output.shape) # Output shape may vary based on convolutions
NOTE
- The input tensor x is expected to have its last dimension split
into two halves for gating purposes.
- If causal is set to True during initialization, the input will be processed in a way that prevents future information from being used.