espnet2.uasr.discriminator.abs_discriminator.AbsDiscriminator
espnet2.uasr.discriminator.abs_discriminator.AbsDiscriminator
class espnet2.uasr.discriminator.abs_discriminator.AbsDiscriminator(*args, **kwargs)
Bases: Module
, ABC
Abstract base class for implementing discriminators in the ESPnet2 framework.
This class defines the interface for all discriminator implementations. It inherits from torch.nn.Module and requires subclasses to implement the forward method. The forward method is responsible for processing input tensors and producing output tensors, typically used in adversarial training settings.
None
- Parameters:
- xs_pad (torch.Tensor) – A padded input tensor containing the features.
- padding_mask (torch.Tensor) – A tensor indicating the padding positions in xs_pad.
- Returns: The output tensor produced by the discriminator after : processing the input features.
- Return type: torch.Tensor
- Raises:NotImplementedError – If the forward method is not implemented in a subclass.
####### Examples
To create a custom discriminator, inherit from this class and implement the forward method as follows:
``
`
python class MyDiscriminator(AbsDiscriminator):
def forward(self, xs_pad, padding_mask): : # Custom processing logic here return output_tensor
``
`
NOTE
This class is intended to be subclassed, and should not be instantiated directly.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
abstract forward(xs_pad: Tensor, padding_mask: Tensor) → Tensor
Computes the forward pass of the discriminator model.
This method takes padded input sequences and their corresponding padding masks to produce a tensor output. It must be implemented by any subclass of AbsDiscriminator.
xs_pad
Padded input sequences for the discriminator.
- Type: torch.Tensor
padding_mask
Mask indicating the positions of padding in
- Type: torch.Tensor
the input sequences.
- Parameters:
- xs_pad (torch.Tensor) – A tensor of shape (batch_size, seq_length,
- data. (feature_dim ) containing the padded input)
- padding_mask (torch.Tensor) – A tensor of shape (batch_size, seq_length)
- valid (indicating which elements of xs_pad are)
- Returns: A tensor containing the output from the discriminator, with shape (batch_size, output_dim).
- Return type: torch.Tensor
- Raises:NotImplementedError – If the method is not implemented in a subclass.
####### Examples
>>> discriminator = MyDiscriminator() # MyDiscriminator should extend AbsDiscriminator
>>> xs_pad = torch.rand(32, 10, 64) # Example input (32 samples, seq_len=10, features=64)
>>> padding_mask = torch.ones(32, 10) # Example mask (no padding)
>>> output = discriminator(xs_pad, padding_mask)
>>> print(output.shape) # Output shape will depend on the implementation
NOTE
This method must be overridden in any subclass of AbsDiscriminator.