espnet2.s2st.losses.attention_loss.S2STAttentionLoss
espnet2.s2st.losses.attention_loss.S2STAttentionLoss
class espnet2.s2st.losses.attention_loss.S2STAttentionLoss(vocab_size: int, padding_idx: int = -1, weight: float = 1.0, smoothing: float = 0.0, normalize_length: str2bool = False, criterion: Module = KLDivLoss())
Bases: AbsS2STLoss
attention-based label smoothing loss for S2ST.
This class implements an attention-based label smoothing loss specifically designed for sequence-to-sequence tasks. It utilizes label smoothing to improve the robustness of the model against label noise and overfitting.
weight
The weight for the loss function.
- Type: float
loss
The label smoothing loss instance.
Type: LabelSmoothingLoss
Parameters:
- vocab_size (int) – The size of the vocabulary.
- padding_idx (int , optional) – The index used for padding. Defaults to -1.
- weight (float , optional) – The weight for the loss function. Defaults to 1.0.
- smoothing (float , optional) – The label smoothing factor. Defaults to 0.0.
- normalize_length (str2bool , optional) – Whether to normalize the loss by the length of the sequence. Defaults to False.
- criterion (torch.nn.Module , optional) – The criterion used for the loss. Defaults to torch.nn.KLDivLoss(reduction=”none”).
Returns: The computed loss value if weight is greater than 0; otherwise, returns None.
Return type: torch.Tensor or None
####### Examples
>>> loss_fn = S2STAttentionLoss(vocab_size=1000)
>>> dense_y = torch.randn(10, 1000) # Example output probabilities
>>> token_y = torch.randint(0, 1000, (10,)) # Example target tokens
>>> loss = loss_fn(dense_y, token_y)
>>> print(loss)
NOTE
Ensure that the dense_y tensor is properly normalized before passing it to the loss function.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(dense_y: Tensor, token_y: Tensor)
Forward method for calculating the attention-based label smoothing loss.
This method computes the loss using the provided dense and token labels. The loss is calculated only if the weight is greater than zero; otherwise, it returns None.
- Parameters:
- dense_y (torch.Tensor) – The predicted dense outputs from the model.
- token_y (torch.Tensor) – The ground truth token labels for the inputs.
- Returns: The computed loss value if weight > 0, otherwise None.
- Return type: torch.Tensor or None
####### Examples
>>> loss_fn = S2STAttentionLoss(vocab_size=100, smoothing=0.1)
>>> dense_y = torch.randn(10, 100) # Example predicted outputs
>>> token_y = torch.randint(0, 100, (10,)) # Example token labels
>>> loss = loss_fn(dense_y, token_y)
>>> print(loss) # Outputs the computed loss tensor