espnet2.asr_transducer.normalization.RMSNorm
Less than 1 minute
espnet2.asr_transducer.normalization.RMSNorm
class espnet2.asr_transducer.normalization.RMSNorm(normalized_shape: int, eps: float = 1e-05, partial: float = 0.0)
Bases: Module
RMSNorm module definition.
Reference: https://arxiv.org/pdf/1910.07467.pdf
- Parameters:
- normalized_shape β Expected size.
- eps β Value added to the denominator for numerical stability.
- partial β Value defining the part of the input used for RMS stats.
Construct a RMSNorm object.
forward(x: Tensor) β Tensor
Compute RMS normalization.
- Parameters:x β Input sequences. (B, T, D_hidden)
- Returns: Output sequences. (B, T, D_hidden)
- Return type: x
