espnet2.s2t.espnet_ctc_model.ESPnetS2TCTCModel
espnet2.s2t.espnet_ctc_model.ESPnetS2TCTCModel
class espnet2.s2t.espnet_ctc_model.ESPnetS2TCTCModel(vocab_size: int, token_list: Tuple[str, ...] | List[str], frontend: AbsFrontend | None, specaug: AbsSpecAug | None, normalize: AbsNormalize | None, encoder: AbsEncoder, prompt_encoder: AbsEncoder, ctc: CTC, interctc_weight: float = 0.0, ignore_id: int = -1, report_cer: bool = True, report_wer: bool = True, sym_space: str = '<space>', sym_blank: str = '<blank>', sym_sos: str = '<sos>', sym_eos: str = '<eos>', sym_sop: str = '<sop>', sym_na: str = '<na>', extract_feats_in_collect_stats: bool = True, ctc_asr_only: List[bool] = [False])
Bases: AbsESPnetModel
An end-to-end speech-to-text model using Connectionist Temporal Classification (CTC).
This model integrates various components including a frontend for feature extraction, an encoder for processing the features, and a CTC layer for training the model. It supports intermediate CTC losses and various data augmentation techniques.
blank_id
The index of the blank token in the token list.
- Type: int
sos
The index of the start-of-sequence token.
- Type: int
eos
The index of the end-of-sequence token.
- Type: int
sop
The index of the start-of-previous token.
- Type: int
na
The index of the not-available token.
- Type: int
vocab_size
The size of the vocabulary.
- Type: int
ignore_id
The ID to ignore during loss computation.
- Type: int
interctc_weight
Weight for intermediate CTC losses.
- Type: float
token_list
The list of tokens used in the model.
- Type: List[str]
ctc_asr_only
List indicating if CTC is only used for ASR.
- Type: List[bool]
frontend
The frontend for feature extraction.
- Type: Optional[AbsFrontend]
specaug
The specification augmentation layer.
- Type: Optional[AbsSpecAug]
normalize
The normalization layer.
- Type: Optional[AbsNormalize]
encoder
The encoder for processing input features.
- Type:AbsEncoder
prompt_encoder
The prompt encoder for additional context.
- Type:AbsEncoder
embed
Embedding layer for the prompt tokens.
- Type: torch.nn.Embedding
pos_enc
Positional encoding layer for the embeddings.
- Type: PositionalEncoding
error_calculator
Calculator for error metrics.
- Type: Optional[ErrorCalculator]
ctc
The CTC layer for loss computation.
- Type:CTC
extract_feats_in_collect_stats
Flag to extract features during statistics collection.
- Type: bool
is_encoder_whisper
Flag indicating if the encoder is a Whisper model.
Type: bool
Parameters:
- vocab_size (int) – Size of the vocabulary.
- token_list (Union *[*Tuple *[*str , ... ] , List *[*str ] ]) – List of tokens for the model.
- frontend (Optional [AbsFrontend ]) – Frontend component for feature extraction.
- specaug (Optional [AbsSpecAug ]) – SpecAugment component for data augmentation.
- normalize (Optional [AbsNormalize ]) – Normalization component.
- encoder (AbsEncoder) – Encoder component for processing features.
- prompt_encoder (AbsEncoder) – Encoder for prompt tokens.
- ctc (CTC) – CTC layer for loss computation.
- interctc_weight (float , optional) – Weight for intermediate CTC loss (default: 0.0).
- ignore_id (int , optional) – ID to ignore in loss computation (default: -1).
- report_cer (bool , optional) – Flag to report Character Error Rate (default: True).
- report_wer (bool , optional) – Flag to report Word Error Rate (default: True).
- sym_space (str , optional) – Symbol for space (default: “<space>”).
- sym_blank (str , optional) – Symbol for blank (default: “<blank>”).
- sym_sos (str , optional) – Symbol for start-of-sequence (default: “<sos>”).
- sym_eos (str , optional) – Symbol for end-of-sequence (default: “<eos>”).
- sym_sop (str , optional) – Symbol for start-of-previous (default: “<sop>”).
- sym_na (str , optional) – Symbol for not available (default: “<na>”).
- extract_feats_in_collect_stats (bool , optional) – Flag to extract features during statistics collection (default: True).
- ctc_asr_only (List *[*bool ] , optional) – List indicating if CTC is only for ASR (default: [False]).
Raises:AssertionError – If interctc_weight is not between 0.0 and 1.0.
########### Examples
Creating an instance of the ESPnetS2TCTCModel
model = ESPnetS2TCTCModel(
vocab_size=1000, token_list=[“<space>”, “<blank>”, “<sos>”, “<eos>”, “<sop>”, “<na>”], frontend=None, specaug=None, normalize=None, encoder=my_encoder, prompt_encoder=my_prompt_encoder, ctc=my_ctc, interctc_weight=0.5
)
Forward pass through the model
loss, stats, weight = model(
speech=my_speech_tensor, speech_lengths=my_speech_lengths, text=my_text_tensor, text_lengths=my_text_lengths, text_prev=my_text_prev_tensor, text_prev_lengths=my_text_prev_lengths, text_ctc=my_text_ctc_tensor, text_ctc_lengths=my_text_ctc_lengths, prefix=my_prefix_tensor, prefix_lengths=my_prefix_lengths
)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
collect_feats(speech: Tensor, speech_lengths: Tensor, text: Tensor, text_lengths: Tensor, text_prev: Tensor, text_prev_lengths: Tensor, text_ctc: Tensor, text_ctc_lengths: Tensor, **kwargs) → Dict[str, Tensor]
Collect features from the input speech data.
This method extracts features from the given speech tensor and its corresponding lengths. It is primarily used for collecting feature statistics and preparing the input for further processing in the model.
- Parameters:
- speech (torch.Tensor) – The input speech tensor of shape (Batch, Length, …).
- speech_lengths (torch.Tensor) – A tensor containing the lengths of each speech input in the batch of shape (Batch,).
- text (torch.Tensor) – The input text tensor of shape (Batch, Length).
- text_lengths (torch.Tensor) – A tensor containing the lengths of each text input in the batch of shape (Batch,).
- text_prev (torch.Tensor) – The previous text tensor of shape (Batch, Length).
- text_prev_lengths (torch.Tensor) – A tensor containing the lengths of each previous text input in the batch of shape (Batch,).
- text_ctc (torch.Tensor) – The CTC text tensor of shape (Batch, Length).
- text_ctc_lengths (torch.Tensor) – A tensor containing the lengths of each CTC text input in the batch of shape (Batch,).
- **kwargs – Additional keyword arguments, if needed.
- Returns: A dictionary containing the extracted features and their corresponding lengths. The keys are:
- ’feats’: The extracted features tensor.
- ’feats_lengths’: The lengths of the extracted features tensor.
- Return type: Dict[str, torch.Tensor]
########### Examples
>>> model = ESPnetS2TCTCModel(...)
>>> speech_tensor = torch.randn(8, 16000) # 8 samples of 1 second audio
>>> speech_lengths = torch.tensor([16000] * 8)
>>> text_tensor = torch.randint(0, 100, (8, 20)) # random text tensor
>>> text_lengths = torch.tensor([20] * 8)
>>> features = model.collect_feats(speech_tensor, speech_lengths, text_tensor, text_lengths, ...)
>>> print(features['feats'].shape) # Check the shape of the extracted features
NOTE
This method internally calls the _extract_feats method to perform the feature extraction.
encode(speech: Tensor, speech_lengths: Tensor, text_prev: Tensor, text_prev_lengths: Tensor, prefix: Tensor, prefix_lengths: Tensor)
Encode input speech.
This method processes the input speech tensor and previous text tensor to generate encoded representations that can be utilized for further processing, such as training or inference in a speech-to-text model. The method applies various transformations, including feature extraction, data augmentation, and normalization.
- Parameters:
- speech (torch.Tensor) – The input speech tensor of shape (Batch, Length, …).
- speech_lengths (torch.Tensor) – A tensor containing the lengths of each input speech sample in the batch, of shape (Batch,).
- text_prev (torch.Tensor) – The tensor representing the previous text tokens, of shape (Batch, Length).
- text_prev_lengths (torch.Tensor) – A tensor containing the lengths of each previous text sample in the batch, of shape (Batch,).
- prefix (torch.Tensor) – A tensor representing language and task tokens, of shape (Batch, Length=2).
- prefix_lengths (torch.Tensor) – A tensor containing the lengths of each prefix in the batch, of shape (Batch,).
- Returns: A tuple containing: : - encoder_out (torch.Tensor): The encoded output from the encoder, : shape (Batch, Encoded_Length, Encoder_Output_Dim).
- encoder_out_lens (torch.Tensor): A tensor containing the lengths of : the encoded outputs for each sample in the batch, of shape (Batch,).
- Return type: Tuple[torch.Tensor, torch.Tensor]
########### Examples
>>> model = ESPnetS2TCTCModel(...)
>>> speech = torch.randn(4, 16000) # Example speech input
>>> speech_lengths = torch.tensor([16000, 16000, 16000, 16000])
>>> text_prev = torch.tensor([[1, 2, 3], [1, 2, -1], [1, 2, 3], [1, -1, -1]])
>>> text_prev_lengths = torch.tensor([3, 2, 3, 1])
>>> prefix = torch.tensor([[0, 1], [0, 1], [0, 1], [0, 1]]) # Example prefixes
>>> prefix_lengths = torch.tensor([2, 2, 2, 2])
>>> encoder_out, encoder_out_lens = model.encode(
... speech, speech_lengths, text_prev, text_prev_lengths, prefix, prefix_lengths
... )
>>> print(encoder_out.shape, encoder_out_lens.shape)
torch.Size([4, Encoded_Length, Encoder_Output_Dim]) torch.Size([4])
NOTE
The input tensors should be properly padded and have consistent dimensions across the batch to ensure successful processing. The method assumes that the input lengths are provided and correctly reflect the actual lengths of the speech and text inputs.
- Raises:AssertionError – If the input lengths do not match the expected dimensions.
forward(speech: Tensor, speech_lengths: Tensor, text: Tensor, text_lengths: Tensor, text_prev: Tensor, text_prev_lengths: Tensor, text_ctc: Tensor, text_ctc_lengths: Tensor, prefix: Tensor, prefix_lengths: Tensor, **kwargs) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Perform the forward pass through the model, combining the frontend, encoder, and loss calculation.
This method processes the input speech and text data, encoding the speech features and calculating the CTC loss and related metrics. It handles the input dimensions and validates that they are consistent across the batch.
- Parameters:
- speech (torch.Tensor) – A tensor of shape (Batch, Length, …) representing the input speech signals.
- speech_lengths (torch.Tensor) – A tensor of shape (Batch,) containing the lengths of each speech signal.
- text (torch.Tensor) – A tensor of shape (Batch, Length) containing the target text sequences.
- text_lengths (torch.Tensor) – A tensor of shape (Batch,) containing the lengths of each target text sequence.
- text_prev (torch.Tensor) – A tensor of shape (Batch, Length) representing the previous text sequences for context.
- text_prev_lengths (torch.Tensor) – A tensor of shape (Batch,) containing the lengths of the previous text sequences.
- text_ctc (torch.Tensor) – A tensor of shape (Batch, Length) representing the text sequences used for CTC loss calculation.
- text_ctc_lengths (torch.Tensor) – A tensor of shape (Batch,) containing the lengths of the CTC target text sequences.
- prefix (torch.Tensor) – A tensor of shape (Batch, Length=2) containing the language and task tokens.
- prefix_lengths (torch.Tensor) – A tensor of shape (Batch,) containing the lengths of the prefix tokens.
- kwargs (dict) – Additional keyword arguments, including “utt_id” which can be used for identification.
- Returns:
- A tensor representing the total loss calculated during the forward pass.
- A dictionary containing various statistics related to the loss and metrics.
- A tensor representing the batch size.
- Return type: Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]
- Raises:AssertionError – If any of the input tensor dimensions are inconsistent or invalid.
########### Examples
>>> model = ESPnetS2TCTCModel(...)
>>> loss, stats, batch_size = model.forward(
... speech=torch.randn(32, 16000),
... speech_lengths=torch.tensor([16000]*32),
... text=torch.randint(0, 100, (32, 20)),
... text_lengths=torch.tensor([20]*32),
... text_prev=torch.randint(0, 100, (32, 20)),
... text_prev_lengths=torch.tensor([20]*32),
... text_ctc=torch.randint(0, 100, (32, 20)),
... text_ctc_lengths=torch.tensor([20]*32),
... prefix=torch.tensor([[0, 1]]*32),
... prefix_lengths=torch.tensor([2]*32)
... )