espnet2.asr_transducer.decoder.modules.mega.multi_head_damped_ema.MultiHeadDampedEMA
espnet2.asr_transducer.decoder.modules.mega.multi_head_damped_ema.MultiHeadDampedEMA
class espnet2.asr_transducer.decoder.modules.mega.multi_head_damped_ema.MultiHeadDampedEMA(size: int, num_heads: int = 4, activation: Module = ReLU(), truncation_length: int | None = None)
Bases: Module
Multi-head Damped Exponential Moving Average (EMA) module for MEGA block.
This module implements a multi-head damped EMA mechanism, which is commonly used in attention mechanisms for sequence processing. The design is based on modifications from the Fairseq library and has been adapted to follow the conventions set forth in the Hugging Face Transformers library.
damping_factor
Parameter representing the damping factor for the EMA.
- Type: torch.nn.Parameter
decay_factor
Parameter representing the decay factor for the EMA.
- Type: torch.nn.Parameter
ema_expansion_matrix
Parameter for the EMA expansion matrix.
- Type: torch.nn.Parameter
kernel_projection_matrix
Parameter for the kernel projection matrix.
- Type: torch.nn.Parameter
residual_weight
Parameter representing the residual weight.
- Type: torch.nn.Parameter
scaling
Scaling factor computed as the square root of the inverse of the number of heads.
- Type: float
truncation_length
Maximum length for truncation, if specified.
- Type: Optional[int]
activation
Activation function to apply to the output.
- Type: torch.nn.Module
num_heads
Number of attention heads used in the module.
Type: int
Parameters:
- size (int) – The size of the module.
- num_heads (int , optional) – The number of attention heads. Defaults to 4.
- activation (torch.nn.Module , optional) – The activation function type. Defaults to ReLU.
- truncation_length (Optional *[*int ] , optional) – The maximum length for truncation. Defaults to None.
################# Examples
>>> ema = MultiHeadDampedEMA(size=128, num_heads=4)
>>> input_tensor = torch.randn(10, 32, 128) # (L, B, D)
>>> output, new_state = ema(input_tensor)
>>> print(output.shape) # Output shape will be (B, L, D)
- Raises:ValueError – If the input tensor does not have the expected shape.
########## NOTE The implementation includes methods for computing EMA coefficients, resetting parameters, and applying the EMA in a forward pass.
Construct an MultiHeadDampedEMA object.
compute_ema_coefficients() → Tuple[Tensor, Tensor]
Compute EMA coefficients.
This method computes the damping factor and the previous timestep weight, which are essential for the exponential moving average (EMA) calculations. The damping factor represents the P-th order coefficient, while the previous timestep weight represents the Q-th order coefficient.
- Returns:
- damping_factor: Damping factor / P-th order coefficient. : Shape: (size, num_heads, 1)
- prev_timestep_weight: Previous timestep weight / Q-th order coefficient. : Shape: (size, num_heads, 1)
- Return type: Tuple[torch.Tensor, torch.Tensor]
################# Examples
>>> ema = MultiHeadDampedEMA(size=10, num_heads=4)
>>> damping, prev_weight = ema.compute_ema_coefficients()
>>> print(damping.shape) # Output: torch.Size([10, 4, 1])
>>> print(prev_weight.shape) # Output: torch.Size([10, 4, 1])
########## NOTE The damping factor is computed using a sigmoid function applied to the damping_factor parameter, and the previous timestep weight is computed using the formula: prev_timestep_weight = 1.0 - damping_factor * decay_factor.
compute_ema_kernel(length: int) → Tensor
Compute EMA kernel / Vandermonde product.
This method calculates the Exponential Moving Average (EMA) kernel using the damped factors and the EMA expansion matrix. The resulting kernel represents the effect of applying the EMA over a sequence of specified length.
- Parameters:length – The sequence length for which to compute the EMA kernel.
- Returns: The EMA kernel / Vandermonde product, shaped (size, L), where ‘size’ corresponds to the module size and ‘L’ corresponds to the input sequence length.
- Return type: torch.Tensor
################# Examples
>>> ema_module = MultiHeadDampedEMA(size=10, num_heads=4)
>>> kernel = ema_module.compute_ema_kernel(length=5)
>>> print(kernel.shape)
torch.Size([10, 5])
########## NOTE The EMA kernel is computed based on the current damping factors and expansion matrix. This is crucial for the operation of the EMA mechanism in the model.
ema_one_step(x: Tensor, state: Tensor | None = None) → Tuple[Tensor, Tensor]
Perform exponential moving average for a single step.
This method computes the exponential moving average (EMA) for the given input tensor x at a single time step. It utilizes the current state of the EMA to update the new state and generate the output sequence.
Parameters:
- x – MultiHeadDampedEMA input sequences. Shape: (B, D, 1), where B is the batch size and D is the dimension of the input.
- state – Optional; MultiHeadDampedEMA state from the previous step. Shape: (B, D, num_heads). If not provided, the EMA is computed without incorporating any prior state.
Returns: MultiHeadDamped output sequences. Shape: (B, 1, D), representing : the output after applying the EMA to the input sequences.
new_state: MultiHeadDampedEMA state for the current step. Shape: : (B, D, num_heads), which can be used in subsequent EMA computations.
Return type: out
################# Examples
>>> ema_module = MultiHeadDampedEMA(size=128, num_heads=4)
>>> input_tensor = torch.rand(32, 128, 1) # Batch of 32
>>> initial_state = torch.zeros(32, 128, 4) # Initial state
>>> output, new_state = ema_module.ema_one_step(input_tensor, initial_state)
########## NOTE The output out is computed by applying the EMA to the input x and adding the contribution from the previous state if provided.
forward(x: Tensor, mask: Tensor | None = None, state: Dict[str, Tensor] | None = None) → Tensor | None
Compute multi-dimensional damped EMA.
This method computes the multi-dimensional damped Exponential Moving Average (EMA) for the input sequences using the current state and optionally applies a mask to the input.
- Parameters:
- x – MultiHeadDampedEMA input sequence. Shape (L, B, D), where:
- L is the sequence length,
- B is the batch size,
- D is the feature dimension.
- mask – Optional sequence mask. Shape (B, 1, L). A mask can be used to ignore certain time steps in the input sequence.
- state – Optional MultiHeadDampedEMA state. Shape (B, D, num_heads), where num_heads is the number of attention heads.
- x – MultiHeadDampedEMA input sequence. Shape (L, B, D), where:
- Returns:
- x: MultiHeadDampedEMA output sequence. Shape (B, L, D).
- new_state: MultiHeadDampedEMA state. Shape (B, D, num_heads) or None if state was not provided.
- Return type: Tuple[torch.Tensor, Optional[torch.Tensor]]
################# Examples
>>> model = MultiHeadDampedEMA(size=128, num_heads=4)
>>> input_tensor = torch.randn(10, 32, 128) # (L, B, D)
>>> mask_tensor = torch.zeros(32, 1, 10) # (B, 1, L)
>>> output, new_state = model(input_tensor, mask=mask_tensor)
########## NOTE If mask is provided, the input tensor x will have masked elements set to zero before processing. If state is not provided, the method will compute the output without using any previous state information.
get_ema_coefficients() → Tuple[Tensor, Tensor]
Get EMA coefficients.
This method retrieves the damping factor and the previous timestep weight coefficients used in the exponential moving average (EMA) calculations. The coefficients are computed using the sigmoid activation function applied to the damping and decay factors.
If the coefficients have not been computed yet, this method will call compute_ema_coefficients() to generate them.
- Returns:
- Damping factor / P-th order coefficient. Shape: (size, num_heads, 1)
- Previous timestep weight / Q-th order coefficient. Shape: (size, num_heads, 1)
- Return type: Tuple[torch.Tensor, torch.Tensor]
################# Examples
>>> ema_module = MultiHeadDampedEMA(size=128, num_heads=4)
>>> damping, prev_weight = ema_module.get_ema_coefficients()
>>> print(damping.shape) # Output: torch.Size([128, 4, 1])
>>> print(prev_weight.shape) # Output: torch.Size([128, 4, 1])
reset_parameters(val: float = 0.0, std1: float = 0.2, std2: float = 1.0) → None
Reset module parameters.
This method initializes the parameters of the MultiHeadDampedEMA module using a normal distribution. It sets the damping and decay factors, the EMA expansion matrix, the kernel projection matrix, and the residual weight.
- Parameters:
- val – Initialization value for the parameters.
- std1 – Standard deviation for the damping and decay factors.
- std2 – Standard deviation for the kernel projection matrix and residual weight.
################# Examples
>>> ema = MultiHeadDampedEMA(size=10, num_heads=4)
>>> ema.reset_parameters(val=0.1, std1=0.3, std2=0.5)
########## NOTE The parameters are initialized in-place, and this function does not return any value. Use this method to reinitialize the model parameters, especially before training.