espnet2.s2t.espnet_model.ESPnetS2TModel
espnet2.s2t.espnet_model.ESPnetS2TModel
class espnet2.s2t.espnet_model.ESPnetS2TModel(vocab_size: int, token_list: Tuple[str, ...] | List[str], frontend: AbsFrontend | None, specaug: AbsSpecAug | None, normalize: AbsNormalize | None, preencoder: AbsPreEncoder | None, encoder: AbsEncoder, postencoder: AbsPostEncoder | None, decoder: AbsDecoder | None, ctc: CTC, ctc_weight: float = 0.5, interctc_weight: float = 0.0, ignore_id: int = -1, lsm_weight: float = 0.0, length_normalized_loss: bool = False, report_cer: bool = True, report_wer: bool = True, sym_space: str = '<space>', sym_blank: str = '<blank>', sym_sos: str = '<sos>', sym_eos: str = '<eos>', sym_sop: str = '<sop>', sym_na: str = '<na>', extract_feats_in_collect_stats: bool = True)
Bases: AbsESPnetModel
CTC-attention hybrid Encoder-Decoder model for speech-to-text tasks.
This model integrates Connectionist Temporal Classification (CTC) and attention-based decoding to facilitate effective end-to-end speech-to-text conversion. The architecture includes various components such as a frontend for feature extraction, an encoder, and a decoder, allowing it to handle a variety of speech input formats and enhance transcription accuracy through the use of attention mechanisms and CTC.
blank_id
The index of the blank token in the token list.
- Type: int
sos
The index of the start-of-sequence token.
- Type: int
eos
The index of the end-of-sequence token.
- Type: int
sop
The index of the start-of-previous token.
- Type: int
na
The index of the not-available token.
- Type: int
vocab_size
The size of the vocabulary.
- Type: int
ignore_id
The index used for padding or ignored tokens.
- Type: int
ctc_weight
Weight for CTC loss in the combined loss calculation.
- Type: float
interctc_weight
Weight for intermediate CTC loss in the combined loss calculation.
- Type: float
token_list
The list of tokens used for the model.
- Type: List[str]
frontend
Frontend component for feature extraction.
- Type: Optional[AbsFrontend]
specaug
SpecAugment component for data augmentation.
- Type: Optional[AbsSpecAug]
normalize
Normalization component for input features.
- Type: Optional[AbsNormalize]
preencoder
Pre-encoder component for raw input data.
- Type: Optional[AbsPreEncoder]
encoder
The main encoder component.
- Type:AbsEncoder
postencoder
Post-encoder component for further processing.
- Type: Optional[AbsPostEncoder]
decoder
The decoder component for generating outputs.
- Type: Optional[AbsDecoder]
ctc
The CTC loss function used for training.
- Type:CTC
criterion_att
The loss function for attention-based decoding.
- Type: LabelSmoothingLoss
error_calculator
An optional calculator for error metrics.
- Type: Optional[ErrorCalculator]
extract_feats_in_collect_stats
Flag to determine if features are extracted during statistics collection.
Type: bool
Parameters:
- vocab_size (int) – Size of the vocabulary.
- token_list (Union *[*Tuple *[*str , ... ] , List *[*str ] ]) – List of tokens used in the model.
- frontend (Optional [AbsFrontend ]) – Frontend for feature extraction (default: None).
- specaug (Optional [AbsSpecAug ]) – Data augmentation component (default: None).
- normalize (Optional [AbsNormalize ]) – Normalization component (default: None).
- preencoder (Optional [AbsPreEncoder ]) – Pre-encoder component (default: None).
- encoder (AbsEncoder) – Encoder component.
- postencoder (Optional [AbsPostEncoder ]) – Post-encoder component (default: None).
- decoder (Optional [AbsDecoder ]) – Decoder component (default: None).
- ctc (CTC) – CTC loss function.
- ctc_weight (float) – Weight for CTC loss (default: 0.5).
- interctc_weight (float) – Weight for intermediate CTC loss (default: 0.0).
- ignore_id (int) – Padding index (default: -1).
- lsm_weight (float) – Label smoothing weight (default: 0.0).
- length_normalized_loss (bool) – Flag for length normalization in loss (default: False).
- report_cer (bool) – Flag to report Character Error Rate (default: True).
- report_wer (bool) – Flag to report Word Error Rate (default: True).
- sym_space (str) – Symbol for space (default: “<space>”).
- sym_blank (str) – Symbol for blank (default: “<blank>”).
- sym_sos (str) – Symbol for start-of-sequence (default: “<sos>”).
- sym_eos (str) – Symbol for end-of-sequence (default: “<eos>”).
- sym_sop (str) – Symbol for start-of-previous (default: “<sop>”).
- sym_na (str) – Symbol for not available (default: “<na>”).
- extract_feats_in_collect_stats (bool) – Flag to extract features during statistics collection (default: True).
Raises:AssertionError – If the CTC weights are not in the range [0.0, 1.0].
########### Examples
>>> model = ESPnetS2TModel(
... vocab_size=5000,
... token_list=['<blank>', '<sos>', '<eos>', '<space>', '<na>'],
... frontend=None,
... specaug=None,
... normalize=None,
... preencoder=None,
... encoder=my_encoder,
... postencoder=None,
... decoder=my_decoder,
... ctc=my_ctc,
... ctc_weight=0.5,
... interctc_weight=0.1,
... ignore_id=-1,
... lsm_weight=0.1,
... length_normalized_loss=True,
... report_cer=True,
... report_wer=True
... )
Initialize internal Module state, shared by both nn.Module and ScriptModule.
collect_feats(speech: Tensor, speech_lengths: Tensor, text: Tensor, text_lengths: Tensor, text_prev: Tensor, text_prev_lengths: Tensor, text_ctc: Tensor, text_ctc_lengths: Tensor, **kwargs) → Dict[str, Tensor]
Extract features from the input speech tensor.
This method processes the input speech tensor and returns the extracted features along with their lengths. It is typically used during the training or evaluation phases of the model to collect the features for further processing.
- Parameters:
- speech (torch.Tensor) – A tensor of shape (Batch, Length, …) representing the input speech signals.
- speech_lengths (torch.Tensor) – A tensor of shape (Batch,) indicating the lengths of each input speech signal.
- text (torch.Tensor) – A tensor of shape (Batch, Length) containing the target text sequences.
- text_lengths (torch.Tensor) – A tensor of shape (Batch,) indicating the lengths of each target text sequence.
- text_prev (torch.Tensor) – A tensor of shape (Batch, Length) containing the previous text sequences.
- text_prev_lengths (torch.Tensor) – A tensor of shape (Batch,) indicating the lengths of each previous text sequence.
- text_ctc (torch.Tensor) – A tensor of shape (Batch, Length) representing the CTC target text sequences.
- text_ctc_lengths (torch.Tensor) – A tensor of shape (Batch,) indicating the lengths of each CTC target text sequence.
- **kwargs – Additional keyword arguments that may be needed for other processing.
- Returns: A dictionary containing: : - ”feats”: A tensor of extracted features of shape (Batch, NFrames, Dim).
- ”feats_lengths”: A tensor of lengths for the extracted features of shape (Batch,).
- Return type: Dict[str, torch.Tensor]
########### Examples
>>> model = ESPnetS2TModel(...)
>>> speech_tensor = torch.randn(32, 16000) # Example input tensor for 32 signals
>>> speech_lengths = torch.tensor([16000] * 32) # All signals have the same length
>>> text_tensor = torch.randint(0, 100, (32, 20)) # Example target text tensor
>>> text_lengths = torch.tensor([20] * 32) # All texts have the same length
>>> features = model.collect_feats(speech_tensor, speech_lengths, text_tensor,
... text_lengths, text_tensor, text_lengths,
... text_tensor, text_lengths)
>>> print(features["feats"].shape) # Should output the shape of extracted features
####### NOTE The method assumes that the frontend is set up properly to handle the feature extraction from the raw speech input.
- Raises:
- AssertionError – If the input dimensions do not match or if there are issues
- with the speech lengths. –
encode(speech: Tensor, speech_lengths: Tensor) → Tuple[Tensor, Tensor]
Processes the input speech through the frontend and encoder.
This method is primarily responsible for extracting features from the raw speech input and then passing those features through the encoder to produce encoded outputs. This function is also used during inference in s2t_inference.py.
- Parameters:
- speech – A tensor of shape (Batch, Length, …) representing the input speech waveforms.
- speech_lengths – A tensor of shape (Batch,) indicating the lengths of each input sequence in the batch.
- Returns:
- encoder_out: A tensor of shape (Batch, Length2, Dim2) representing the output of the encoder.
- encoder_out_lens: A tensor of shape (Batch,) representing the lengths of the encoder outputs.
- Return type: A tuple containing
####### NOTE This method incorporates optional data augmentation, normalization, and pre-encoding steps, depending on the model configuration.
########### Examples
>>> model = ESPnetS2TModel(...)
>>> speech = torch.randn(2, 16000) # Example batch of 2 audio signals
>>> speech_lengths = torch.tensor([16000, 15000]) # Lengths of each audio
>>> encoder_out, encoder_out_lens = model.encode(speech, speech_lengths)
>>> print(encoder_out.shape) # Output shape will depend on encoder configuration
>>> print(encoder_out_lens) # Lengths of encoder outputs
forward(speech: Tensor, speech_lengths: Tensor, text: Tensor, text_lengths: Tensor, text_prev: Tensor, text_prev_lengths: Tensor, text_ctc: Tensor, text_ctc_lengths: Tensor, **kwargs) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Process input through the model’s components and compute the loss.
This method orchestrates the flow of input data through the frontend, encoder, and decoder, calculating the loss for both CTC and attention-based branches as necessary. It handles different types of input, computes relevant statistics, and returns the final loss along with statistics.
- Parameters:
- speech (torch.Tensor) – Input speech tensor of shape (Batch, Length, …).
- speech_lengths (torch.Tensor) – Lengths of the speech inputs of shape (Batch,).
- text (torch.Tensor) – Input text tensor of shape (Batch, Length).
- text_lengths (torch.Tensor) – Lengths of the text inputs of shape (Batch,).
- text_prev (torch.Tensor) – Previous text inputs for attention mechanism of shape (Batch, Length).
- text_prev_lengths (torch.Tensor) – Lengths of the previous text inputs of shape (Batch,).
- text_ctc (torch.Tensor) – CTC-targeted text tensor of shape (Batch, Length).
- text_ctc_lengths (torch.Tensor) – Lengths of the CTC-targeted text inputs of shape (Batch,).
- kwargs – Additional keyword arguments, expected to include “utt_id”.
- Returns: A tuple containing: : - loss (torch.Tensor): The computed loss for the current batch.
- stats (Dict[str, torch.Tensor]): A dictionary containing various statistics: : - loss_ctc: CTC loss.
- cer_ctc: Character Error Rate for CTC.
- loss_att: Attention loss.
- acc: Accuracy for the attention mechanism.
- cer: Character Error Rate for the attention mechanism.
- wer: Word Error Rate for the attention mechanism.
- loss: Total computed loss.
- weight (torch.Tensor): The batch size for loss normalization.
- stats (Dict[str, torch.Tensor]): A dictionary containing various statistics: : - loss_ctc: CTC loss.
- Return type: Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]
- Raises:AssertionError – If the dimensions of the input tensors do not match.
########### Examples
>>> model = ESPnetS2TModel(...)
>>> speech = torch.randn(4, 16000) # 4 samples of 1 second audio
>>> speech_lengths = torch.tensor([16000, 16000, 16000, 16000])
>>> text = torch.randint(0, 100, (4, 20)) # Random text tensor
>>> text_lengths = torch.tensor([20, 20, 20, 20])
>>> text_prev = torch.randint(0, 100, (4, 20))
>>> text_prev_lengths = torch.tensor([20, 20, 20, 20])
>>> text_ctc = torch.randint(0, 100, (4, 20))
>>> text_ctc_lengths = torch.tensor([20, 20, 20, 20])
>>> loss, stats, weight = model.forward(speech, speech_lengths, text, text_lengths,
... text_prev, text_prev_lengths,
... text_ctc, text_ctc_lengths)
####### NOTE This method is typically called during the training loop, where it is essential to compute both the forward pass and the associated loss for model optimization.