espnet2.uasr.espnet_model.ESPnetUASRModel
espnet2.uasr.espnet_model.ESPnetUASRModel
class espnet2.uasr.espnet_model.ESPnetUASRModel(frontend: AbsFrontend | None, segmenter: AbsSegmenter | None, generator: AbsGenerator, discriminator: AbsDiscriminator, losses: Dict[str, AbsUASRLoss], kenlm_path: str | None, token_list: list | None, max_epoch: int | None, vocab_size: int, cfg: Dict | None = None, pad: int = 1, sil_token: str = '<SIL>', sos_token: str = '<s>', eos_token: str = '</s>', skip_softmax: str2bool = False, use_gumbel: str2bool = False, use_hard_gumbel: str2bool = True, min_temperature: float = 0.1, max_temperature: float = 2.0, decay_temperature: float = 0.99995, use_collected_training_feats: str2bool = False)
Bases: AbsESPnetModel
Unsupervised ASR model.
This model is designed for unsupervised automatic speech recognition (ASR) tasks. The implementation is based on the work from FAIRSEQ: https://github.com/facebookresearch/fairseq/tree/main/examples/wav2vec/unsupervised
frontend
The frontend module for feature extraction.
- Type: Optional[AbsFrontend]
segmenter
The segmenter module for processing features.
- Type: Optional[AbsSegmenter]
generator
The generator module for producing output.
- Type:AbsGenerator
discriminator
The discriminator module for classification.
- Type:AbsDiscriminator
losses
A dictionary containing various loss functions.
- Type: Dict[str, AbsUASRLoss]
kenlm_path
Path to the KenLM language model.
- Type: Optional[str]
token_list
List of tokens for text representation.
- Type: Optional[list]
max_epoch
Maximum number of training epochs.
- Type: Optional[int]
vocab_size
Size of the vocabulary.
- Type: int
cfg
Configuration options.
- Type: Optional[Dict], optional
pad
Padding index.
- Type: int
sil_token
Silence token representation.
- Type: str
sos_token
Start-of-sequence token representation.
- Type: str
eos_token
End-of-sequence token representation.
- Type: str
skip_softmax
Whether to skip softmax in generation.
- Type: str2bool
use_gumbel
Whether to use Gumbel softmax.
- Type: str2bool
use_hard_gumbel
Whether to use hard Gumbel softmax.
- Type: str2bool
min_temperature
Minimum temperature for Gumbel softmax.
- Type: float
max_temperature
Maximum temperature for Gumbel softmax.
- Type: float
decay_temperature
Decay factor for temperature.
- Type: float
use_collected_training_feats
Whether to use collected features.
Type: str2bool
Parameters:
- frontend (Optional [AbsFrontend ]) – The frontend module for feature extraction.
- segmenter (Optional [AbsSegmenter ]) – The segmenter module for processing features.
- generator (AbsGenerator) – The generator module for producing output.
- discriminator (AbsDiscriminator) – The discriminator module for classification.
- losses (Dict *[*str , AbsUASRLoss ]) – A dictionary containing various loss functions.
- kenlm_path (Optional *[*str ]) – Path to the KenLM language model.
- token_list (Optional *[*list ]) – List of tokens for text representation.
- max_epoch (Optional *[*int ]) – Maximum number of training epochs.
- vocab_size (int) – Size of the vocabulary.
- cfg (Optional *[*Dict ] , optional) – Configuration options.
- pad (int) – Padding index (default: 1).
- sil_token (str) – Silence token representation (default: “<SIL>”).
- sos_token (str) – Start-of-sequence token representation (default: “<s>”).
- eos_token (str) – End-of-sequence token representation (default: “</s>”).
- skip_softmax (str2bool) – Whether to skip softmax in generation (default: False).
- use_gumbel (str2bool) – Whether to use Gumbel softmax (default: False).
- use_hard_gumbel (str2bool) – Whether to use hard Gumbel softmax (default: True).
- min_temperature (float) – Minimum temperature for Gumbel softmax (default: 0.1).
- max_temperature (float) – Maximum temperature for Gumbel softmax (default: 2.0).
- decay_temperature (float) – Decay factor for temperature (default: 0.99995).
- use_collected_training_feats (str2bool) – Whether to use collected features (default: False).
Raises:AssertionError – If KenLM is not installed or if invalid parameters are provided.
################# Examples
model = ESPnetUASRModel( : frontend=my_frontend, segmenter=my_segmenter, generator=my_generator, discriminator=my_discriminator, losses=my_losses, kenlm_path=’path/to/kenlm’, token_list=my_token_list, max_epoch=50, vocab_size=1000, pad=1, sil_token=’<SIL>’, sos_token=’<s>’, eos_token=’</s>’, skip_softmax=False, use_gumbel=True, use_hard_gumbel=True, min_temperature=0.1, max_temperature=2.0, decay_temperature=0.99995, use_collected_training_feats=False,
)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
collect_feats(speech: Tensor, speech_lengths: Tensor, text: Tensor | None = None, text_lengths: Tensor | None = None, **kwargs) → Dict[str, Tensor]
Collects features from the input speech tensor.
This method processes the input speech tensor through a frontend if available, applying necessary transformations to extract features. If no frontend is defined, the original speech tensor is returned as features.
- Parameters:
- speech (torch.Tensor) – Input speech tensor of shape (Batch, NSamples).
- speech_lengths (torch.Tensor) – Lengths of the input speech tensor of shape (Batch,).
- text (Optional *[*torch.Tensor ] , optional) – Input text tensor of shape (Batch, NText). Defaults to None.
- text_lengths (Optional *[*torch.Tensor ] , optional) – Lengths of the input text tensor of shape (Batch,). Defaults to None.
- **kwargs – Additional keyword arguments.
- Returns: A dictionary containing: : - ’feats’: Extracted features tensor of shape (Batch, NFrames, Dim).
- ’feats_lengths’: Lengths of the extracted features tensor of shape (Batch,).
- Return type: Dict[str, torch.Tensor]
################# Examples
>>> model = ESPnetUASRModel(...)
>>> speech = torch.randn(8, 16000) # Batch of 8 samples
>>> speech_lengths = torch.tensor([16000] * 8) # All samples are
>>> processed with length 16000
>>> features = model.collect_feats(speech, speech_lengths)
>>> print(features['feats'].shape) # Expected shape: (8, NFrames, Dim)
####### NOTE The frontend processing may include operations such as STFT or other feature extraction methods.
encode(speech: Tensor, speech_lengths: Tensor) → Tuple[Tensor, Tensor]
Encode the input speech tensor into features and create a padding mask.
This method extracts features from the input speech tensor and applies a segmentation process if a segmenter is provided. It returns the extracted features along with a padding mask that indicates the valid elements in the feature tensor.
- Parameters:
- speech (torch.Tensor) – Input speech tensor of shape (batch_size, num_samples).
- speech_lengths (torch.Tensor) – Tensor containing the lengths of each speech sample in the batch.
- Returns: A tuple containing: : - feats (torch.Tensor): Extracted feature tensor of shape : (batch_size, num_frames, feature_dim).
- padding_mask (torch.Tensor): Boolean tensor of shape : (batch_size, num_frames) indicating valid frames.
- Return type: Tuple[torch.Tensor, torch.Tensor]
################# Examples
>>> speech = torch.randn(8, 16000) # 8 samples of 1 second each
>>> speech_lengths = torch.tensor([16000] * 8) # all samples are 1s
>>> model = ESPnetUASRModel(...)
>>> feats, padding_mask = model.encode(speech, speech_lengths)
####### NOTE The input speech tensor is expected to be of shape (batch_size, num_samples) where num_samples can vary. The lengths tensor should be a 1D tensor containing the actual lengths of each sample in the batch.
- Raises:
- AssertionError – If speech_lengths does not have the expected
- dimensions. –
forward(speech: Tensor, speech_lengths: Tensor, text: Tensor | None = None, text_lengths: Tensor | None = None, pseudo_labels: Tensor | None = None, pseudo_labels_lengths: Tensor | None = None, do_validation: str2bool | None = False, print_hyp: str2bool | None = False, **kwargs) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Processes input speech data through the model components.
The forward method performs the following operations:
- Extracts features from the input speech.
- Generates fake samples using the generator.
- Optionally applies segmentation to the generated samples.
- Calculates losses based on discriminator predictions.
- If validation is enabled, computes validation statistics.
- Parameters:
- speech (torch.Tensor) – Input speech tensor of shape (batch_size, sequence_length).
- speech_lengths (torch.Tensor) – Lengths of the input speech sequences of shape (batch_size,).
- text (Optional *[*torch.Tensor ]) – Ground truth text tensor of shape (batch_size, max_text_length). Default is None.
- text_lengths (Optional *[*torch.Tensor ]) – Lengths of the text sequences of shape (batch_size,). Default is None.
- pseudo_labels (Optional *[*torch.Tensor ]) – Pseudo labels for training of shape (batch_size, max_label_length). Default is None.
- pseudo_labels_lengths (Optional *[*torch.Tensor ]) – Lengths of pseudo labels of shape (batch_size,). Default is None.
- do_validation (Optional *[*str2bool ]) – Whether to perform validation during training. Default is False.
- print_hyp (Optional *[*str2bool ]) – Whether to print hypotheses during validation. Default is False.
- **kwargs – Additional keyword arguments.
- Returns: A tuple containing:
- loss (torch.Tensor): The computed loss for the batch.
- stats (Dict[str, torch.Tensor]): A dictionary containing : various statistics from the forward pass.
- weight (torch.Tensor): The weight for the current batch.
- Return type: Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]
- Raises:AssertionError – If the input dimensions do not match expected shapes.
################# Examples
>>> model = ESPnetUASRModel(...)
>>> speech = torch.randn(32, 16000) # 32 samples of 1 second
>>> speech_lengths = torch.tensor([16000] * 32) # All samples are 1s
>>> text = torch.randint(0, 100, (32, 20)) # Random text
>>> text_lengths = torch.tensor([20] * 32) # All text lengths are 20
>>> loss, stats, weight = model.forward(speech, speech_lengths, text,
... text_lengths)
get_optim_index()
Get the optimization index based on the number of updates.
This method calculates the optimization index used for alternating updates in the training process. It returns 0 or 1 depending on whether the number of updates is even or odd. This can be useful for implementing different training strategies based on the optimization step.
- Returns: The optimization index (0 or 1) determined by the current number of updates.
- Return type: int
################# Examples
>>> model = ESPnetUASRModel(...)
>>> model.number_updates = 1
>>> model.get_optim_index()
1
>>> model.number_updates = 2
>>> model.get_optim_index()
0
inference(speech: Tensor, speech_lengths: Tensor)
Run inference on the given speech input to generate samples.
This method extracts features from the input speech and uses the generator to create fake samples based on those features.
- Parameters:
- speech (torch.Tensor) – A tensor containing the input speech signal.
- speech_lengths (torch.Tensor) – A tensor containing the lengths of the input speech signals.
- Returns: A tuple containing: : - generated_sample (torch.Tensor): The generated samples.
- generated_sample_padding_mask (torch.Tensor): The padding mask for the generated samples.
- Return type: Tuple[torch.Tensor, torch.Tensor]
################# Examples
>>> model = ESPnetUASRModel(...) # Initialize the model
>>> speech = torch.randn(2, 16000) # Example speech input (batch size 2)
>>> speech_lengths = torch.tensor([16000, 15000]) # Lengths of the inputs
>>> generated_sample, padding_mask = model.inference(speech, speech_lengths)
>>> print(generated_sample.shape) # Output shape of generated samples
>>> print(padding_mask.shape) # Output shape of padding mask
####### NOTE Ensure that the model has been properly initialized and trained before calling this method for inference.
is_discriminative_step()
Determines whether the current update step is a discriminative step.
This method checks the number of updates and returns True if the number of updates is odd, indicating that the current training iteration is a discriminative step, and False otherwise.
- Returns: True if the current update is a discriminative step, otherwise False.
- Return type: bool
################# Examples
>>> model = ESPnetUASRModel(...) # Initialize the model with necessary args
>>> model.number_updates = 1
>>> model.is_discriminative_step() # Returns True
>>> model.number_updates = 2
>>> model.is_discriminative_step() # Returns False
property number_updates