espnet2.gan_tts.espnet_model.ESPnetGANTTSModel
espnet2.gan_tts.espnet_model.ESPnetGANTTSModel
class espnet2.gan_tts.espnet_model.ESPnetGANTTSModel(feats_extract: AbsFeatsExtract | None, normalize: InversibleInterface | None, pitch_extract: AbsFeatsExtract | None, pitch_normalize: InversibleInterface | None, energy_extract: AbsFeatsExtract | None, energy_normalize: InversibleInterface | None, tts: AbsGANTTS)
Bases: AbsGANESPnetModel
ESPnet model for GAN-based text-to-speech task.
This class implements a GAN-based model for generating speech from text. It uses various feature extraction and normalization layers to process input data, and it relies on a generator and discriminator for training and inference.
feats_extract
Feature extraction module.
- Type: Optional[AbsFeatsExtract]
normalize
Normalization module for features.
- Type: Optional[AbsNormalize and InversibleInterface]
pitch_extract
Feature extraction module for pitch.
- Type: Optional[AbsFeatsExtract]
pitch_normalize
Normalization module for pitch features.
- Type: Optional[AbsNormalize and InversibleInterface]
energy_extract
Feature extraction module for energy.
- Type: Optional[AbsFeatsExtract]
energy_normalize
Normalization module for energy features.
- Type: Optional[AbsNormalize and InversibleInterface]
tts
Text-to-speech module that includes the generator and discriminator.
Type:AbsGANTTS
Parameters:
- feats_extract (Optional [AbsFeatsExtract ]) – Feature extraction module.
- normalize (Optional *[*AbsNormalize and InversibleInterface ]) – Normalization module for features.
- pitch_extract (Optional [AbsFeatsExtract ]) – Feature extraction module for pitch.
- pitch_normalize (Optional *[*AbsNormalize and InversibleInterface ]) – Normalization module for pitch features.
- energy_extract (Optional [AbsFeatsExtract ]) – Feature extraction module for energy.
- energy_normalize (Optional *[*AbsNormalize and InversibleInterface ]) – Normalization module for energy features.
- tts (AbsGANTTS) – Text-to-speech module.
Raises:AssertionError – If the generator or discriminator is not properly registered in the TTS module.
######### Examples
>>> model = ESPnetGANTTSModel(feats_extract=some_feats_extract,
... normalize=some_normalize,
... pitch_extract=some_pitch_extract,
... pitch_normalize=some_pitch_normalize,
... energy_extract=some_energy_extract,
... energy_normalize=some_energy_normalize,
... tts=some_tts_module)
>>> output = model.forward(text_tensor, text_lengths_tensor,
... speech_tensor, speech_lengths_tensor)
>>> print(output)
NOTE
Ensure that the tts parameter contains the required attributes generator and discriminator.
Initialize ESPnetGANTTSModel module.
collect_feats(text: Tensor, text_lengths: Tensor, speech: Tensor, speech_lengths: Tensor, durations: Tensor | None = None, durations_lengths: Tensor | None = None, pitch: Tensor | None = None, pitch_lengths: Tensor | None = None, energy: Tensor | None = None, energy_lengths: Tensor | None = None, spembs: Tensor | None = None, sids: Tensor | None = None, lids: Tensor | None = None, **kwargs) → Dict[str, Tensor]
Calculate features and return them as a dict.
This method extracts various features from the input speech waveform, including pitch and energy, and organizes them into a dictionary format.
- Parameters:
- text (Tensor) – Text index tensor (B, T_text).
- text_lengths (Tensor) – Text length tensor (B,).
- speech (Tensor) – Speech waveform tensor (B, T_wav).
- speech_lengths (Tensor) – Speech length tensor (B, 1).
- durations (Optional *[*Tensor ]) – Duration tensor.
- durations_lengths (Optional *[*Tensor ]) – Duration length tensor (B,).
- pitch (Optional *[*Tensor ]) – Pitch tensor.
- pitch_lengths (Optional *[*Tensor ]) – Pitch length tensor (B,).
- energy (Optional *[*Tensor ]) – Energy tensor.
- energy_lengths (Optional *[*Tensor ]) – Energy length tensor (B,).
- spembs (Optional *[*Tensor ]) – Speaker embedding tensor (B, D).
- sids (Optional *[*Tensor ]) – Speaker index tensor (B, 1).
- lids (Optional *[*Tensor ]) – Language ID tensor (B, 1).
- Returns: Dictionary containing the extracted features, including:
- feats: Extracted features tensor.
- feats_lengths: Lengths of the extracted features.
- pitch: Extracted pitch tensor.
- pitch_lengths: Lengths of the extracted pitch.
- energy: Extracted energy tensor.
- energy_lengths: Lengths of the extracted energy.
- Return type: Dict[str, Tensor]
######### Examples
>>> model.collect_feats(
... text=text_tensor,
... text_lengths=text_lengths_tensor,
... speech=speech_tensor,
... speech_lengths=speech_lengths_tensor
... )
{'feats': tensor(...), 'feats_lengths': tensor(...),
'pitch': tensor(...), 'pitch_lengths': tensor(...),
'energy': tensor(...), 'energy_lengths': tensor(...)}
forward(text: Tensor, text_lengths: Tensor, speech: Tensor, speech_lengths: Tensor, durations: Tensor | None = None, durations_lengths: Tensor | None = None, pitch: Tensor | None = None, pitch_lengths: Tensor | None = None, energy: Tensor | None = None, energy_lengths: Tensor | None = None, spembs: Tensor | None = None, sids: Tensor | None = None, lids: Tensor | None = None, forward_generator: bool = True, **kwargs) → Dict[str, Any]
Return generator or discriminator loss with dict format.
This method processes the input text and corresponding features to compute the generator or discriminator loss in a GAN-based text-to-speech system. It can also extract and normalize features if required.
- Parameters:
- text (Tensor) – Text index tensor of shape (B, T_text).
- text_lengths (Tensor) – Text length tensor of shape (B,).
- speech (Tensor) – Speech waveform tensor of shape (B, T_wav).
- speech_lengths (Tensor) – Speech length tensor of shape (B,).
- durations (Optional *[*Tensor ]) – Duration tensor of shape (B, T).
- durations_lengths (Optional *[*Tensor ]) – Duration length tensor of shape (B,).
- pitch (Optional *[*Tensor ]) – Pitch tensor of shape (B, T).
- pitch_lengths (Optional *[*Tensor ]) – Pitch length tensor of shape (B,).
- energy (Optional *[*Tensor ]) – Energy tensor of shape (B, T).
- energy_lengths (Optional *[*Tensor ]) – Energy length tensor of shape (B,).
- spembs (Optional *[*Tensor ]) – Speaker embedding tensor of shape (B, D).
- sids (Optional *[*Tensor ]) – Speaker ID tensor of shape (B, 1).
- lids (Optional *[*Tensor ]) – Language ID tensor of shape (B, 1).
- forward_generator (bool) – Flag to determine if the generator should be used. Default is True.
- kwargs – Additional arguments, where “utt_id” may be included.
- Returns: A dictionary containing: : - loss (Tensor): Loss scalar tensor.
- stats (Dict[str, float]): Statistics to be monitored.
- weight (Tensor): Weight tensor to summarize losses.
- optim_idx (int): Optimizer index (0 for G and 1 for D).
- Return type: Dict[str, Any]
######### Examples
>>> model.forward(text_tensor, text_lengths_tensor, speech_tensor,
... speech_lengths_tensor)
{'loss': tensor(0.1234), 'stats': {'accuracy': 0.98},
'weight': tensor(1.0), 'optim_idx': 0}
NOTE
Ensure that the input tensors have the correct shapes as specified in the arguments section.