espnet2.s2st.losses.guided_attention_loss.S2STGuidedAttentionLoss
espnet2.s2st.losses.guided_attention_loss.S2STGuidedAttentionLoss
class espnet2.s2st.losses.guided_attention_loss.S2STGuidedAttentionLoss(weight: float = 1.0, sigma: float = 0.4, alpha: float = 1.0)
Bases: AbsS2STLoss
Tacotron-based loss for S2ST.
This class implements a guided attention loss for sequence-to-sequence translation tasks using the Tacotron architecture. It leverages the GuidedAttentionLoss from the ESPnet library to compute the loss based on attention weights, input lengths, and output lengths.
weight
Weight for the loss. If set to 0, the loss will not be computed.
- Type: float
loss
Instance of GuidedAttentionLoss used to compute the guided attention loss.
Type:GuidedAttentionLoss
Parameters:
- weight (float , optional) – The weight for the loss. Defaults to 1.0.
- sigma (float , optional) – The sigma parameter for the guided attention loss. Defaults to 0.4.
- alpha (float , optional) – The alpha parameter for the guided attention loss. Defaults to 1.0.
Returns: The guided attention loss, or None if weight is 0.
Return type: Tensor
####### Examples
>>> loss_fn = S2STGuidedAttentionLoss(weight=1.0, sigma=0.4, alpha=1.0)
>>> att_ws = torch.rand(10, 20) # Example attention weights
>>> ilens = torch.randint(1, 20, (10,)) # Example input lengths
>>> olens_in = torch.randint(1, 20, (10,)) # Example output lengths
>>> loss = loss_fn(att_ws, ilens, olens_in)
>>> print(loss) # Should output the computed loss tensor
NOTE
Ensure that the input tensors (att_ws, ilens, olens_in) are properly formatted and match the expected dimensions.
- Raises:ValueError – If the input tensors do not have compatible dimensions.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(att_ws: Tensor, ilens: Tensor, olens_in: Tensor)
Forward pass for calculating the guided attention loss.
This method computes the guided attention loss based on the provided attention weights, input lengths, and output lengths. If the weight for the loss is greater than zero, it will invoke the loss calculation from the GuidedAttentionLoss class.
- Parameters:
- att_ws (torch.Tensor) – Attention weights of the model output.
- ilens (torch.Tensor) – Input lengths for the batch.
- olens_in (torch.Tensor) – Output lengths for the batch.
- Returns: Guided attention loss if weight > 0, otherwise returns (None, None, None, None).
- Return type: Tensor
####### Examples
>>> att_ws = torch.rand(10, 50) # Example attention weights
>>> ilens = torch.tensor([50] * 10) # Example input lengths
>>> olens_in = torch.tensor([40] * 10) # Example output lengths
>>> loss = S2STGuidedAttentionLoss(weight=1.0)
>>> result = loss.forward(att_ws, ilens, olens_in)
>>> print(result) # Should output the computed guided attention loss