espnet2.tts.fastspeech2.loss.FastSpeech2Loss
espnet2.tts.fastspeech2.loss.FastSpeech2Loss
class espnet2.tts.fastspeech2.loss.FastSpeech2Loss(use_masking: bool = True, use_weighted_masking: bool = False)
Bases: Module
Loss function module for FastSpeech2.
This class implements the loss calculation for the FastSpeech2 model. It includes options for applying masking to ignore padded elements in the loss calculation, as well as weighted masking for enhanced training performance.
use_masking
Whether to apply masking for padded parts in loss calculation.
- Type: bool
use_weighted_masking
Whether to apply weighted masking in loss calculation.
- Type: bool
l1_criterion
L1 loss criterion.
- Type: torch.nn.Module
mse_criterion
Mean Squared Error loss criterion.
- Type: torch.nn.Module
duration_criterion
Duration predictor loss criterion.
Parameters:
- use_masking (bool) – Whether to apply masking for padded part in loss calculation.
- use_weighted_masking (bool) – Whether to apply weighted masking in loss calculation.
Returns:
- L1 loss value.
- Duration predictor loss value.
- Pitch predictor loss value.
- Energy predictor loss value.
Return type: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
####### Examples
>>> loss_fn = FastSpeech2Loss(use_masking=True, use_weighted_masking=False)
>>> after_outs = torch.randn(8, 100, 80)
>>> before_outs = torch.randn(8, 100, 80)
>>> d_outs = torch.randint(1, 10, (8, 50))
>>> p_outs = torch.randn(8, 50, 1)
>>> e_outs = torch.randn(8, 50, 1)
>>> ys = torch.randn(8, 100, 80)
>>> ds = torch.randint(1, 10, (8, 50))
>>> ps = torch.randn(8, 50, 1)
>>> es = torch.randn(8, 50, 1)
>>> ilens = torch.randint(50, 100, (8,))
>>> olens = torch.randint(50, 100, (8,))
>>> loss_fn(after_outs, before_outs, d_outs, p_outs, e_outs, ys, ds, ps, es, ilens, olens)
Initialize feed-forward Transformer loss module.
- Parameters:
- use_masking (bool) – Whether to apply masking for padded part in loss calculation.
- use_weighted_masking (bool) – Whether to weighted masking in loss calculation.
forward(after_outs: Tensor, before_outs: Tensor, d_outs: Tensor, p_outs: Tensor, e_outs: Tensor, ys: Tensor, ds: Tensor, ps: Tensor, es: Tensor, ilens: Tensor, olens: Tensor) → Tuple[Tensor, Tensor, Tensor, Tensor]
Calculate forward propagation for FastSpeech2 loss computation.
This method computes the loss values for various outputs generated by the FastSpeech2 model, including L1 loss, duration predictor loss, pitch predictor loss, and energy predictor loss. It applies masking to handle padded parts of the input tensors when specified.
- Parameters:
- after_outs (Tensor) – Batch of outputs after postnets (B, T_feats, odim).
- before_outs (Tensor) – Batch of outputs before postnets (B, T_feats, odim).
- d_outs (LongTensor) – Batch of outputs of duration predictor (B, T_text).
- p_outs (Tensor) – Batch of outputs of pitch predictor (B, T_text, 1).
- e_outs (Tensor) – Batch of outputs of energy predictor (B, T_text, 1).
- ys (Tensor) – Batch of target features (B, T_feats, odim).
- ds (LongTensor) – Batch of durations (B, T_text).
- ps (Tensor) – Batch of target token-averaged pitch (B, T_text, 1).
- es (Tensor) – Batch of target token-averaged energy (B, T_text, 1).
- ilens (LongTensor) – Batch of the lengths of each input (B,).
- olens (LongTensor) – Batch of the lengths of each target (B,).
- Returns: A tuple containing: : - L1 loss value.
- Duration predictor loss value.
- Pitch predictor loss value.
- Energy predictor loss value.
- Return type: Tuple[Tensor, Tensor, Tensor, Tensor]
####### Examples
>>> l1_loss, duration_loss, pitch_loss, energy_loss = model.forward(
... after_outs, before_outs, d_outs, p_outs, e_outs, ys, ds, ps, es, ilens, olens
... )
NOTE
This function applies masking based on the use_masking and use_weighted_masking flags set during the initialization of the FastSpeech2Loss class. Ensure to set these flags according to your requirements for loss calculation.