espnet2.gan_tts.vits.flow.LogFlow
espnet2.gan_tts.vits.flow.LogFlow
class espnet2.gan_tts.vits.flow.LogFlow(*args, **kwargs)
Bases: Module
LogFlow module for calculating forward propagation in flow-based models.
This class implements a log flow transformation as part of the Variational Inference Text-to-Speech (VITS) model. It computes the log of the input tensor while applying a mask and calculates the log-determinant necessary for negative log-likelihood (NLL) computation.
- Parameters:
- x (Tensor) – Input tensor of shape (B, channels, T).
- x_mask (Tensor) – Mask tensor of shape (B, 1, T).
- inverse (bool) – Whether to compute the inverse of the flow.
- eps (float) – A small value to prevent log(0). Default is 1e-5.
- Returns:
- If inverse is False: : - Output tensor of shape (B, channels, T).
- Log-determinant tensor for NLL of shape (B,).
- If inverse is True: : - Output tensor of shape (B, channels, T).
- If inverse is False: : - Output tensor of shape (B, channels, T).
- Return type: Union[Tensor, Tuple[Tensor, Tensor]]
Examples
>>> log_flow = LogFlow()
>>> x = torch.tensor([[[0.1, 0.2, 0.3]]]) # Shape: (1, 1, 3)
>>> x_mask = torch.tensor([[[1, 1, 1]]]) # Shape: (1, 1, 3)
>>> y, logdet = log_flow(x, x_mask) # Forward propagation
>>> x_inv = log_flow(y, x_mask, inverse=True) # Inverse propagation
NOTE
The log transformation applied here is element-wise, and the log-determinant is computed by summing the negative log of the transformed input tensor over the time and channel dimensions.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x: Tensor, x_mask: Tensor, inverse: bool = False, eps: float = 1e-05, **kwargs) → Tensor | Tuple[Tensor, Tensor]
Calculate forward propagation.
- Parameters:
- x (Tensor) – Input tensor (B, channels, T).
- x_mask (Tensor) – Mask tensor (B, 1, T).
- inverse (bool) – Whether to inverse the flow.
- eps (float) – Epsilon for log.
- Returns: Output tensor (B, channels, T). Tensor: Log-determinant tensor for NLL (B,) if not inverse.
- Return type: Tensor