espnet2.asr.discrete_asr_espnet_model.ESPnetDiscreteASRModel
espnet2.asr.discrete_asr_espnet_model.ESPnetDiscreteASRModel
class espnet2.asr.discrete_asr_espnet_model.ESPnetDiscreteASRModel(vocab_size: int, token_list: Tuple[str, ...] | List[str], frontend: AbsFrontend | None, specaug: AbsSpecAug | None, preencoder: AbsPreEncoder | None, encoder: AbsEncoder, postencoder: AbsPostEncoder | None, decoder: AbsDecoder, ctc: CTC | None, ctc_weight: float = 0.5, interctc_weight: float = 0.0, src_vocab_size: int = 0, src_token_list: Tuple[str, ...] | List[str] = [], ignore_id: int = -1, lsm_weight: float = 0.0, length_normalized_loss: bool = False, report_bleu: bool = True, sym_space: str = '<space>', sym_blank: str = '<blank>', patch_size: int = 1, extract_feats_in_collect_stats: bool = True, share_decoder_input_output_embed: bool = False, share_encoder_decoder_input_embed: bool = False)
Bases: ESPnetMTModel
ESPnetDiscreteASRModel is an encoder-decoder model for automatic speech
recognition (ASR) that leverages discrete tokens. It integrates various components such as frontend, encoder, decoder, and optionally CTC for improved performance.
vocab_size
The size of the vocabulary used for decoding.
- Type: int
token_list
A list of tokens corresponding to the vocabulary.
- Type: List[str]
frontend
An optional frontend for feature extraction.
- Type:AbsFrontend
specaug
An optional data augmentation technique.
- Type:AbsSpecAug
preencoder
An optional preencoder for raw input data.
- Type:AbsPreEncoder
encoder
The encoder component of the model.
- Type:AbsEncoder
postencoder
An optional postencoder for additional processing.
- Type:AbsPostEncoder
decoder
The decoder component of the model.
- Type:AbsDecoder
ctc
An optional CTC module for training.
- Type:CTC
ctc
Weight for the CTC loss in the combined loss function.
- Type: float
interctc_weight
Weight for the intermediate CTC loss.
- Type: float
ignore_id
Token ID to ignore in the loss calculation.
- Type: int
length_normalized_loss
If True, normalize the loss by length.
- Type: bool
report_bleu
If True, report BLEU score during training.
- Type: bool
sym_space
Symbol representing space in the token list.
- Type: str
sym_blank
Symbol representing blank in the token list.
- Type: str
blank_id
The ID of the blank token.
- Type: int
error_calculator
Calculates error metrics like CER and WER.
Type: ASRErrorCalculator
Parameters:
- vocab_size (int) – Size of the vocabulary.
- token_list (Union *[*Tuple *[*str , ... ] , List *[*str ] ]) – List of tokens.
- frontend (Optional [AbsFrontend ]) – Frontend component.
- specaug (Optional [AbsSpecAug ]) – SpecAugment component.
- preencoder (Optional [AbsPreEncoder ]) – Preencoder component.
- encoder (AbsEncoder) – Encoder component.
- postencoder (Optional [AbsPostEncoder ]) – Postencoder component.
- decoder (AbsDecoder) – Decoder component.
- ctc (Optional [CTC ]) – CTC component.
- ctc_weight (float) – Weight for CTC loss (default 0.5).
- interctc_weight (float) – Weight for intermediate CTC loss (default 0.0).
- src_vocab_size (int) – Source vocabulary size (default 0).
- src_token_list (Union *[*Tuple *[*str , ... ] , List *[*str ] ]) – Source token list (default []).
- ignore_id (int) – ID to ignore in loss calculation (default -1).
- lsm_weight (float) – Label smoothing weight (default 0.0).
- length_normalized_loss (bool) – Normalize loss by length (default False).
- report_bleu (bool) – Report BLEU score (default True).
- sym_space (str) – Symbol for space (default “<space>”).
- sym_blank (str) – Symbol for blank (default “<blank>”).
- patch_size (int) – Patch size for model (default 1).
- extract_feats_in_collect_stats (bool) – Extract features during statistics collection (default True).
- share_decoder_input_output_embed (bool) – Share decoder input/output embedding (default False).
- share_encoder_decoder_input_embed (bool) – Share encoder/decoder input embedding (default False).
Returns: A tuple containing the total loss, a dictionary of statistics, and the batch size.
Return type: Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]
Raises:AssertionError – If the input lengths are inconsistent or invalid.
######### Examples
model = ESPnetDiscreteASRModel( : vocab_size=5000, token_list=[“<blank>”, “<space>”, “hello”, “world”], frontend=None, specaug=None, preencoder=None, encoder=my_encoder, postencoder=None, decoder=my_decoder, ctc=my_ctc,
)
loss, stats, batch_size = model( : text=torch.randint(0, 5000, (32, 10)), text_lengths=torch.randint(1, 11, (32,)), src_text=torch.randint(0, 5000, (32, 10)), src_text_lengths=torch.randint(1, 11, (32,))
)
####### NOTE This model supports optional components for various stages of processing, allowing for flexibility in architecture design.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
encode(src_text: Tensor, src_text_lengths: Tensor) → Tuple[Tensor, Tensor]
Frontend + Encoder. Note that this method is used by mt_inference.py.
This method processes the input source text through a series of layers including the frontend, preencoder (if applicable), and the main encoder to produce the encoded output and its lengths.
- Parameters:
- src_text – A tensor of shape (Batch, Length, …), representing the input source text sequences.
- src_text_lengths – A tensor of shape (Batch,), containing the lengths of the source text sequences.
- Returns:
- encoder_out: A tensor of shape (Batch, Length2, Dim2), representing the encoded output from the encoder.
- encoder_out_lens: A tensor of shape (Batch,), containing the lengths of the encoded output sequences.
- Return type: A tuple containing
####### NOTE
- This method assumes that the input text has already been processed to a suitable format for encoding.
- The method can perform data augmentation if the model is in training mode and a spec augmentation instance is provided.
######### Examples
>>> model = ESPnetDiscreteASRModel(...)
>>> src_text = torch.randn(2, 10, 256) # Example input tensor
>>> src_text_lengths = torch.tensor([10, 8]) # Lengths of the inputs
>>> encoder_out, encoder_out_lens = model.encode(src_text, src_text_lengths)
>>> print(encoder_out.shape) # Expected shape: (2, Length2, Dim2)
>>> print(encoder_out_lens) # Lengths of the encoded sequences
forward(text: Tensor, text_lengths: Tensor, src_text: Tensor, src_text_lengths: Tensor, **kwargs) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Frontend + Encoder + Decoder + Calculate loss.
This method performs the forward pass through the entire model, processing the input text through the frontend, encoder, and decoder while also calculating the associated loss.
- Parameters:
- text – Tensor of shape (Batch, Length) representing the target sequences.
- text_lengths – Tensor of shape (Batch,) containing the lengths of each target sequence.
- src_text – Tensor of shape (Batch, Length) representing the source sequences.
- src_text_lengths – Tensor of shape (Batch,) containing the lengths of each source sequence.
- kwargs – Additional keyword arguments, where “utt_id” may be included.
- Returns:
- loss: A tensor representing the computed loss.
- stats: A dictionary of various statistics computed during the forward pass.
- weight: A tensor representing the batch size.
- Return type: A tuple containing
- Raises:AssertionError – If the dimensions of input tensors do not match.
######### Examples
>>> model = ESPnetDiscreteASRModel(...)
>>> text = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> text_lengths = torch.tensor([3, 3])
>>> src_text = torch.tensor([[1, 2], [3, 4]])
>>> src_text_lengths = torch.tensor([2, 2])
>>> loss, stats, weight = model.forward(text, text_lengths, src_text,
... src_text_lengths)
####### NOTE Ensure that the input tensors are properly padded and that the lengths are accurately specified.