espnet2.st.espnet_model.ESPnetSTModel
espnet2.st.espnet_model.ESPnetSTModel
class espnet2.st.espnet_model.ESPnetSTModel(vocab_size: int, token_list: Tuple[str, ...] | List[str], frontend: AbsFrontend | None, specaug: AbsSpecAug | None, normalize: AbsNormalize | None, preencoder: AbsPreEncoder | None, encoder: AbsEncoder, hier_encoder: AbsEncoder | None, md_encoder: AbsEncoder | None, extra_mt_encoder: AbsEncoder | None, postencoder: AbsPostEncoder | None, decoder: AbsDecoder, extra_asr_decoder: AbsDecoder | None, extra_mt_decoder: AbsDecoder | None, ctc: CTC | None, st_ctc: CTC | None, st_joint_network: Module | None, src_vocab_size: int | None, src_token_list: Tuple[str, ...] | List[str] | None, asr_weight: float = 0.0, mt_weight: float = 0.0, mtlalpha: float = 0.0, st_mtlalpha: float = 0.0, ignore_id: int = -1, tgt_ignore_id: int = -1, lsm_weight: float = 0.0, length_normalized_loss: bool = False, report_cer: bool = True, report_wer: bool = True, report_bleu: bool = True, sym_space: str = '<space>', sym_blank: str = '<blank>', tgt_sym_space: str = '<space>', tgt_sym_blank: str = '<blank>', extract_feats_in_collect_stats: bool = True, ctc_sample_rate: float = 0.0, tgt_sym_sos: str = '<sos/eos>', tgt_sym_eos: str = '<sos/eos>', lang_token_id: int = -1)
Bases: AbsESPnetModel
CTC-attention hybrid Encoder-Decoder model.
This model combines CTC (Connectionist Temporal Classification) and attention mechanisms for sequence-to-sequence tasks, particularly in speech translation. It utilizes various components including frontends, encoders, decoders, and loss functions tailored for automatic speech recognition (ASR) and machine translation (MT).
vocab_size
Size of the target vocabulary.
- Type: int
token_list
List of tokens in the target language.
- Type: List[str]
frontend
Frontend for feature extraction.
- Type: Optional[AbsFrontend]
specaug
SpecAugment for data augmentation.
- Type: Optional[AbsSpecAug]
normalize
Normalization layer.
- Type: Optional[AbsNormalize]
preencoder
Pre-encoder for raw input.
- Type: Optional[AbsPreEncoder]
encoder
Main encoder for the input sequence.
- Type:AbsEncoder
hier_encoder
Hierarchical encoder.
- Type: Optional[AbsEncoder]
md_encoder
Multi-decoder encoder.
- Type: Optional[AbsEncoder]
extra_mt_encoder
Additional encoder for MT.
- Type: Optional[AbsEncoder]
postencoder
Post-encoder for output processing.
- Type: Optional[AbsPostEncoder]
decoder
Main decoder for generating output sequences.
- Type:AbsDecoder
extra_asr_decoder
Additional ASR decoder.
- Type: Optional[AbsDecoder]
extra_mt_decoder
Additional MT decoder.
- Type: Optional[AbsDecoder]
ctc
CTC loss for ASR task.
- Type: Optional[CTC]
st_ctc
CTC loss for ST task.
- Type: Optional[CTC]
st_joint_network
Joint network for ST.
- Type: Optional[torch.nn.Module]
src_vocab_size
Size of the source vocabulary.
- Type: Optional[int]
src_token_list
List of tokens in the source language.
- Type: Optional[List[str]]
asr_weight
Weight for ASR loss.
- Type: float
mt_weight
Weight for MT loss.
- Type: float
mtlalpha
Weight for multi-task learning.
- Type: float
st_mtlalpha
Weight for ST multi-task learning.
- Type: float
ignore_id
ID to ignore during loss calculation.
- Type: int
tgt_ignore_id
Target ignore ID.
- Type: int
lsm_weight
Label smoothing weight.
- Type: float
length_normalized_loss
Whether to normalize loss by length.
- Type: bool
report_cer
Whether to report character error rate.
- Type: bool
report_wer
Whether to report word error rate.
- Type: bool
report_bleu
Whether to report BLEU score.
- Type: bool
sym_space
Symbol for space in target language.
- Type: str
sym_blank
Symbol for blank in target language.
- Type: str
tgt_sym_space
Symbol for space in target language.
- Type: str
tgt_sym_blank
Symbol for blank in target language.
- Type: str
extract_feats_in_collect_stats
Flag to extract features.
- Type: bool
ctc
Sampling rate for CTC.
- Type: float
tgt_sym_sos
Start of sequence symbol for target.
- Type: str
tgt_sym_eos
End of sequence symbol for target.
- Type: str
lang_token_id
Language token ID.
Type: Optional[torch.Tensor]
Parameters:
- vocab_size (int) – Size of the target vocabulary.
- token_list (Union *[*Tuple *[*str , ... ] , List *[*str ] ]) – List of tokens.
- frontend (Optional [AbsFrontend ]) – Frontend for feature extraction.
- specaug (Optional [AbsSpecAug ]) – SpecAugment for data augmentation.
- normalize (Optional [AbsNormalize ]) – Normalization layer.
- preencoder (Optional [AbsPreEncoder ]) – Pre-encoder for raw input.
- encoder (AbsEncoder) – Main encoder for the input sequence.
- hier_encoder (Optional [AbsEncoder ]) – Hierarchical encoder.
- md_encoder (Optional [AbsEncoder ]) – Multi-decoder encoder.
- extra_mt_encoder (Optional [AbsEncoder ]) – Additional encoder for MT.
- postencoder (Optional [AbsPostEncoder ]) – Post-encoder for output processing.
- decoder (AbsDecoder) – Main decoder for generating output sequences.
- extra_asr_decoder (Optional [AbsDecoder ]) – Additional ASR decoder.
- extra_mt_decoder (Optional [AbsDecoder ]) – Additional MT decoder.
- ctc (Optional [CTC ]) – CTC loss for ASR task.
- st_ctc (Optional [CTC ]) – CTC loss for ST task.
- st_joint_network (Optional *[*torch.nn.Module ]) – Joint network for ST.
- src_vocab_size (Optional *[*int ]) – Size of the source vocabulary.
- src_token_list (Optional *[*Union *[*Tuple *[*str , ... ] , List *[*str ] ] ]) – List of tokens in the source language.
- asr_weight (float) – Weight for ASR loss.
- mt_weight (float) – Weight for MT loss.
- mtlalpha (float) – Weight for multi-task learning.
- st_mtlalpha (float) – Weight for ST multi-task learning.
- ignore_id (int) – ID to ignore during loss calculation.
- tgt_ignore_id (int) – Target ignore ID.
- lsm_weight (float) – Label smoothing weight.
- length_normalized_loss (bool) – Whether to normalize loss by length.
- report_cer (bool) – Whether to report character error rate.
- report_wer (bool) – Whether to report word error rate.
- report_bleu (bool) – Whether to report BLEU score.
- sym_space (str) – Symbol for space in target language.
- sym_blank (str) – Symbol for blank in target language.
- tgt_sym_space (str) – Symbol for space in target language.
- tgt_sym_blank (str) – Symbol for blank in target language.
- extract_feats_in_collect_stats (bool) – Flag to extract features.
- ctc_sample_rate (float) – Sampling rate for CTC.
- tgt_sym_sos (str) – Start of sequence symbol for target.
- tgt_sym_eos (str) – End of sequence symbol for target.
- lang_token_id (int) – Language token ID.
Returns: None
########### Examples
Initialize the model
model = ESPnetSTModel(
vocab_size=1000, token_list=[“<blank>”, “<sos>”, “<eos>”, “hello”, “world”], frontend=None, specaug=None, normalize=None, preencoder=None, encoder=my_encoder, hier_encoder=None, md_encoder=None, extra_mt_encoder=None, postencoder=None, decoder=my_decoder, extra_asr_decoder=None, extra_mt_decoder=None, ctc=my_ctc, st_ctc=my_st_ctc, st_joint_network=None, src_vocab_size=500, src_token_list=[“<blank>”, “<sos>”, “<eos>”, “bonjour”, “monde”], asr_weight=0.5, mt_weight=0.5, mtlalpha=0.5, st_mtlalpha=0.5, ignore_id=-1, tgt_ignore_id=-1, lsm_weight=0.1, length_normalized_loss=False, report_cer=True, report_wer=True, report_bleu=True, sym_space=”<space>”, sym_blank=”<blank>”, tgt_sym_space=”<space>”, tgt_sym_blank=”<blank>”, extract_feats_in_collect_stats=True, ctc_sample_rate=0.0, tgt_sym_sos=”<sos/eos>”, tgt_sym_eos=”<sos/eos>”, lang_token_id=-1,
)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
collect_feats(speech: Tensor, speech_lengths: Tensor, text: Tensor, text_lengths: Tensor, src_text: Tensor | None = None, src_text_lengths: Tensor | None = None, **kwargs) → Dict[str, Tensor]
Collect features from the input speech data.
This method extracts features from the given speech tensor and its lengths, returning them in a dictionary format. It can be used in various stages of the model, including training and evaluation.
- Parameters:
- speech – A tensor of shape (Batch, Length, …), representing the input speech data.
- speech_lengths – A tensor of shape (Batch,), indicating the lengths of each input sequence in the batch.
- text – A tensor of shape (Batch, Length), representing the target text data.
- text_lengths – A tensor of shape (Batch,), indicating the lengths of each target sequence in the batch.
- src_text – (Optional) A tensor of shape (Batch, Length) for source text data.
- src_text_lengths – (Optional) A tensor of shape (Batch,) for lengths of source text sequences.
- kwargs – Additional keyword arguments for future use.
- Returns:
- “feats”: A tensor of extracted features.
- ”feats_lengths”: A tensor of lengths corresponding to the extracted features.
- Return type: A dictionary containing
########### Examples
>>> model = ESPnetSTModel(...)
>>> speech_data = torch.randn(2, 16000) # Example speech tensor
>>> speech_lengths = torch.tensor([16000, 16000])
>>> text_data = torch.randint(0, 100, (2, 20)) # Example text tensor
>>> text_lengths = torch.tensor([20, 20])
>>> feats = model.collect_feats(speech_data, speech_lengths, text_data, text_lengths)
>>> print(feats['feats'].shape) # Output shape of the extracted features
####### NOTE Ensure that the input tensors are correctly shaped and that lengths match the batch size. This method relies on the _extract_feats method to perform the actual feature extraction.
encode(speech: Tensor, speech_lengths: Tensor, return_int_enc: bool = False) → Tuple[Tensor, Tensor]
Encodes input speech data using a frontend and an encoder.
This method performs the following steps:
- Extract features from the input speech using the frontend.
- Apply data augmentation (if specified) during training.
- Normalize the features if a normalization method is provided.
- Pass the features through the pre-encoder (if specified).
- Forward the features through the main encoder.
- Optionally pass the encoder output through a hierarchical encoder and/or a
post-encoder.
- Parameters:
- speech – A tensor containing the input speech data of shape (Batch, Length, …).
- speech_lengths – A tensor containing the lengths of the input speech sequences of shape (Batch,).
- return_int_enc – A boolean indicating whether to return the internal encoder output.
- Returns:
- encoder_out: The output from the encoder of shape (Batch, Length2, Dim2).
- encoder_out_lens: The lengths of the encoder outputs of shape (Batch,).
- int_encoder_out: The internal encoder output (only if return_int_enc is True).
- int_encoder_out_lens: The lengths of the internal encoder output (only if return_int_enc is True).
- Return type: A tuple containing
####### NOTE This method is primarily used by the st_inference.py script for speech translation.
########### Examples
>>> model = ESPnetSTModel(...)
>>> speech_tensor = torch.randn(32, 16000) # 32 samples, 1 second each
>>> speech_lengths = torch.tensor([16000] * 32) # All samples are 1 second
>>> encoder_out, encoder_out_lens = model.encode(speech_tensor, speech_lengths)
forward(speech: Tensor, speech_lengths: Tensor, text: Tensor, text_lengths: Tensor, src_text: Tensor | None = None, src_text_lengths: Tensor | None = None, **kwargs) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Frontend + Encoder + Decoder + Calc loss.
This method processes the input speech and text data through the model’s frontend, encoder, and decoder, and computes the corresponding losses for the tasks (ASR, ST, MT). It performs checks to ensure the input data dimensions are consistent and handles optional source text for ASR.
- Parameters:
- speech – Tensor of shape (Batch, Length, …) representing the input speech data.
- speech_lengths – Tensor of shape (Batch,) representing the lengths of the input speech sequences.
- text – Tensor of shape (Batch, Length) representing the target text sequences.
- text_lengths – Tensor of shape (Batch,) representing the lengths of the target text sequences.
- src_text – Optional; Tensor of shape (Batch, length) representing the source text sequences for ASR. Defaults to None.
- src_text_lengths – Optional; Tensor of shape (Batch,) representing the lengths of the source text sequences. Defaults to None.
- kwargs – Additional keyword arguments; “utt_id” can be included among the input.
- Returns:
- loss: The total computed loss for the input batch.
- stats: A dictionary containing various loss metrics and accuracy statistics.
- weight: The batch size for normalization purposes.
- Return type: Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]
- Raises:
- AssertionError – If the dimensions of the input tensors do not match
- the expected shapes. –
########### Examples
>>> model = ESPnetSTModel(...)
>>> speech = torch.randn(16, 16000) # Example input
>>> speech_lengths = torch.tensor([16000] * 16)
>>> text = torch.randint(0, 100, (16, 20)) # Example target text
>>> text_lengths = torch.tensor([20] * 16)
>>> loss, stats, weight = model.forward(speech, speech_lengths, text, text_lengths)
####### NOTE Ensure that the input tensors are correctly padded and that the lengths provided are accurate to avoid dimension mismatch errors.