espnet2.gan_svs.vits.modules.Projection
espnet2.gan_svs.vits.modules.Projection
class espnet2.gan_svs.vits.modules.Projection(hidden_channels, out_channels)
Bases: Module
Projection is a PyTorch module that performs a linear transformation on the
input tensor using a 1D convolutional layer. This module is typically used in the context of generative models, where it projects the hidden representation to a specified output dimension.
hidden_channels
The number of input channels for the projection.
- Type: int
out_channels
The number of output channels for the projection.
- Type: int
proj
A 1D convolutional layer that performs the projection.
Type: torch.nn.Conv1d
Parameters:
- hidden_channels (int) – The number of input channels.
- out_channels (int) – The number of output channels.
Returns: A tuple containing: : - m_p (torch.Tensor): The mean parameters of the projected output.
- logs_p (torch.Tensor): The log variance parameters of the projected output.
Return type: Tuple[torch.Tensor, torch.Tensor]
Raises:
- ValueError – If the input tensor x and the mask x_mask are not of
- compatible shapes. –
####### Examples
>>> projection = Projection(hidden_channels=64, out_channels=32)
>>> x = torch.randn(8, 64, 100) # (B, attention_dim, T_text)
>>> x_mask = torch.ones(8, 1, 100) # Mask with shape (B, 1, T_text)
>>> m_p, logs_p = projection(x, x_mask)
>>> m_p.shape # Should be (8, 32, 100)
>>> logs_p.shape # Should be (8, 32, 100)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x, x_mask)
Forward pass for the Projection module.
This method applies a 1D convolution to the input tensor x and computes the masked output statistics. The output consists of two tensors: m_p and logs_p, which represent the mean and log variance of the projected input, respectively.
- Parameters:
- x (torch.Tensor) – Input tensor of shape (B, attention_dim, T_text).
- x_mask (torch.Tensor) – Mask tensor of shape (B, 1, T_text) to apply masking during projection.
- Returns: A tuple containing: : - m_p (torch.Tensor): The mean tensor of shape (B, out_channels, T_text).
- logs_p (torch.Tensor): The log variance tensor of shape (B, out_channels, T_text).
- Return type: tuple
####### Examples
>>> projection = Projection(hidden_channels=128, out_channels=64)
>>> x = torch.randn(32, 128, 50) # Batch of 32, 128 features, 50 time steps
>>> x_mask = torch.ones(32, 1, 50) # Mask with all ones
>>> m_p, logs_p = projection(x, x_mask)
>>> m_p.shape
torch.Size([32, 64, 50])
>>> logs_p.shape
torch.Size([32, 64, 50])
NOTE
The input tensor x should be of appropriate shape and the mask should be broadcastable to match the dimensions of x.