espnet2.spk.layers.rawnet_block.AFMS
espnet2.spk.layers.rawnet_block.AFMS
class espnet2.spk.layers.rawnet_block.AFMS(nb_dim: int)
Bases: Module
Alpha-Feature Map Scaling (AFMS) module, which applies scaling to the output of
each residual block.
This module is designed to enhance the feature representation in neural networks by learning a scaling factor (alpha) for the output features. The scaling factor is modulated by a learned feature map that is derived from the input.
References:
- RawNet2: https://www.isca-speech.org/archive/Interspeech_2020/pdfs/1011.pdf
- AMFS: https://www.koreascience.or.kr/article/JAKO202029757857763.page
alpha
Learnable scaling factors for each feature dimension.
- Type: nn.Parameter
fc
Fully connected layer to project features.
- Type: nn.Linear
sig
Sigmoid activation function.
Type: nn.Sigmoid
Parameters:nb_dim (int) – The number of dimensions (features) for the input tensor.
Returns: The scaled output tensor after applying the AFMS.
Return type: torch.Tensor
####### Examples
>>> afms = AFMS(nb_dim=128)
>>> input_tensor = torch.randn(32, 128, 10) # (batch_size, nb_dim, seq_len)
>>> output_tensor = afms(input_tensor)
>>> print(output_tensor.shape) # Should be (32, 128, 10)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Perform the forward pass of the AFMS module.
This method applies the Alpha-Feature map scaling to the input tensor x. It computes the adaptive average pooling, applies a fully connected layer, and scales the input by the learned alpha parameter.
- Parameters:x (torch.Tensor) – Input tensor of shape (batch_size, channels, length).
- Returns: Output tensor after applying the Alpha-Feature map scaling.
- Return type: torch.Tensor
####### Examples
>>> afms = AFMS(nb_dim=64)
>>> input_tensor = torch.randn(10, 64, 100) # (batch_size, channels, length)
>>> output_tensor = afms(input_tensor)
>>> print(output_tensor.shape)
torch.Size([10, 64, 100])
NOTE
This method expects x to be a 3-dimensional tensor where the second dimension represents the feature channels and the third dimension represents the sequence length.
- Raises:ValueError – If the input tensor x does not have the expected shape.