espnet2.speechlm.net_utils.ce_loss
espnet2.speechlm.net_utils.ce_loss
espnet2.speechlm.net_utils.ce_loss(logits: Tensor, target: Tensor, lengths: Tensor, prefix_len: Tensor | None = None, first_layer_weight: int = 1.0) → Tuple[Tensor, Tensor]
Computes the cross-entropy loss and accuracy for the given logits and targets.
This function calculates the cross-entropy loss between the predicted logits and the target values, applying a length mask to ignore padding tokens. It also computes accuracy for each layer in the output.
espnet2.speechlm.net_utils.logits
The predicted logits of shape (B, T, N, C) where B is the batch size, T is the sequence length, N is the number of layers, and C is the number of classes.
- Type: torch.Tensor
espnet2.speechlm.net_utils.target
The target tensor of shape (B, T, N) containing the correct class indices for each token.
- Type: torch.Tensor
espnet2.speechlm.net_utils.lengths
A tensor of shape (B,) indicating the actual lengths of each sequence in the batch.
- Type: torch.Tensor
espnet2.speechlm.net_utils.prefix_len
A tensor indicating the lengths of the prefixes to mask out in the loss computation. Defaults to None.
- Type: torch.Tensor, optional
espnet2.speechlm.net_utils.first_layer_weight
Weight to apply to the first layer’s gradients. Defaults to 1.0.
Type: float, optional
Parameters:
- logits (torch.Tensor) – Predicted logits from the model.
- target (torch.Tensor) – Ground truth target values.
- lengths (torch.Tensor) – Lengths of the sequences for masking.
- prefix_len (torch.Tensor , optional) – Prefix lengths for masking.
- first_layer_weight (float , optional) – Weight for the first layer.
Returns: A tuple containing : the computed loss, a dictionary with accuracy statistics for each layer, and the total weight used in loss computation.
Return type: Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]
Raises:AssertionError – If the dimensions of logits or the sizes do not match the target.
Examples
>>> logits = torch.randn(2, 5, 3, 10) # Example logits
>>> target = torch.randint(0, 10, (2, 5, 3)) # Example target
>>> lengths = torch.tensor([5, 3]) # Actual lengths
>>> loss, stats, weight = ce_loss(logits, target, lengths)
NOTE
This function assumes that the input tensors are correctly shaped and that the required dimensions are as specified.