espnet2.gan_svs.vits.prior_decoder.PriorDecoder
espnet2.gan_svs.vits.prior_decoder.PriorDecoder
class espnet2.gan_svs.vits.prior_decoder.PriorDecoder(out_channels: int = 384, attention_dim: int = 192, attention_heads: int = 2, linear_units: int = 768, blocks: int = 6, positionwise_layer_type: str = 'conv1d', positionwise_conv_kernel_size: int = 3, positional_encoding_layer_type: str = 'rel_pos', self_attention_layer_type: str = 'rel_selfattn', activation_type: str = 'swish', normalize_before: bool = True, use_macaron_style: bool = False, use_conformer_conv: bool = False, conformer_kernel_size: int = 7, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.0, attention_dropout_rate: float = 0.0, global_channels: int = 0)
Bases: Module
PriorDecoder is a neural network module that implements a prior decoder
architecture for speech synthesis. It processes input features through multiple layers, applying attention mechanisms and optional convolutional structures to generate output features suitable for various tasks.
prenet
A convolutional layer that processes input features before passing them to the decoder.
- Type: torch.nn.Conv1d
decoder
The core encoder module that applies attention mechanisms to the input features.
- Type:Encoder
proj
A projection layer that maps the decoder output to the desired output channel size.
- Type: torch.nn.Conv1d
conv
An optional convolutional layer for handling global channels if specified.
Type: torch.nn.Conv1d, optional
Parameters:
- out_channels (int) – Output channels of the prior decoder. Defaults to 384.
- attention_dim (int) – Dimension of the attention mechanism. Defaults to 192.
- attention_heads (int) – Number of attention heads. Defaults to 2.
- linear_units (int) – Number of units in the linear layer. Defaults to 768.
- blocks (int) – Number of blocks in the encoder. Defaults to 6.
- positionwise_layer_type (str) – Type of the positionwise layer. Defaults to “conv1d”.
- positionwise_conv_kernel_size (int) – Kernel size of the positionwise convolutional layer. Defaults to 3.
- positional_encoding_layer_type (str) – Type of positional encoding layer. Defaults to “rel_pos”.
- self_attention_layer_type (str) – Type of self-attention layer. Defaults to “rel_selfattn”.
- activation_type (str) – Type of activation. Defaults to “swish”.
- normalize_before (bool) – Flag for normalization. Defaults to True.
- use_macaron_style (bool) – Flag for macaron style. Defaults to False.
- use_conformer_conv (bool) – Flag for using conformer convolution. Defaults to False.
- conformer_kernel_size (int) – Kernel size for conformer convolution. Defaults to 7.
- dropout_rate (float) – Dropout rate. Defaults to 0.1.
- positional_dropout_rate (float) – Dropout rate for positional encoding. Defaults to 0.0.
- attention_dropout_rate (float) – Dropout rate for attention. Defaults to 0.0.
- global_channels (int) – Number of global channels. Defaults to 0.
Returns: A tuple containing: : - Output tensor of shape (B, out_channels, T).
- Output mask tensor of shape (B, 1, T).
Return type: Tuple[Tensor, Tensor]
####### Examples
Example of initializing the PriorDecoder
decoder = PriorDecoder(out_channels=384, attention_dim=192)
Example of a forward pass
x = torch.randn(32, 194, 100) # (B, attention_dim + 2, T) x_lengths = torch.randint(1, 100, (32,)) # (B,) g = torch.randn(32, 0, 1) # (B, global_channels, 1) if global_channels > 0 output, mask = decoder(x, x_lengths, g)
NOTE
Ensure that the input tensor x has the correct shape of (B, attention_dim + 2, T) and that x_lengths corresponds to the lengths of the input sequences.
Initialize prior decoder module.
- Parameters:
- out_channels (int) – Output channels of the prior decoder. Defaults to 384.
- attention_dim (int) – Dimension of the attention mechanism. Defaults to 192.
- attention_heads (int) – Number of attention heads. Defaults to 2.
- linear_units (int) – Number of units in the linear layer. Defaults to 768.
- blocks (int) – Number of blocks in the encoder. Defaults to 6.
- positionwise_layer_type (str) – Type of the positionwise layer. Defaults to “conv1d”.
- positionwise_conv_kernel_size (int) – Kernel size of the positionwise convolutional layer. Defaults to 3.
- positional_encoding_layer_type (str) – Type of positional encoding layer. Defaults to “rel_pos”.
- self_attention_layer_type (str) – Type of self-attention layer. Defaults to “rel_selfattn”.
- activation_type (str) – Type of activation. Defaults to “swish”.
- normalize_before (bool) – Flag for normalization. Defaults to True.
- use_macaron_style (bool) – Flag for macaron style. Defaults to False.
- use_conformer_conv (bool) – Flag for using conformer convolution. Defaults to False.
- conformer_kernel_size (int) – Kernel size for conformer convolution. Defaults to 7.
- dropout_rate (float) – Dropout rate. Defaults to 0.1.
- positional_dropout_rate (float) – Dropout rate for positional encoding. Defaults to 0.0.
- attention_dropout_rate (float) – Dropout rate for attention. Defaults to 0.0.
- global_channels (int) – Number of global channels. Defaults to 0.
forward(x, x_lengths, g=None)
Forward pass of the PriorDecoder module.
This method processes the input tensor through the prenet, applies optional multi-singer processing, and passes the result through the decoder and projection layers. It outputs the decoded tensor and a mask tensor.
- Parameters:
- x (Tensor) – Input tensor with shape (B, attention_dim + 2, T).
- x_lengths (Tensor) – Length tensor with shape (B,).
- g (Tensor , optional) – Tensor for multi-singer with shape (B, global_channels, 1). Defaults to None.
- Returns: A tuple containing: : - Tensor: Output tensor with shape (B, out_channels, T).
- Tensor: Output mask tensor with shape (B, 1, T).
- Return type: Tuple[Tensor, Tensor]
####### Examples
>>> decoder = PriorDecoder()
>>> input_tensor = torch.randn(8, 194, 10) # B=8, attention_dim=192+2, T=10
>>> lengths = torch.tensor([10, 10, 10, 10, 10, 10, 10, 10])
>>> output, mask = decoder(input_tensor, lengths)
>>> print(output.shape) # Should output: torch.Size([8, 384, 10])
>>> print(mask.shape) # Should output: torch.Size([8, 1, 10])