espnet2.gan_tts.vits.loss.KLDivergenceLossWithoutFlow
Less than 1 minute
espnet2.gan_tts.vits.loss.KLDivergenceLossWithoutFlow
class espnet2.gan_tts.vits.loss.KLDivergenceLossWithoutFlow(*args, **kwargs)
Bases: Module
KL divergence loss without flow.
This class implements the calculation of the Kullback-Leibler (KL) divergence loss in a variational inference framework, specifically designed for cases where flow-based representations are not used.
None
- Parameters:
- m_q (Tensor) – Posterior encoder projected mean (B, H, T_feats).
- logs_q (Tensor) – Posterior encoder projected scale (B, H, T_feats).
- m_p (Tensor) – Expanded text encoder projected mean (B, H, T_feats).
- logs_p (Tensor) – Expanded text encoder projected scale (B, H, T_feats).
- Returns: KL divergence loss, averaged over the batch.
- Return type: Tensor
Examples
>>> kl_loss = KLDivergenceLossWithoutFlow()
>>> m_q = torch.randn(32, 64, 100) # Example tensor for m_q
>>> logs_q = torch.randn(32, 64, 100) # Example tensor for logs_q
>>> m_p = torch.randn(32, 64, 100) # Example tensor for m_p
>>> logs_p = torch.randn(32, 64, 100) # Example tensor for logs_p
>>> loss = kl_loss(m_q, logs_q, m_p, logs_p)
>>> print(loss) # Output will be the calculated KL divergence loss
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(m_q: Tensor, logs_q: Tensor, m_p: Tensor, logs_p: Tensor) → Tensor
Calculate KL divergence loss without flow.
- Parameters:
- m_q (Tensor) – Posterior encoder projected mean (B, H, T_feats).
- logs_q (Tensor) – Posterior encoder projected scale (B, H, T_feats).
- m_p (Tensor) – Expanded text encoder projected mean (B, H, T_feats).
- logs_p (Tensor) – Expanded text encoder projected scale (B, H, T_feats).