espnet2.gan_tts.abs_gan_tts.AbsGANTTS
espnet2.gan_tts.abs_gan_tts.AbsGANTTS
class espnet2.gan_tts.abs_gan_tts.AbsGANTTS(*args, **kwargs)
Bases: AbsTTS
, ABC
Abstract class for GAN-based Text-to-Speech (TTS) models.
This class serves as a blueprint for implementing GAN-based TTS models. It inherits from the AbsTTS class and requires the implementation of the forward method, which is responsible for calculating the loss for either the generator or the discriminator.
None
- Parameters:
- forward_generator – A callable that generates the output from the TTS model.
- *args – Additional positional arguments to be passed to the forward generator.
- **kwargs – Additional keyword arguments to be passed to the forward generator.
- Returns:
- A tensor representing the generator loss.
- A dictionary of tensors for various losses.
- An integer indicating the current epoch or step.
- Return type: A dictionary containing either
- Yields: None
- Raises:NotImplementedError – If the forward method is not implemented in the subclass.
####### Examples
>>> class MyGANTTS(AbsGANTTS):
... def forward(self, forward_generator, *args, **kwargs):
... # Implement the forward logic here
... return {"loss": torch.tensor(0.0)}
>>> gan_tts = MyGANTTS()
>>> loss = gan_tts.forward(my_forward_generator, arg1, arg2)
>>> print(loss)
NOTE
Subclasses must implement the forward method to define the specific behavior of the GAN-based TTS model.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
abstract forward(forward_generator, *args, **kwargs) → Dict[str, Tensor | Dict[str, Tensor] | int]
Returns the generator or discriminator loss for the GAN-based TTS model.
This method is an abstract method that must be implemented by any subclass of AbsGANTTS. The implementation should define how to compute the loss based on the generator’s output and any additional inputs provided.
- Parameters:
- forward_generator – The generator function that produces output from the input data.
- *args – Variable length argument list for additional inputs required by the generator.
- **kwargs – Arbitrary keyword arguments that may be needed for the generator function.
- Returns:
- A torch.Tensor representing the loss value.
- A nested dictionary of torch.Tensor(s) if multiple losses are computed.
- An integer representing any additional metric, if applicable.
- Return type: A dictionary containing
- Raises:NotImplementedError – If the method is not implemented in a subclass.
####### Examples
Here is an example of how to use the forward method in a subclass:
``
`
python class MyGANTTS(AbsGANTTS):
def forward(self, forward_generator,
*
args,
**
kwargs): : # Call the generator and compute loss output = forward_generator( <br/>
*
<br/> args, <br/>
**
<br/> kwargs) loss = self.compute_loss(output) return
``
`
NOTE
Subclasses must provide an implementation of this method to function correctly as a GAN-based TTS model.