espnet2.s2st.espnet_model.ESPnetS2STModel
espnet2.s2st.espnet_model.ESPnetS2STModel
class espnet2.s2st.espnet_model.ESPnetS2STModel(s2st_type: str, frontend: AbsFrontend | None, tgt_feats_extract: AbsTgtFeatsExtract | None, specaug: AbsSpecAug | None, src_normalize: AbsNormalize | None, tgt_normalize: AbsNormalize | None, preencoder: AbsPreEncoder | None, encoder: AbsEncoder, postencoder: AbsPostEncoder | None, asr_decoder: AbsDecoder | None, st_decoder: AbsDecoder | None, aux_attention: AbsS2STAuxAttention | None, unit_encoder: AbsEncoder | None, synthesizer: AbsSynthesizer | None, asr_ctc: CTC | None, st_ctc: CTC | None, losses: Dict[str, AbsS2STLoss], tgt_vocab_size: int | None, tgt_token_list: Tuple[str, ...] | List[str] | None, src_vocab_size: int | None, src_token_list: Tuple[str, ...] | List[str] | None, unit_vocab_size: int | None, unit_token_list: Tuple[str, ...] | List[str] | None, ignore_id: int = -1, report_cer: bool = True, report_wer: bool = True, report_bleu: bool = True, sym_space: str = '<space>', sym_blank: str = '<blank>', extract_feats_in_collect_stats: bool = True)
Bases: AbsESPnetModel
ESPnet speech-to-speech translation model.
This class implements a speech-to-speech translation (S2ST) model that can handle various types of input and output features. The model can be configured with different frontends, encoders, decoders, and loss functions to support diverse speech translation tasks.
sos
Start-of-sequence token index for target vocabulary.
- Type: int
eos
End-of-sequence token index for target vocabulary.
- Type: int
src_sos
Start-of-sequence token index for source vocabulary.
- Type: int
src_eos
End-of-sequence token index for source vocabulary.
- Type: int
unit_sos
Start-of-sequence token index for unit vocabulary.
- Type: int
unit_eos
End-of-sequence token index for unit vocabulary.
- Type: int
tgt_vocab_size
Size of the target vocabulary.
- Type: int
src_vocab_size
Size of the source vocabulary.
- Type: int
unit_vocab_size
Size of the unit vocabulary.
- Type: int
ignore_id
Index to ignore during loss computation.
- Type: int
tgt_token_list
List of tokens for target language.
- Type: list
src_token_list
List of tokens for source language.
- Type: list
unit_token_list
List of tokens for unit representation.
- Type: list
s2st_type
Type of the S2ST model (e.g., “translatotron”).
- Type: str
frontend
Frontend processing module.
- Type:AbsFrontend
tgt_feats_extract
Target feature extraction module.
- Type:AbsTgtFeatsExtract
specaug
Spectral augmentation module.
- Type:AbsSpecAug
src_normalize
Normalization module for source features.
- Type:AbsNormalize
tgt_normalize
Normalization module for target features.
- Type:AbsNormalize
preencoder
Pre-encoder module for raw input data.
- Type:AbsPreEncoder
postencoder
Post-encoder module for additional processing.
- Type:AbsPostEncoder
encoder
Encoder module for feature extraction.
- Type:AbsEncoder
asr_decoder
ASR decoder module.
- Type:AbsDecoder
st_decoder
ST decoder module.
- Type:AbsDecoder
aux_attention
Auxiliary attention mechanism.
- Type:AbsS2STAuxAttention
unit_encoder
Encoder module for unit representation.
- Type:AbsEncoder
synthesizer
Synthesizer module for generating output.
- Type:AbsSynthesizer
asr_ctc
CTC loss module for ASR.
- Type:CTC
st_ctc
CTC loss module for ST.
- Type:CTC
losses
Dictionary of loss functions for different tasks.
- Type: dict
extract_feats_in_collect_stats
Flag to indicate feature extraction during statistics collection.
Type: bool
Parameters:
- s2st_type (str) – Type of the S2ST model.
- frontend (Optional [AbsFrontend ]) – Frontend processing module.
- tgt_feats_extract (Optional [AbsTgtFeatsExtract ]) – Target feature extraction module.
- specaug (Optional [AbsSpecAug ]) – Spectral augmentation module.
- src_normalize (Optional [AbsNormalize ]) – Normalization module for source features.
- tgt_normalize (Optional [AbsNormalize ]) – Normalization module for target features.
- preencoder (Optional [AbsPreEncoder ]) – Pre-encoder module for raw input data.
- encoder (AbsEncoder) – Encoder module for feature extraction.
- postencoder (Optional [AbsPostEncoder ]) – Post-encoder module for additional processing.
- asr_decoder (Optional [AbsDecoder ]) – ASR decoder module.
- st_decoder (Optional [AbsDecoder ]) – ST decoder module.
- aux_attention (Optional [AbsS2STAuxAttention ]) – Auxiliary attention mechanism.
- unit_encoder (Optional [AbsEncoder ]) – Encoder module for unit representation.
- synthesizer (Optional [AbsSynthesizer ]) – Synthesizer module for generating output.
- asr_ctc (Optional [CTC ]) – CTC loss module for ASR.
- st_ctc (Optional [CTC ]) – CTC loss module for ST.
- losses (Dict *[*str , AbsS2STLoss ]) – Dictionary of loss functions for different tasks.
- tgt_vocab_size (Optional *[*int ]) – Size of the target vocabulary.
- tgt_token_list (Optional *[*Union *[*Tuple *[*str , ... ] , List *[*str ] ] ]) – List of tokens for target language.
- src_vocab_size (Optional *[*int ]) – Size of the source vocabulary.
- src_token_list (Optional *[*Union *[*Tuple *[*str , ... ] , List *[*str ] ] ]) – List of tokens for source language.
- unit_vocab_size (Optional *[*int ]) – Size of the unit vocabulary.
- unit_token_list (Optional *[*Union *[*Tuple *[*str , ... ] , List *[*str ] ] ]) – List of tokens for unit representation.
- ignore_id (int) – Index to ignore during loss computation.
- report_cer (bool) – Flag to report character error rate.
- report_wer (bool) – Flag to report word error rate.
- report_bleu (bool) – Flag to report BLEU score.
- sym_space (str) – Symbol representing space.
- sym_blank (str) – Symbol representing blank.
- extract_feats_in_collect_stats (bool) – Flag to indicate feature extraction during statistics collection.
Raises:AssertionError – If certain configurations are not met.
############# Examples
Create an instance of the ESPnetS2STModel
model = ESPnetS2STModel(
s2st_type=”translatotron”, frontend=my_frontend, tgt_feats_extract=my_tgt_feats_extract, …
)
Forward pass through the model
loss, stats, weight = model(
src_speech=my_src_speech, src_speech_lengths=my_src_speech_lengths, tgt_speech=my_tgt_speech, tgt_speech_lengths=my_tgt_speech_lengths, …
)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
collect_feats(src_speech: Tensor, src_speech_lengths: Tensor, tgt_speech: Tensor, tgt_speech_lengths: Tensor, **kwargs) → Dict[str, Tensor]
Collects features from source and target speech tensors for analysis.
This method extracts features from the provided source and target speech tensors. If extract_feats_in_collect_stats is set to True, it performs feature extraction; otherwise, it generates dummy statistics.
extract_feats_in_collect_stats
Determines whether to extract
- Type: bool
features or generate dummy stats.
- Parameters:
- src_speech (torch.Tensor) – Source speech tensor of shape (Batch, Length).
- src_speech_lengths (torch.Tensor) – Lengths of source speech tensor of shape (Batch,).
- tgt_speech (torch.Tensor) – Target speech tensor of shape (Batch, Length).
- tgt_speech_lengths (torch.Tensor) – Lengths of target speech tensor of shape (Batch,).
- **kwargs – Additional keyword arguments.
- Returns: A dictionary containing extracted features and their lengths. The keys are:
- ”src_feats”: Extracted source features.
- ”src_feats_lengths”: Lengths of the source features.
- ”tgt_feats”: Extracted target features (if applicable).
- ”tgt_feats_lengths”: Lengths of the target features (if applicable).
- Return type: Dict[str, torch.Tensor]
############# Examples
>>> src_speech = torch.randn(2, 16000) # Example source speech
>>> src_lengths = torch.tensor([16000, 15000]) # Example lengths
>>> tgt_speech = torch.randn(2, 16000) # Example target speech
>>> tgt_lengths = torch.tensor([16000, 15000]) # Example lengths
>>> features = model.collect_feats(src_speech, src_lengths, tgt_speech, tgt_lengths)
>>> print(features["src_feats"].shape) # Output: torch.Size([2, N, D])
####### NOTE If extract_feats_in_collect_stats is False, this method will log a warning and return the original speech tensors as dummy statistics.
- Raises:AssertionError – If the input tensors do not match the expected dimensions.
encode(speech: Tensor, speech_lengths: Tensor, return_all_hs: bool = False, **kwargs) → Tuple[Tensor, Tensor]
Encode the input speech using the frontend and encoder components of the model.
This method performs several preprocessing steps, including feature extraction, data augmentation, and normalization before passing the processed features to the encoder. It can return intermediate hidden states if requested.
- Parameters:
- speech – A tensor of shape (Batch, Length, …) representing the input speech.
- speech_lengths – A tensor of shape (Batch,) representing the lengths of each input sequence in the batch.
- return_all_hs – A boolean indicating whether to return all hidden states from the encoder. Defaults to False.
- **kwargs – Additional keyword arguments to be passed to the encoder.
- Returns:
- encoder_out: A tensor of shape (Batch, Length2, Dim2) representing the encoded output.
- encoder_out_lens: A tensor of shape (Batch,) representing the lengths of the encoded sequences.
- inter_encoder_out (optional): Intermediate hidden states from the encoder if return_all_hs is True.
- Return type: A tuple containing
############# Examples
>>> model = ESPnetS2STModel(...)
>>> speech = torch.randn(8, 16000) # Example input (8 samples, 1 second each)
>>> speech_lengths = torch.tensor([16000] * 8) # Lengths of the input
>>> encoder_out, encoder_out_lens = model.encode(speech, speech_lengths)
####### NOTE This method is used by the speech-to-speech translation inference process.
forward(src_speech: Tensor, src_speech_lengths: Tensor, tgt_speech: Tensor, tgt_speech_lengths: Tensor, tgt_text: Tensor | None = None, tgt_text_lengths: Tensor | None = None, src_text: Tensor | None = None, src_text_lengths: Tensor | None = None, spembs: Tensor | None = None, sids: Tensor | None = None, lids: Tensor | None = None, **kwargs) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Perform the forward pass of the speech-to-speech translation model.
This method takes the source and target speech along with optional text representations, processes them through the model, and returns the computed loss, statistics, and batch size.
- Parameters:
- src_speech (torch.Tensor) – Source speech tensor of shape (Batch, Length, …).
- src_speech_lengths (torch.Tensor) – Lengths of source speech sequences of shape (Batch,).
- tgt_speech (torch.Tensor) – Target speech tensor of shape (Batch, Length, …).
- tgt_speech_lengths (torch.Tensor) – Lengths of target speech sequences of shape (Batch,).
- tgt_text (Optional *[*torch.Tensor ] , optional) – Target text tensor of shape (Batch, Length, …). Defaults to None.
- tgt_text_lengths (Optional *[*torch.Tensor ] , optional) – Lengths of target text sequences of shape (Batch,). Defaults to None.
- src_text (Optional *[*torch.Tensor ] , optional) – Source text tensor of shape (Batch, Length, …). Defaults to None.
- src_text_lengths (Optional *[*torch.Tensor ] , optional) – Lengths of source text sequences of shape (Batch,). Defaults to None.
- spembs (Optional *[*torch.Tensor ] , optional) – Speaker embeddings tensor. Defaults to None.
- sids (Optional *[*torch.Tensor ] , optional) – Speaker IDs tensor. Defaults to None.
- lids (Optional *[*torch.Tensor ] , optional) – Language IDs tensor. Defaults to None.
- **kwargs – Additional keyword arguments.
- Returns: A tuple containing: : - loss (torch.Tensor): The computed loss for the batch.
- stats (Dict[str, torch.Tensor]): A dictionary of statistics, including losses and accuracies.
- weight (torch.Tensor): The batch size for DataParallel compatibility.
- Return type: Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]
- Raises:
- ValueError – If the specified speech lengths do not match
- the batch size or if an unsupported s2st type is encountered. –
############# Examples
>>> model = ESPnetS2STModel(...)
>>> loss, stats, weight = model.forward(
... src_speech, src_speech_lengths, tgt_speech,
... tgt_speech_lengths, tgt_text, tgt_text_lengths
... )
####### NOTE The method includes various checks to ensure that the input dimensions are consistent and raises assertions if they are not.
inference(src_speech: Tensor, src_speech_lengths: Tensor | None = None, tgt_speech: Tensor | None = None, tgt_speech_lengths: Tensor | None = None, spembs: Tensor | None = None, sids: Tensor | None = None, lids: Tensor | None = None, threshold: float = 0.5, minlenratio: float = 0.0, maxlenratio: float = 10.0, use_att_constraint: bool = False, backward_window: int = 1, forward_window: int = 3, use_teacher_forcing: bool = False) → Dict[str, Tensor]
Run inference for the speech-to-speech translation model.
This method takes input speech and generates the corresponding output speech features. The method utilizes the encoder and synthesizer to produce the output based on the specified model type.
- Parameters:
- src_speech (torch.Tensor) – Input source speech tensor.
- src_speech_lengths (Optional *[*torch.Tensor ]) – Lengths of the source speech tensor.
- tgt_speech (Optional *[*torch.Tensor ]) – Target speech tensor (for feature extraction).
- tgt_speech_lengths (Optional *[*torch.Tensor ]) – Lengths of the target speech tensor.
- spembs (Optional *[*torch.Tensor ]) – Speaker embeddings.
- sids (Optional *[*torch.Tensor ]) – Speaker IDs.
- lids (Optional *[*torch.Tensor ]) – Language IDs.
- threshold (float) – Threshold for synthesizer output. Default is 0.5.
- minlenratio (float) – Minimum length ratio for output. Default is 0.0.
- maxlenratio (float) – Maximum length ratio for output. Default is 10.0.
- use_att_constraint (bool) – Flag to use attention constraint. Default is False.
- backward_window (int) – Number of frames to consider backward. Default is 1.
- forward_window (int) – Number of frames to consider forward. Default is 3.
- use_teacher_forcing (bool) – Flag to use teacher forcing during inference. Default is False.
- Returns: A dictionary containing generated features and any additional output data, including:
- ’feat_gen’: Generated features.
- ’feat_gen_denorm’: Denormalized generated features if normalization was applied.
- Return type: Dict[str, torch.Tensor]
- Raises:ValueError – If an unsupported s2st type is encountered.
############# Examples
>>> model.inference(src_speech, src_speech_lengths)
{
'feat_gen': <tensor>,
'feat_gen_denorm': <tensor>
}
property require_vocoder
Return whether or not vocoder is required.