espnet2.tts.gst.style_encoder.ReferenceEncoder
espnet2.tts.gst.style_encoder.ReferenceEncoder
class espnet2.tts.gst.style_encoder.ReferenceEncoder(idim=80, conv_layers: int = 6, conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), conv_kernel_size: int = 3, conv_stride: int = 2, gru_layers: int = 1, gru_units: int = 128)
Bases: Module
Reference encoder module.
This module is the reference encoder introduced in Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End Speech Synthesis.
conv_layers
The number of conv layers in the reference encoder.
- Type: int
kernel_size
Kernel size of conv layers in the reference encoder.
- Type: int
stride
Stride size of conv layers in the reference encoder.
- Type: int
padding
Padding size used in convolution layers.
Type: int
Parameters:
- idim (int , optional) – Dimension of the input mel-spectrogram.
- conv_layers (int , optional) – The number of conv layers in the reference encoder.
- conv_chans_list (Sequence *[*int ] , optional) – List of the number of channels of conv layers in the reference encoder.
- conv_kernel_size (int , optional) – Kernel size of conv layers in the reference encoder.
- conv_stride (int , optional) – Stride size of conv layers in the reference encoder.
- gru_layers (int , optional) – The number of GRU layers in the reference encoder.
- gru_units (int , optional) – The number of GRU units in the reference encoder.
####### Examples
>>> encoder = ReferenceEncoder(idim=80, conv_layers=6)
>>> input_tensor = torch.randn(16, 100, 80) # (B, Lmax, idim)
>>> output = encoder(input_tensor)
>>> print(output.shape) # Output shape: (B, gru_units)
Initilize reference encoder module.
forward(speech: Tensor) → Tensor
Calculate forward propagation.
This method performs forward propagation through the style encoder, taking a batch of padded target features and returning style token embeddings.
- Parameters:speech (Tensor) – Batch of padded target features (B, Lmax, odim).
- Returns: Style token embeddings (B, token_dim).
- Return type: Tensor
####### Examples
>>> encoder = StyleEncoder()
>>> speech_input = torch.randn(8, 100, 80) # Example input
>>> style_embeddings = encoder.forward(speech_input)
>>> print(style_embeddings.shape) # Should output: torch.Size([8, 256])