espnet2.gan_tts.joint.joint_text2wav.JointText2Wav
espnet2.gan_tts.joint.joint_text2wav.JointText2Wav
class espnet2.gan_tts.joint.joint_text2wav.JointText2Wav(idim: int, odim: int, segment_size: int = 32, sampling_rate: int = 22050, text2mel_type: str = 'fastspeech2', text2mel_params: Dict[str, Any] = {'adim': 384, 'aheads': 2, 'conformer_activation_type': 'swish', 'conformer_dec_kernel_size': 31, 'conformer_enc_kernel_size': 7, 'conformer_pos_enc_layer_type': 'rel_pos', 'conformer_rel_pos_type': 'latest', 'conformer_self_attn_layer_type': 'rel_selfattn', 'decoder_concat_after': False, 'decoder_normalize_before': True, 'decoder_type': 'conformer', 'dlayers': 4, 'dunits': 1536, 'duration_predictor_chans': 384, 'duration_predictor_dropout_rate': 0.1, 'duration_predictor_kernel_size': 3, 'duration_predictor_layers': 2, 'elayers': 4, 'encoder_concat_after': False, 'encoder_normalize_before': True, 'encoder_type': 'conformer', 'energy_embed_dropout': 0.5, 'energy_embed_kernel_size': 1, 'energy_predictor_chans': 384, 'energy_predictor_dropout': 0.5, 'energy_predictor_kernel_size': 3, 'energy_predictor_layers': 2, 'eunits': 1536, 'gst_conv_chans_list': [32, 32, 64, 64, 128, 128], 'gst_conv_kernel_size': 3, 'gst_conv_layers': 6, 'gst_conv_stride': 2, 'gst_gru_layers': 1, 'gst_gru_units': 128, 'gst_heads': 4, 'gst_tokens': 10, 'init_dec_alpha': 1.0, 'init_enc_alpha': 1.0, 'init_type': 'xavier_uniform', 'langs': -1, 'pitch_embed_dropout': 0.5, 'pitch_embed_kernel_size': 1, 'pitch_predictor_chans': 384, 'pitch_predictor_dropout': 0.5, 'pitch_predictor_kernel_size': 5, 'pitch_predictor_layers': 5, 'positionwise_conv_kernel_size': 1, 'positionwise_layer_type': 'conv1d', 'postnet_chans': 512, 'postnet_dropout_rate': 0.5, 'postnet_filts': 5, 'postnet_layers': 5, 'reduction_factor': 1, 'spk_embed_dim': None, 'spk_embed_integration_type': 'add', 'spks': -1, 'stop_gradient_from_energy_predictor': False, 'stop_gradient_from_pitch_predictor': True, 'transformer_dec_attn_dropout_rate': 0.1, 'transformer_dec_dropout_rate': 0.1, 'transformer_dec_positional_dropout_rate': 0.1, 'transformer_enc_attn_dropout_rate': 0.1, 'transformer_enc_dropout_rate': 0.1, 'transformer_enc_positional_dropout_rate': 0.1, 'use_batch_norm': True, 'use_cnn_in_conformer': True, 'use_gst': False, 'use_macaron_style_in_conformer': True, 'use_masking': False, 'use_scaled_pos_enc': True, 'use_weighted_masking': False, 'zero_triu': False}, vocoder_type: str = 'hifigan_generator', vocoder_params: Dict[str, Any] = {'bias': True, 'channels': 512, 'global_channels': -1, 'kernel_size': 7, 'nonlinear_activation': 'LeakyReLU', 'nonlinear_activation_params': {'negative_slope': 0.1}, 'out_channels': 1, 'resblock_dilations': [[1, 3, 5], [1, 3, 5], [1, 3, 5]], 'resblock_kernel_sizes': [3, 7, 11], 'upsample_kernel_sizes': [16, 16, 4, 4], 'upsample_scales': [8, 8, 2, 2], 'use_additional_convs': True, 'use_weight_norm': True}, use_pqmf: bool = False, pqmf_params: Dict[str, Any] = {'beta': 9.0, 'cutoff_ratio': 0.142, 'subbands': 4, 'taps': 62}, discriminator_type: str = 'hifigan_multi_scale_multi_period_discriminator', discriminator_params: Dict[str, Any] = {'follow_official_norm': False, 'period_discriminator_params': {'bias': True, 'channels': 32, 'downsample_scales': [3, 3, 3, 3, 1], 'in_channels': 1, 'kernel_sizes': [5, 3], 'max_downsample_channels': 1024, 'nonlinear_activation': 'LeakyReLU', 'nonlinear_activation_params': {'negative_slope': 0.1}, 'out_channels': 1, 'use_spectral_norm': False, 'use_weight_norm': True}, 'periods': [2, 3, 5, 7, 11], 'scale_discriminator_params': {'bias': True, 'channels': 128, 'downsample_scales': [2, 2, 4, 4, 1], 'in_channels': 1, 'kernel_sizes': [15, 41, 5, 3], 'max_downsample_channels': 1024, 'max_groups': 16, 'nonlinear_activation': 'LeakyReLU', 'nonlinear_activation_params': {'negative_slope': 0.1}, 'out_channels': 1, 'use_spectral_norm': False, 'use_weight_norm': True}, 'scale_downsample_pooling': 'AvgPool1d', 'scale_downsample_pooling_params': {'kernel_size': 4, 'padding': 2, 'stride': 2}, 'scales': 1}, generator_adv_loss_params: Dict[str, Any] = {'average_by_discriminators': False, 'loss_type': 'mse'}, discriminator_adv_loss_params: Dict[str, Any] = {'average_by_discriminators': False, 'loss_type': 'mse'}, use_feat_match_loss: bool = True, feat_match_loss_params: Dict[str, Any] = {'average_by_discriminators': False, 'average_by_layers': False, 'include_final_outputs': True}, use_mel_loss: bool = True, mel_loss_params: Dict[str, Any] = {'fmax': None, 'fmin': 0, 'fs': 22050, 'hop_length': 256, 'log_base': None, 'n_fft': 1024, 'n_mels': 80, 'win_length': None, 'window': 'hann'}, lambda_text2mel: float = 1.0, lambda_adv: float = 1.0, lambda_feat_match: float = 2.0, lambda_mel: float = 45.0, cache_generator_outputs: bool = False)
Bases: AbsGANTTS
Joint text-to-wav module for end-to-end training.
This class serves as a general framework to jointly train the text-to-mel and vocoder components in a GAN-based TTS system. It combines the functionalities of text-to-mel generation and vocoder processing to synthesize high-quality speech waveforms from input text.
segment_size
Segment size for random windowed inputs.
- Type: int
use_pqmf
Whether to use PQMF for multi-band vocoder.
- Type: bool
generator
Contains the text2mel and vocoder models.
- Type: torch.nn.ModuleDict
discriminator
Discriminator model for adversarial training.
- Type: torch.nn.Module
generator
Loss function for generator.
discriminator
Loss function for discriminator.
use_feat_match_loss
Whether to use feature matching loss.
- Type: bool
feat_match_loss
Loss function for feature matching.
- Type:FeatureMatchLoss
use_mel_loss
Whether to use mel loss.
- Type: bool
mel_loss
Loss function for mel spectrogram.
- Type:MelSpectrogramLoss
lambda_text2mel
Loss scaling coefficient for text2mel model loss.
- Type: float
lambda_adv
Loss scaling coefficient for adversarial loss.
- Type: float
lambda_feat_match
Loss scaling coefficient for feature matching loss.
- Type: float
lambda_mel
Loss scaling coefficient for mel loss.
- Type: float
cache_generator_outputs
Whether to cache generator outputs.
- Type: bool
_cache
Cached outputs from the generator.
- Type: Any
fs
Sampling rate for saving wav files.
- Type: int
spks
Number of speakers.
- Type: int
langs
Number of languages.
- Type: int
spk_embed_dim
Speaker embedding dimension.
Type: int
Parameters:
- idim (int) – Input vocabulary size.
- odim (int) – Acoustic feature dimension.
- segment_size (int) – Segment size for random windowed inputs.
- sampling_rate (int) – Sampling rate for saving waveform during inference.
- text2mel_type (str) – The type of text2mel model to use.
- text2mel_params (Dict *[*str , Any ]) – Parameter dictionary for text2mel model.
- vocoder_type (str) – The type of vocoder model to use.
- vocoder_params (Dict *[*str , Any ]) – Parameter dictionary for vocoder model.
- use_pqmf (bool) – Whether to use PQMF for multi-band vocoder.
- pqmf_params (Dict *[*str , Any ]) – Parameter dictionary for PQMF module.
- discriminator_type (str) – Type of discriminator to use.
- discriminator_params (Dict *[*str , Any ]) – Parameter dictionary for discriminator.
- generator_adv_loss_params (Dict *[*str , Any ]) – Parameters for generator adversarial loss.
- discriminator_adv_loss_params (Dict *[*str , Any ]) – Parameters for discriminator adversarial loss.
- use_feat_match_loss (bool) – Whether to use feature matching loss.
- feat_match_loss_params (Dict *[*str , Any ]) – Parameters for feature matching loss.
- use_mel_loss (bool) – Whether to use mel loss.
- mel_loss_params (Dict *[*str , Any ]) – Parameters for mel loss.
- lambda_text2mel (float) – Coefficient for text2mel loss.
- lambda_adv (float) – Coefficient for adversarial loss.
- lambda_feat_match (float) – Coefficient for feature matching loss.
- lambda_mel (float) – Coefficient for mel loss.
- cache_generator_outputs (bool) – Whether to cache generator outputs.
####### Examples
Initialize the JointText2Wav model
model = JointText2Wav(idim=256, odim=80)
Forward pass with text and features
output = model.forward(text=torch.randint(0, 256, (1, 10)),
text_lengths=torch.tensor([10]), feats=torch.randn(1, 20, 80), feats_lengths=torch.tensor([20]), speech=torch.randn(1, 16000), speech_lengths=torch.tensor([16000]))
Run inference
inference_output = model.inference(text=torch.randint(0, 256, (10,)))
- Raises:ValueError – If the input dimensions are inconsistent or invalid.
Initialize JointText2Wav module.
- Parameters:
- idim (int) – Input vocabrary size.
- odim (int) – Acoustic feature dimension. The actual output channels will be 1 since the model is the end-to-end text-to-wave model but for the compatibility odim is used to indicate the acoustic feature dimension.
- segment_size (int) – Segment size for random windowed inputs.
- sampling_rate (int) – Sampling rate, not used for the training but it will be referred in saving waveform during the inference.
- text2mel_type (str) – The text2mel model type.
- text2mel_params (Dict *[*str , Any ]) – Parameter dict for text2mel model.
- use_pqmf (bool) – Whether to use PQMF for multi-band vocoder.
- pqmf_params (Dict *[*str , Any ]) – Parameter dict for PQMF module.
- vocoder_type (str) – The vocoder model type.
- vocoder_params (Dict *[*str , Any ]) – Parameter dict for vocoder model.
- discriminator_type (str) – Discriminator type.
- discriminator_params (Dict *[*str , Any ]) – Parameter dict for discriminator.
- generator_adv_loss_params (Dict *[*str , Any ]) – Parameter dict for generator adversarial loss.
- discriminator_adv_loss_params (Dict *[*str , Any ]) – Parameter dict for discriminator adversarial loss.
- use_feat_match_loss (bool) – Whether to use feat match loss.
- feat_match_loss_params (Dict *[*str , Any ]) – Parameter dict for feat match loss.
- use_mel_loss (bool) – Whether to use mel loss.
- mel_loss_params (Dict *[*str , Any ]) – Parameter dict for mel loss.
- lambda_text2mel (float) – Loss scaling coefficient for text2mel model loss.
- lambda_adv (float) – Loss scaling coefficient for adversarial loss.
- lambda_feat_match (float) – Loss scaling coefficient for feat match loss.
- lambda_mel (float) – Loss scaling coefficient for mel loss.
- cache_generator_outputs (bool) – Whether to cache generator outputs.
forward(text: Tensor, text_lengths: Tensor, feats: Tensor, feats_lengths: Tensor, speech: Tensor, speech_lengths: Tensor, forward_generator: bool = True, **kwargs) → Dict[str, Any]
Perform generator forward.
- Parameters:
- text (Tensor) – Text index tensor (B, T_text).
- text_lengths (Tensor) – Text length tensor (B,).
- feats (Tensor) – Feature tensor (B, T_feats, aux_channels).
- feats_lengths (Tensor) – Feature length tensor (B,).
- speech (Tensor) – Speech waveform tensor (B, T_wav).
- speech_lengths (Tensor) – Speech length tensor (B,).
- forward_generator (bool) – Whether to forward generator.
- Returns:
- loss (Tensor): Loss scalar tensor.
- stats (Dict[str, float]): Statistics to be monitored.
- weight (Tensor): Weight tensor to summarize losses.
- optim_idx (int): Optimizer index (0 for G and 1 for D).
- Return type: Dict[str, Any]
####### Examples
>>> model = JointText2Wav(idim=40, odim=80)
>>> text = torch.randint(0, 100, (2, 10)) # Batch of 2, length 10
>>> text_lengths = torch.tensor([10, 9])
>>> feats = torch.randn(2, 80, 40) # Batch of 2, 80 time steps, 40 features
>>> feats_lengths = torch.tensor([80, 70])
>>> speech = torch.randn(2, 16000) # Batch of 2, 16000 samples
>>> speech_lengths = torch.tensor([16000, 15000])
>>> output = model.forward(text, text_lengths, feats, feats_lengths,
... speech, speech_lengths)
>>> print(output.keys())
dict_keys(['loss', 'stats', 'weight', 'optim_idx'])
inference(text: Tensor, **kwargs) → Dict[str, Tensor]
Run inference.
- Parameters:text (Tensor) – Input text index tensor (T_text,).
- Returns:
- wav (Tensor): Generated waveform tensor (T_wav,).
- feat_gan (Tensor): Generated feature tensor (T_text, C).
- Return type: Dict[str, Tensor]
property require_raw_speech
Return whether or not speech is required.
property require_vocoder
Return whether or not vocoder is required.