espnet2.asr_transducer.normalization.RMSNorm
espnet2.asr_transducer.normalization.RMSNorm
class espnet2.asr_transducer.normalization.RMSNorm(normalized_shape: int, eps: float = 1e-05, partial: float = 0.0)
Bases: Module
RMSNorm module definition.
This class implements the Root Mean Square Layer Normalization (RMSNorm), which normalizes the input using the root mean square of the input values along the specified dimensions. RMSNorm is beneficial for stabilizing training and improving model performance.
Reference: : - https://arxiv.org/pdf/1910.07467.pdf
- Parameters:
- normalized_shape (int) – The expected size of the input tensor for normalization.
- eps (float , optional) – A small value added to the denominator for numerical stability. Default is 1e-5.
- partial (float , optional) – A value defining the part of the input used for RMS statistics. It should be in the range (0, 1). Default is 0.0, which means RMS statistics will be computed over the entire input.
normalized_shape
The shape of the input tensor.
- Type: int
partial
A boolean indicating whether to use partial normalization.
- Type: bool
p
The fraction of the input to be used for RMS statistics.
- Type: float
eps
The small value for numerical stability.
- Type: float
scale
Learnable scale parameter.
- Type: torch.nn.Parameter
####### Examples
>>> rms_norm = RMSNorm(normalized_shape=256, eps=1e-5, partial=0.5)
>>> input_tensor = torch.randn(32, 10, 256) # Batch size of 32
>>> output_tensor = rms_norm(input_tensor)
>>> output_tensor.shape
torch.Size([32, 10, 256])
- Returns: The normalized output tensor with the same shape as the input tensor.
- Return type: torch.Tensor
Construct a RMSNorm object.
forward(x: Tensor) → Tensor
RMSNorm module definition.
This module applies Root Mean Square (RMS) normalization to the input tensor. RMS normalization helps in stabilizing the training of deep neural networks by ensuring that the input to each layer has a consistent scale.
Reference: https://arxiv.org/pdf/1910.07467.pdf
- Parameters:
- normalized_shape – Expected size of the input tensor.
- eps – Value added to the denominator for numerical stability. Default is 1e-5.
- partial – Value defining the part of the input used for RMS stats. If this value is between 0 and 1, only a portion of the input is used to compute the RMS statistics. Default is 0.0, which means full input is used.
normalized_shape
The expected size of the input tensor.
partial
A boolean indicating whether partial RMS statistics should be used.
p
The proportion of the input to use for RMS stats if partial is enabled.
eps
The epsilon value for numerical stability.
scale
A learnable parameter for scaling the normalized output.
####### Examples
>>> rms_norm = RMSNorm(normalized_shape=64, eps=1e-5, partial=0.5)
>>> input_tensor = torch.randn(32, 10, 64) # (Batch, Time, Features)
>>> output_tensor = rms_norm(input_tensor)
>>> output_tensor.shape
torch.Size([32, 10, 64]) # Output shape matches input shape
- Returns: Output sequences after RMS normalization. Shape is (B, T, D_hidden).
- Return type: x