espnet2.enh.espnet_enh_s2t_model.ESPnetEnhS2TModel
espnet2.enh.espnet_enh_s2t_model.ESPnetEnhS2TModel
class espnet2.enh.espnet_enh_s2t_model.ESPnetEnhS2TModel(enh_model: ESPnetEnhancementModel, s2t_model: ESPnetASRModel | ESPnetSTModel | ESPnetDiarizationModel, calc_enh_loss: bool = True, bypass_enh_prob: float = 0)
Bases: AbsESPnetModel
Joint model for Enhancement and Speech to Text (S2T).
This class combines an enhancement model and a speech-to-text model, allowing for joint training and inference for tasks involving speech enhancement and transcription. It can handle multiple types of speech-to-text models, including automatic speech recognition (ASR), speech translation (ST), and speaker diarization.
enh_model
The enhancement model used for speech enhancement.
s2t_model
The speech-to-text model used for transcribing enhanced speech.
- Type: Union[ESPnetASRModel, ESPnetSTModel, ESPnetDiarizationModel]
bypass_enh_prob
Probability of bypassing the enhancement model during training.
- Type: float
calc_enh_loss
Flag indicating whether to calculate enhancement loss.
- Type: bool
extract_feats_in_collect_stats
Flag to determine if features should be extracted during statistics collection.
Type: bool
Parameters:
- enh_model (ESPnetEnhancementModel) – The enhancement model.
- s2t_model (Union [ESPnetASRModel , ESPnetSTModel , ESPnetDiarizationModel ]) – The speech-to-text model (ASR, ST, or DIAR).
- calc_enh_loss (bool) – Whether to calculate enhancement loss. Default is True.
- bypass_enh_prob (float) – Probability to bypass enhancement. Default is 0.
Raises:NotImplementedError – If the provided speech-to-text model type is not supported.
####################### Examples
>>> enh_model = ESPnetEnhancementModel(...)
>>> s2t_model = ESPnetASRModel(...)
>>> model = ESPnetEnhS2TModel(enh_model, s2t_model)
>>> speech = torch.randn(2, 16000) # (Batch, Length)
>>> lengths = torch.tensor([16000, 16000])
>>> loss, stats, weight = model(speech, speech_lengths=lengths)
########### NOTE The model’s forward method performs both enhancement and transcription, calculating the necessary losses based on the specified configurations.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
asr_pit_loss(speech, speech_lengths, text, text_lengths)
Calculate the permutation-invariant training (PIT) loss for ASR.
This method computes the loss for automatic speech recognition (ASR) using the permutation-invariant training approach. It determines the optimal alignment between the reference and hypothesis sequences based on the calculated CTC (Connectionist Temporal Classification) loss. The function also sorts the speech input according to the optimal permutation.
- Parameters:
- speech (torch.Tensor) – The enhanced speech signals of shape (Batch, Length, …).
- speech_lengths (torch.Tensor) – The lengths of the speech signals of shape (Batch,).
- text (List *[*torch.Tensor ]) – A list of reference text sequences for each speaker of shape (Batch, Length).
- text_lengths (List *[*torch.Tensor ]) – A list of lengths of the reference text sequences for each speaker of shape (Batch,).
- Returns: A tuple containing: : - loss (torch.Tensor): The calculated loss for the ASR task.
- stats (Dict[str, torch.Tensor]): A dictionary of statistics related to the loss computation.
- weight (torch.Tensor): The weight tensor for the computed loss.
- Return type: Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]
- Raises:ValueError – If CTC is not used for determining the permutation.
####################### Examples
>>> speech = torch.randn(2, 100, 80) # Batch of 2, 100 time steps, 80 features
>>> speech_lengths = torch.tensor([100, 90])
>>> text = [torch.randint(0, 100, (2, 20)), torch.randint(0, 100, (2, 25))]
>>> text_lengths = [torch.tensor([20, 25]), torch.tensor([20, 25])]
>>> loss, stats, weight = model.asr_pit_loss(speech, speech_lengths, text, text_lengths)
########### NOTE Ensure that the self.s2t_model.ctc is initialized before calling this method, as it is required for the computation of the loss.
batchify_nll(encoder_out: Tensor, encoder_out_lens: Tensor, ys_pad: Tensor, ys_pad_lens: Tensor, batch_size: int = 100)
Compute negative log likelihood (NLL) from the transformer decoder.
This function separates the input into batches to avoid out-of-memory (OOM) errors. It processes each batch individually and combines the results before returning.
- Parameters:
- encoder_out – Tensor of shape (Batch, Length, Dim) representing the output of the encoder.
- encoder_out_lens – Tensor of shape (Batch,) containing the lengths of each output sequence from the encoder.
- ys_pad – Tensor of shape (Batch, Length) representing the padded target sequences.
- ys_pad_lens – Tensor of shape (Batch,) containing the lengths of each target sequence.
- batch_size – Integer specifying the number of samples in each batch during the NLL computation. Adjust this to manage memory usage effectively.
- Returns: Tensor containing the computed negative log likelihood for each sample in the batch, with shape (Batch,).
####################### Examples
>>> encoder_out = torch.randn(300, 50, 256) # Example encoder output
>>> encoder_out_lens = torch.randint(1, 51, (300,))
>>> ys_pad = torch.randint(0, 100, (300, 40)) # Example target
>>> ys_pad_lens = torch.randint(1, 41, (300,))
>>> model = ESPnetASRModel(...) # Initialize model parameters
>>> nll = model.batchify_nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
########### NOTE The batch size can be adjusted to accommodate different hardware configurations and avoid memory issues during processing.
collect_feats(speech: Tensor, speech_lengths: Tensor, **kwargs) → Dict[str, Tensor]
Collect features from the input speech tensor and corresponding lengths.
This method extracts features using the speech-to-text model, depending on whether the model is configured to extract features or generate dummy stats.
extract_feats_in_collect_stats
Flag indicating whether to extract features in the collection process.
Type: bool
Parameters:
- speech (torch.Tensor) – The input speech tensor of shape (Batch, Length, …).
- speech_lengths (torch.Tensor) – The lengths of the input speech of shape (Batch,).
- **kwargs – Additional keyword arguments. Can include:
- text (torch.Tensor): The reference text tensor.
- text_lengths (torch.Tensor): The lengths of the reference text.
Returns: A dictionary containing: : - feats (torch.Tensor): The extracted features.
- feats_lengths (torch.Tensor): The lengths of the extracted features.
Return type: Dict[str, torch.Tensor]
Raises:ValueError – If the model configuration does not support feature extraction.
####################### Examples
>>> model = ESPnetEnhS2TModel(...)
>>> speech = torch.randn(10, 16000) # Example speech input
>>> speech_lengths = torch.tensor([16000] * 10) # Lengths for each input
>>> feats = model.collect_feats(speech, speech_lengths)
>>> print(feats["feats"].shape) # Should print the shape of the extracted features
encode(speech: Tensor, speech_lengths: Tensor) → Tuple[Tensor, Tensor]
Frontend + Encoder. Note that this method is used by asr_inference.py.
This method processes the input speech through the enhancement model and then encodes the enhanced speech using the speech-to-text model. It returns the encoded outputs and their corresponding lengths.
- Parameters:
- speech – Tensor of shape (Batch, Length, …), representing the input speech signals.
- speech_lengths – Tensor of shape (Batch,), representing the lengths of the input speech sequences.
- Returns: A tuple containing: : - encoder_out: Encoded output from the speech-to-text model, with shape (Batch, Length, Dim).
- encoder_out_lens: Tensor of shape (Batch,), representing the lengths of the encoded outputs.
- Return type: Tuple[torch.Tensor, torch.Tensor]
- Raises:
- AssertionError – If the number of speakers in the processed speech
- does not match the expected number of speakers in the enhancement model. –
####################### Examples
>>> model = ESPnetEnhS2TModel(...)
>>> speech = torch.randn(8, 16000) # 8 samples of 1 second audio
>>> speech_lengths = torch.tensor([16000] * 8) # lengths for each sample
>>> encoder_out, encoder_out_lens = model.encode(speech, speech_lengths)
>>> print(encoder_out.shape) # Should output: (8, Length, Dim)
>>> print(encoder_out_lens.shape) # Should output: (8,)
encode_diar(speech: Tensor, speech_lengths: Tensor, num_spk: int) → Tuple[Tensor, Tensor]
Frontend + Encoder. Note that this method is used by diar_inference.py.
This method processes the input speech tensor through the enhancement model and then encodes the enhanced speech using the speech-to-text model. It is specifically designed for diarization tasks, which involve identifying and segmenting speakers in an audio stream.
- Parameters:
- speech – A tensor of shape (Batch, Length, …) representing the input speech signal.
- speech_lengths – A tensor of shape (Batch,) representing the lengths of each speech sample in the batch.
- num_spk – An integer indicating the number of speakers in the input speech.
- Returns: A tuple containing: : - encoder_out: The output of the encoder after processing the : enhanced speech.
- encoder_out_lens: The lengths of the output sequences from the : encoder.
- speech_pre: The enhanced speech signals.
- Return type: Tuple[torch.Tensor, torch.Tensor]
####################### Examples
>>> model = ESPnetEnhS2TModel(...)
>>> speech_tensor = torch.randn(2, 16000) # Batch of 2, 1 second of audio
>>> speech_lengths = torch.tensor([16000, 16000]) # Lengths of audio
>>> num_spk = 2 # Assuming there are 2 speakers
>>> encoder_out, encoder_out_lens, speech_pre = model.encode_diar(
... speech_tensor, speech_lengths, num_spk)
########### NOTE Ensure that the speech tensor is pre-processed and in the correct format before calling this method. The num_spk parameter must match the expected number of speakers for accurate processing.
- Raises:ValueError – If the input tensor dimensions do not match the expected shapes or if num_spk is inconsistent with the enhancement model.
forward(speech: Tensor, speech_lengths: Tensor | None = None, **kwargs) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Frontend + Encoder + Decoder + Calculate loss.
This method processes the input speech tensor through the enhancement model and the speech-to-text (S2T) model. It computes the necessary losses for training and returns them along with relevant statistics.
- Parameters:
speech – A tensor of shape (Batch, Length, …) representing the input speech signals.
speech_lengths – A tensor of shape (Batch,) indicating the lengths of the input speech signals. Default is None for chunk iterators since they do not return the speech lengths. See espnet2/iterators/chunk_iter_factory.py for more details.
**kwargs –
Additional keyword arguments, which may include:
- For Enh+ASR task:
text_spk1: (Batch, Length) tensor of text sequences for : speaker 1.
text_spk2: (Batch, Length) tensor of text sequences for : speaker 2.
… text_spk1_lengths: (Batch,) tensor of lengths for text
sequences of speaker 1.
text_spk2_lengths: (Batch,) tensor of lengths for text : sequences of speaker 2.
…
- For other tasks: : text: (Batch, Length) tensor of text sequences. Default : is None, included to maintain argument order. <br/> text_lengths: (Batch,) tensor of lengths for text sequences. : Default is None for the same reason as speech_lengths.
- Returns:
- A tensor representing the computed loss.
- A dictionary containing various statistics related to the forward pass.
- A tensor representing the weight for the loss.
- Return type: Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]
- Raises:
- NotImplementedError – If the provided S2T model type is not
- supported. –
####################### Examples
>>> model = ESPnetEnhS2TModel(enh_model, s2t_model)
>>> speech_input = torch.randn(2, 16000) # Example speech input
>>> lengths = torch.tensor([16000, 16000]) # Lengths of the input
>>> loss, stats, weight = model.forward(speech_input, lengths,
... text_spk1=text1,
... text_spk1_lengths=lengths1)
########### NOTE Ensure that the input tensors are properly shaped and that the appropriate keyword arguments are passed based on the task (Enh+ASR or others).
inherite_attributes(inherite_enh_attrs: List[str] = [], inherite_s2t_attrs: List[str] = [])
Inherit attributes from the enhancement and speech-to-text models.
This method allows the user to inherit specified attributes from the enhancement model and the speech-to-text model, enabling the joint model to access properties and methods defined in the respective models without directly exposing them.
- Parameters:
- inherite_enh_attrs (List *[*str ]) – A list of attribute names to inherit from the enhancement model.
- inherite_s2t_attrs (List *[*str ]) – A list of attribute names to inherit from the speech-to-text model.
####################### Examples
>>> model = ESPnetEnhS2TModel(enh_model, s2t_model)
>>> model.inherite_attributes(
... inherite_enh_attrs=['some_enh_attr'],
... inherite_s2t_attrs=['some_s2t_attr']
... )
>>> print(model.some_enh_attr) # Access inherited attribute
>>> print(model.some_s2t_attr) # Access inherited attribute
########### NOTE If the specified attributes do not exist in the respective models, their values will be set to None.
nll(encoder_out: Tensor, encoder_out_lens: Tensor, ys_pad: Tensor, ys_pad_lens: Tensor) → Tensor
Compute negative log likelihood (NLL) from transformer decoder.
This function is typically called within the batchify_nll method. It computes the negative log likelihood for a given batch of encoded outputs and their corresponding target sequences.
- Parameters:
- encoder_out – A tensor of shape (Batch, Length, Dim) representing the output from the encoder.
- encoder_out_lens – A tensor of shape (Batch,) representing the lengths of the encoder outputs.
- ys_pad – A tensor of shape (Batch, Length) representing the padded target sequences.
- ys_pad_lens – A tensor of shape (Batch,) representing the lengths of the target sequences.
- Returns: A tensor representing the computed negative log likelihood for the provided inputs.
####################### Examples
>>> encoder_out = torch.randn(32, 50, 256) # Batch of 32, 50 time steps, 256 features
>>> encoder_out_lens = torch.randint(1, 50, (32,)) # Random lengths
>>> ys_pad = torch.randint(0, 100, (32, 40)) # Random target sequences
>>> ys_pad_lens = torch.randint(1, 40, (32,)) # Random lengths
>>> nll_value = model.nll(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
permutation_invariant_training(losses: Tensor)
Compute the Permutation Invariant Training (PIT) loss.
This method applies the Hungarian algorithm to determine the optimal assignment of hypotheses to references based on the provided loss matrix. The goal is to minimize the total loss by finding the best permutation of the hypotheses for each batch.
Parameters:losses (torch.Tensor) – A tensor of shape (batch, nref, nhyp) representing the loss values for each reference and hypothesis pair.
Returns: list: A list containing the optimal permutation indices for : each batch, where each entry is of shape (n_spk).
loss: torch.Tensor: A tensor of shape (batch) representing the : minimized loss for each batch after applying the optimal permutation.
Return type: perm
####################### Examples
>>> losses = torch.tensor([[0.1, 0.2], [0.3, 0.4]])
>>> perm, loss = model.permutation_invariant_training(losses)
>>> print(perm)
[[0], [1]]
>>> print(loss)
tensor([0.1, 0.4])
########### NOTE This method is primarily used in scenarios where multiple hypotheses can correspond to multiple references, such as in speech recognition or multi-speaker scenarios.
- Raises:ValueError – If the cost matrix is infeasible, which can happen when all loss values are infinite. In such cases, a random assignment will be used.