espnet2.spk.loss.aamsoftmax_subcenter_intertopk.ArcMarginProduct_intertopk_subcenter
espnet2.spk.loss.aamsoftmax_subcenter_intertopk.ArcMarginProduct_intertopk_subcenter
class espnet2.spk.loss.aamsoftmax_subcenter_intertopk.ArcMarginProduct_intertopk_subcenter(nout, nclasses, scale=32.0, margin=0.2, easy_margin=False, K=3, mp=0.06, k_top=5, do_lm=False)
Bases: AbsLoss
Implement of large margin arc distance with intertopk and subcenter.
This class implements the ArcMarginProduct with intertopk and subcenter techniques for improved speaker verification. It leverages concepts from the referenced papers to enhance the model’s performance.
References
MULTI-QUERY MULTI-HEAD ATTENTION POOLING AND INTER-TOPK PENALTY FOR SPEAKER VERIFICATION. https://arxiv.org/pdf/2110.05042.pdf Sub-center ArcFace: Boosting Face Recognition by Large-Scale Noisy Web Faces. https://ibug.doc.ic.ac.uk/media/uploads/documents/eccv_1445.pdf
- Parameters:
- in_features (int) – Size of each input sample.
- out_features (int) – Size of each output sample.
- scale (float , optional) – Norm of input feature. Defaults to 32.0.
- margin (float , optional) – Margin for cos(theta + margin). Defaults to 0.2.
- easy_margin (bool , optional) – Use easy margin if True. Defaults to False.
- K (int , optional) – Number of sub-centers. Defaults to 3.
- mp (float , optional) – Margin penalty of hard samples. Defaults to 0.06.
- k_top (int , optional) – Number of hard samples. Defaults to 5.
- do_lm (bool , optional) – Whether to perform large margin finetune. Defaults to False.
in_features
Number of input features.
- Type: int
out_features
Number of output classes.
- Type: int
scale
Scaling factor for output.
- Type: float
margin
Margin value.
- Type: float
K
Number of sub-centers.
- Type: int
k_top
Number of top-k samples.
- Type: int
mp
Margin penalty for hard samples.
- Type: float
do_lm
Flag for large margin finetuning.
- Type: bool
weight
Weight parameter for the classifier.
- Type: torch.Tensor
easy_margin
Flag for easy margin setting.
- Type: bool
cos_m
Cosine of the margin.
- Type: float
sin_m
Sine of the margin.
- Type: float
th
Threshold for cosine value.
- Type: float
mm
Margin adjustment factor.
- Type: float
mm
Additional margin adjustment factor.
- Type: float
cos_m
Cosine for margin penalty.
- Type: float
sin_m
Sine for margin penalty.
- Type: float
ce
Cross-entropy loss function.
- Type: nn.CrossEntropyLoss
######### Examples
>>> arc_margin = ArcMarginProduct_intertopk_subcenter(
... nout=512, nclasses=10, scale=32.0, margin=0.2, K=3, k_top=5
... )
>>> input_tensor = torch.randn(16, 512) # Batch of 16 samples
>>> label_tensor = torch.randint(0, 10, (16, 1)) # Random labels
>>> loss = arc_margin.forward(input_tensor, label_tensor)
NOTE
Ensure that the input tensor is normalized before passing it to the forward method to achieve optimal results.
- Raises:ValueError – If the input and label sizes do not match.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(input, label)
Computes the forward pass for the ArcMarginProduct.
This method applies the ArcMarginProduct calculation to the input features and computes the loss based on the provided labels. It normalizes the input features and weight parameters, calculates the cosine and sine values for the margin, and applies the necessary transformations to generate the output.
- Parameters:
- input (torch.Tensor) – The input features of shape (batch_size, in_features).
- label (torch.Tensor) – The ground truth labels of shape (batch_size, num_classes) or (batch_size, 1).
- Returns: The computed loss value.
- Return type: torch.Tensor
- Raises:ValueError – If the input tensor does not match the expected dimensions.
######### Examples
>>> model = ArcMarginProduct_intertopk_subcenter(nout=512, nclasses=10)
>>> input = torch.randn(32, 512) # Example input for batch size 32
>>> label = torch.randint(0, 10, (32,)) # Random labels for 10 classes
>>> loss = model(input, label)
>>> print(loss)
NOTE
Ensure that the input tensor is normalized before passing it to this method for optimal performance.
update(margin=0.2)
Implement of large margin arc distance with intertopk and subcenter.
This class implements the ArcMarginProduct with enhancements for speaker verification tasks. It includes intertopk and subcenter techniques to improve the robustness of the margin-based loss function.
Reference: : MULTI-QUERY MULTI-HEAD ATTENTION POOLING AND INTER-TOPK PENALTY FOR SPEAKER VERIFICATION. https://arxiv.org/pdf/2110.05042.pdf Sub-center ArcFace: Boosting Face Recognition by Large-Scale Noisy Web Faces. https://ibug.doc.ic.ac.uk/media/uploads/documents/eccv_1445.pdf
in_features
Size of each input sample.
- Type: int
out_features
Size of each output sample.
- Type: int
scale
Norm of input feature.
- Type: float
margin
Margin for the arc margin calculation.
- Type: float
K
Number of sub-centers.
- Type: int
k_top
Number of hard samples.
- Type: int
mp
Margin penalty of hard samples.
- Type: float
do_lm
Whether to perform large margin finetune.
Type: bool
Parameters:
- nout (int) – Number of output features.
- nclasses (int) – Number of classes for classification.
- scale (float , optional) – Norm of input feature. Defaults to 32.0.
- margin (float , optional) – Margin for the arc margin calculation. Defaults to 0.2.
- easy_margin (bool , optional) – If True, use easy margin. Defaults to False.
- K (int , optional) – Number of sub-centers. Defaults to 3.
- mp (float , optional) – Margin penalty of hard samples. Defaults to 0.06.
- k_top (int , optional) – Number of hard samples. Defaults to 5.
- do_lm (bool , optional) – If True, performs large margin finetune. Defaults to False.
######### Examples
arc_margin = ArcMarginProduct_intertopk_subcenter(nout=512, nclasses=10) arc_margin.update(margin=0.3)