espnet2.hubert.hubert_loss.HubertPretrainLoss
espnet2.hubert.hubert_loss.HubertPretrainLoss
class espnet2.hubert.hubert_loss.HubertPretrainLoss(pred_masked_weight: float = 1.0, pred_nomask_weight: float = 0.0, loss_weights: float = 10.0)
Bases: Module
Hubert criterion module.
This module implements the Hubert pretraining loss used for masked and unmasked frames. It is designed to compute the loss for both masked and unmasked predictions, as well as additional losses if applicable.
pred_masked_weight
Weight for predictive loss for masked frames.
- Type: float
pred_nomask_weight
Weight for predictive loss for unmasked frames.
- Type: float
loss_weights
Weights for additional loss terms (not first one).
Type: float
Parameters:
- pred_masked_weight (float) – Weight for predictive loss for masked frames. Defaults to 1.0.
- pred_nomask_weight (float) – Weight for predictive loss for unmasked frames. Defaults to 0.0.
- loss_weights (float) – Weights for additional loss terms. Defaults to 10.0.
Returns: A tuple containing: : - loss (float): The computed loss value.
- logp_m_list (List[Tensor]): List of logits for masked predictions.
- logp_u_list (List[Tensor]): List of logits for unmasked predictions.
Return type: Tuple[float, List[Tensor], List[Tensor]]
Raises:NotImplementedError – If the model does not support extra loss terms.
####### Examples
>>> loss_module = HubertPretrainLoss()
>>> loss, logp_m, logp_u = loss_module(model, enc_outputs)
NOTE
This implementation utilizes code from Fairseq and is based on the work of Abdelrahman Mohamed and Wei-Ning Hsu.
References: : - Paper: https://arxiv.org/pdf/2106.07447.pdf
- Code in Fairseq: https://github.com/pytorch/fairseq/tree/master/examples/hubert
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(model, enc_outputs, reduce=True)
Computes the forward pass of the Hubert pretraining loss.
This method calculates the total loss based on the model’s predictions for both masked and unmasked frames. It uses cross-entropy loss for the masked and unmasked predictions and includes additional loss terms if specified. The final loss is weighted according to the configured parameters.
- Parameters:
- model – The model used to obtain predictions and targets.
- enc_outputs – The encoded outputs from the model that are used to compute the loss.
- reduce – A boolean indicating whether to reduce the loss. If True, the loss will be summed; otherwise, it will not be reduced.
- Returns:
- loss (float): The computed loss value.
- logp_m_list (list): The list of logits for masked frames.
- logp_u_list (list): The list of logits for unmasked frames.
- Return type: A tuple containing
####### Examples
>>> model = HubertModel()
>>> enc_outputs = model.encode(inputs)
>>> loss_fn = HubertPretrainLoss()
>>> loss, logp_m, logp_u = loss_fn(model, enc_outputs)
NOTE
This method assumes that the model has methods get_logits and get_targets, and it also checks for the existence of get_extra_losses if additional loss weights are utilized.
- Raises:
- NotImplementedError – If the model’s extra losses are not a list
- containing a single element. –