espnet2.asr.pit_espnet_model.ESPnetASRModel
espnet2.asr.pit_espnet_model.ESPnetASRModel
class espnet2.asr.pit_espnet_model.ESPnetASRModel(vocab_size: int, token_list: Tuple[str, ...] | List[str], frontend: AbsFrontend | None, specaug: AbsSpecAug | None, normalize: AbsNormalize | None, preencoder: AbsPreEncoder | None, encoder: AbsEncoder, postencoder: AbsPostEncoder | None, decoder: AbsDecoder | None, ctc: CTC, joint_network: Module | None, ctc_weight: float = 0.5, interctc_weight: float = 0.0, ignore_id: int = -1, lsm_weight: float = 0.0, length_normalized_loss: bool = False, report_cer: bool = True, report_wer: bool = True, sym_space: str = '<space>', sym_blank: str = '<blank>', sym_sos: str = '<sos/eos>', sym_eos: str = '<sos/eos>', extract_feats_in_collect_stats: bool = True, lang_token_id: int = -1, num_inf: int = 1, num_ref: int = 1)
Bases: ESPnetASRModel
ESPnetASRModel is a hybrid CTC-attention Encoder-Decoder model for automatic speech recognition (ASR). This model combines the strengths of Connectionist Temporal Classification (CTC) and attention mechanisms, enabling it to handle different types of input sequences effectively.
num_inf
The number of inferences (outputs) from the model.
- Type: int
num_ref
The number of references (ground truth sequences).
- Type: int
pit_ctc
A wrapper for calculating the Permutation Invariant Training (PIT) loss with CTC.
Type:PITLossWrapper
Parameters:
- vocab_size (int) – The size of the vocabulary.
- token_list (Union *[*Tuple *[*str , ... ] , List *[*str ] ]) – List of tokens in the vocabulary.
- frontend (Optional [AbsFrontend ]) – Frontend processing component.
- specaug (Optional [AbsSpecAug ]) – SpecAugment component for data augmentation.
- normalize (Optional [AbsNormalize ]) – Normalization layer.
- preencoder (Optional [AbsPreEncoder ]) – Pre-encoder component.
- encoder (AbsEncoder) – The encoder component of the model.
- postencoder (Optional [AbsPostEncoder ]) – Post-encoder component.
- decoder (Optional [AbsDecoder ]) – The decoder component of the model.
- ctc (CTC) – The CTC component for loss calculation.
- joint_network (Optional *[*torch.nn.Module ]) – Joint network component.
- ctc_weight (float) – Weight for the CTC loss in the combined loss. Defaults to 0.5.
- interctc_weight (float) – Weight for the inter-CTC loss. Defaults to 0.0.
- ignore_id (int) – The token ID to ignore during loss calculation. Defaults to -1.
- lsm_weight (float) – Label smoothing weight. Defaults to 0.0.
- length_normalized_loss (bool) – Whether to use length-normalized loss. Defaults to False.
- report_cer (bool) – Whether to report Character Error Rate (CER). Defaults to True.
- report_wer (bool) – Whether to report Word Error Rate (WER). Defaults to True.
- sym_space (str) – Symbol for space token. Defaults to “<space>”.
- sym_blank (str) – Symbol for blank token in CTC. Defaults to “<blank>”.
- sym_sos (str) – Symbol for start of sequence. Defaults to “<sos/eos>”.
- sym_eos (str) – Symbol for end of sequence. Defaults to “<sos/eos>”.
- extract_feats_in_collect_stats (bool) – Whether to extract features in collecting statistics. Defaults to True.
- lang_token_id (int) – Language token ID. Defaults to -1.
- num_inf (int) – Number of inferences (outputs) from the model. Defaults to 1.
- num_ref (int) – Number of references (ground truth sequences). Defaults to 1.
Raises:
- AssertionError – If ctc_weight is not in the range (0.0, 1.0] or if interctc_weight is not equal to 0.0.
- AssertionError – If num_inf is not equal to num_ref.
####### Examples
Create an instance of ESPnetASRModel
model = ESPnetASRModel(
vocab_size=1000, token_list=[“<blank>”, “<space>”, “<sos/eos>”] + list(“abcdefghijklmnopqrstuvwxyz”), frontend=None, specaug=None, normalize=None, preencoder=None, encoder=my_encoder, postencoder=None, decoder=my_decoder, ctc=my_ctc, joint_network=None, ctc_weight=0.5, interctc_weight=0.0, ignore_id=-1, lsm_weight=0.0, length_normalized_loss=False, report_cer=True, report_wer=True, sym_space=”<space>”, sym_blank=”<blank>”, sym_sos=”<sos/eos>”, sym_eos=”<sos/eos>”, extract_feats_in_collect_stats=True, lang_token_id=-1, num_inf=1, num_ref=1,
)
Forward pass through the model
loss, stats, weight = model.forward(
speech=my_speech_tensor, speech_lengths=my_speech_lengths_tensor, text=my_text_tensor, text_lengths=my_text_lengths_tensor, utt_id=”example_id”
)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(speech: Tensor, speech_lengths: Tensor, text: Tensor, text_lengths: Tensor, **kwargs) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Forward pass of the ESPnetASRModel, which processes the input speech data through the frontend, encoder, and decoder, and calculates the loss based on the provided references. This method supports multiple references for enhanced performance in speech recognition tasks.
- Parameters:
- speech (torch.Tensor) – Input speech tensor of shape (Batch, Length, …).
- speech_lengths (torch.Tensor) – Lengths of the input speech tensor of shape (Batch,).
- text (torch.Tensor) – Target text tensor of shape (Batch, Length).
- text_lengths (torch.Tensor) – Lengths of the target text tensor of shape (Batch,).
- **kwargs – Additional keyword arguments. Must include “utt_id” and may include references for additional speakers, e.g., “text_spk1”, “text_spk1_lengths”, etc.
- Returns:
- loss: Computed loss value for the batch.
- stats: Dictionary containing various statistics from the model, such as loss values and accuracy metrics.
- weight: The batch size for further processing.
- Return type: Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]
- Raises:AssertionError – If the dimensions of input tensors do not match as expected.
####### Examples
>>> model = ESPnetASRModel(...)
>>> speech = torch.randn(4, 100, 80) # Batch of 4, Length 100, 80 features
>>> speech_lengths = torch.tensor([100, 90, 80, 70])
>>> text = torch.randint(0, 30, (4, 20)) # Batch of 4, Length 20
>>> text_lengths = torch.tensor([20, 18, 15, 12])
>>> loss, stats, weight = model.forward(speech, speech_lengths, text, text_lengths,
... utt_id='utt1', text_spk1=text,
... text_spk1_lengths=text_lengths)
NOTE
Ensure that the input tensors are properly padded and have the correct dimensions. The text and text_lengths should match the batch size of the speech and speech_lengths tensors.