espnet2.tts2.fastspeech2.loss.FastSpeech2LossDiscrete
espnet2.tts2.fastspeech2.loss.FastSpeech2LossDiscrete
class espnet2.tts2.fastspeech2.loss.FastSpeech2LossDiscrete(use_masking: bool = True, use_weighted_masking: bool = False, ignore_id: int = -1)
Bases: Module
Loss function module for FastSpeech2, designed for calculating various loss components used in training the FastSpeech2 model. This module computes the cross-entropy loss for discrete features, as well as losses for duration, pitch, and energy predictors.
use_masking
Indicates whether to apply masking for padded parts in loss calculations.
- Type: bool
use_weighted_masking
Indicates whether to use weighted masking in loss calculations.
- Type: bool
ce_criterion
Cross-entropy loss criterion.
- Type: torch.nn.CrossEntropyLoss
mse_criterion
Mean squared error loss criterion.
- Type: torch.nn.MSELoss
duration_criterion
Duration prediction loss criterion.
Parameters:
- use_masking (bool) – Whether to apply masking for padded part in loss calculation. Default is True.
- use_weighted_masking (bool) – Whether to apply weighted masking in loss calculation. Default is False.
- ignore_id (int) – Index to ignore in the loss calculation. Default is -1.
Returns: CrossEntropy loss value, Duration predictor loss value, Pitch predictor loss value, Energy predictor loss value.
Return type: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
####### Examples
>>> loss_module = FastSpeech2LossDiscrete()
>>> ce_loss, duration_loss, pitch_loss, energy_loss = loss_module(
... after_outs, before_outs, d_outs, p_outs, e_outs, ys, ds, ps, es, ilens, olens
... )
NOTE
This class is part of the ESPnet2 toolkit and is specifically designed for FastSpeech2, which utilizes discrete speech targets.
- Raises:AssertionError – If both use_masking and use_weighted_masking are set to True.
Initialize feed-forward Transformer loss module.
- Parameters:
- use_masking (bool) – Whether to apply masking for padded part in loss calculation.
- use_weighted_masking (bool) – Whether to weighted masking in loss calculation.
forward(after_outs: Tensor, before_outs: Tensor, d_outs: Tensor, p_outs: Tensor, e_outs: Tensor, ys: Tensor, ds: Tensor, ps: Tensor, es: Tensor, ilens: Tensor, olens: Tensor) → Tuple[Tensor, Tensor, Tensor, Tensor]
Calculate forward propagation for the FastSpeech2 model, computing the loss
values for cross-entropy, duration, pitch, and energy predictions.
- Parameters:
- after_outs (torch.Tensor) – Batch of outputs after postnets (B, T_feats, odim).
- before_outs (torch.Tensor) – Batch of outputs before postnets (B, T_feats, odim).
- d_outs (torch.LongTensor) – Batch of outputs of duration predictor (B, T_text).
- p_outs (torch.Tensor) – Batch of outputs of pitch predictor (B, T_text, 1).
- e_outs (torch.Tensor) – Batch of outputs of energy predictor (B, T_text, 1).
- ys (torch.Tensor) – Batch of target features in discrete space (B, T_feats).
- ds (torch.LongTensor) – Batch of durations (B, T_text).
- ps (torch.Tensor) – Batch of target token-averaged pitch (B, T_text, 1).
- es (torch.Tensor) – Batch of target token-averaged energy (B, T_text, 1).
- ilens (torch.LongTensor) – Batch of the lengths of each input (B,).
- olens (torch.LongTensor) – Batch of the lengths of each target (B,).
- Returns: A tuple containing the following loss values:
- CrossEntropy loss value.
- Duration predictor loss value.
- Pitch predictor loss value.
- Energy predictor loss value.
- Return type: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
####### Examples
>>> loss_function = FastSpeech2LossDiscrete()
>>> ce_loss, duration_loss, pitch_loss, energy_loss = loss_function(
... after_outs, before_outs, d_outs, p_outs, e_outs, ys, ds, ps, es, ilens, olens
... )
NOTE
The method applies masking to handle padded parts in the loss calculation if use_masking is set to True. If use_weighted_masking is also True, weighted masking will be applied for loss computation.