espnet2.gan_tts.vits.loss.KLDivergenceLoss
Less than 1 minute
espnet2.gan_tts.vits.loss.KLDivergenceLoss
class espnet2.gan_tts.vits.loss.KLDivergenceLoss(*args, **kwargs)
Bases: Module
KL divergence loss.
This module computes the Kullback-Leibler (KL) divergence loss between two distributions, which is commonly used in variational inference methods.
The KL divergence measures how one probability distribution diverges from a second expected probability distribution.
None
- Parameters:
- z_p (Tensor) – Flow hidden representation (B, H, T_feats).
- logs_q (Tensor) – Posterior encoder projected scale (B, H, T_feats).
- m_p (Tensor) – Expanded text encoder projected mean (B, H, T_feats).
- logs_p (Tensor) – Expanded text encoder projected scale (B, H, T_feats).
- z_mask (Tensor) – Mask tensor (B, 1, T_feats).
- Returns: KL divergence loss.
- Return type: Tensor
Examples
>>> kl_loss = KLDivergenceLoss()
>>> z_p = torch.randn(32, 64, 100)
>>> logs_q = torch.randn(32, 64, 100)
>>> m_p = torch.randn(32, 64, 100)
>>> logs_p = torch.randn(32, 64, 100)
>>> z_mask = torch.ones(32, 1, 100)
>>> loss = kl_loss(z_p, logs_q, m_p, logs_p, z_mask)
>>> print(loss)
NOTE
This loss is useful for training generative models like VITS.
- Raises:None –
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(z_p: Tensor, logs_q: Tensor, m_p: Tensor, logs_p: Tensor, z_mask: Tensor) → Tensor
Calculate KL divergence loss.
- Parameters:
- z_p (Tensor) – Flow hidden representation (B, H, T_feats).
- logs_q (Tensor) – Posterior encoder projected scale (B, H, T_feats).
- m_p (Tensor) – Expanded text encoder projected mean (B, H, T_feats).
- logs_p (Tensor) – Expanded text encoder projected scale (B, H, T_feats).
- z_mask (Tensor) – Mask tensor (B, 1, T_feats).
- Returns: KL divergence loss.
- Return type: Tensor