espnet2.asr_transducer.espnet_transducer_model.ESPnetASRTransducerModel
espnet2.asr_transducer.espnet_transducer_model.ESPnetASRTransducerModel
class espnet2.asr_transducer.espnet_transducer_model.ESPnetASRTransducerModel(vocab_size: int, token_list: Tuple[str, ...] | List[str], frontend: AbsFrontend | None, specaug: AbsSpecAug | None, normalize: AbsNormalize | None, encoder: Encoder, decoder: AbsDecoder, joint_network: JointNetwork, transducer_weight: float = 1.0, use_k2_pruned_loss: bool = False, k2_pruned_loss_args: Dict = {}, warmup_steps: int = 25000, validation_nstep: int = 2, fastemit_lambda: float = 0.0, auxiliary_ctc_weight: float = 0.0, auxiliary_ctc_dropout_rate: float = 0.0, auxiliary_lm_loss_weight: float = 0.0, auxiliary_lm_loss_smoothing: float = 0.05, ignore_id: int = -1, sym_space: str = '<space>', sym_blank: str = '<blank>', report_cer: bool = False, report_wer: bool = False, extract_feats_in_collect_stats: bool = True)
Bases: AbsESPnetModel
ESPnet2 ASR Transducer model.
This class defines the ESPnet2 ASR Transducer model, which is an end-to-end automatic speech recognition (ASR) model based on the transducer architecture. It includes components for encoding speech, decoding sequences, and computing losses. The model supports auxiliary losses and can utilize K2 pruned loss for improved performance.
vocab_size
Size of complete vocabulary (including SOS/EOS and blank).
- Type: int
token_list
List of tokens in vocabulary (minus reserved tokens).
- Type: List[str]
frontend
Frontend module for feature extraction.
- Type:AbsFrontend
specaug
SpecAugment module for data augmentation.
- Type:AbsSpecAug
normalize
Normalization module for feature processing.
- Type:AbsNormalize
encoder
Encoder module to process input features.
- Type:Encoder
decoder
Decoder module to generate output sequences.
- Type:AbsDecoder
joint_network
Joint Network module combining encoder and decoder outputs.
- Type:JointNetwork
transducer_weight
Weight of the Transducer loss.
- Type: float
use_k2_pruned_loss
Flag to indicate if K2 pruned loss should be used.
- Type: bool
k2_pruned_loss_args
Arguments for K2 pruned loss computation.
- Type: Dict
warmup_steps
Number of warmup steps for loss scaling.
- Type: int
validation_nstep
Max number of symbol expansions for CER/WER reporting.
- Type: int
fastemit_lambda
Lambda value for FastEmit mechanism.
- Type: float
auxiliary_ctc_weight
Weight for auxiliary CTC loss.
- Type: float
auxiliary_ctc_dropout_rate
Dropout rate for CTC inputs.
- Type: float
auxiliary_lm_loss_weight
Weight for auxiliary LM loss.
- Type: float
auxiliary_lm_loss_smoothing
Smoothing rate for LM loss.
- Type: float
ignore_id
ID used for padding.
- Type: int
sym_space
Space symbol representation.
- Type: str
sym_blank
Blank symbol representation.
- Type: str
report_cer
Flag to report Character Error Rate during validation.
- Type: bool
report_wer
Flag to report Word Error Rate during validation.
- Type: bool
extract_feats_in_collect_stats
Flag to control feature extraction during stats collection.
Type: bool
Parameters:
- vocab_size – Size of complete vocabulary (including SOS/EOS and blank).
- token_list – List of tokens in vocabulary (minus reserved tokens).
- frontend – Frontend module for feature extraction.
- specaug – SpecAugment module for data augmentation.
- normalize – Normalization module for feature processing.
- encoder – Encoder module to process input features.
- decoder – Decoder module to generate output sequences.
- joint_network – Joint Network module combining encoder and decoder outputs.
- transducer_weight – Weight of the Transducer loss.
- use_k2_pruned_loss – Whether to use K2 pruned Transducer loss.
- k2_pruned_loss_args – Arguments for K2 pruned loss computation.
- warmup_steps – Number of warmup steps for loss scaling.
- validation_nstep – Max number of symbol expansions for CER/WER reporting.
- fastemit_lambda – Lambda value for FastEmit mechanism.
- auxiliary_ctc_weight – Weight for auxiliary CTC loss.
- auxiliary_ctc_dropout_rate – Dropout rate for CTC inputs.
- auxiliary_lm_loss_weight – Weight for auxiliary LM loss.
- auxiliary_lm_loss_smoothing – Smoothing rate for LM loss.
- ignore_id – ID used for padding.
- sym_space – Space symbol representation.
- sym_blank – Blank symbol representation.
- report_cer – Flag to report Character Error Rate during validation.
- report_wer – Flag to report Word Error Rate during validation.
- extract_feats_in_collect_stats – Flag to control feature extraction during stats collection.
########### Examples
>>> model = ESPnetASRTransducerModel(
... vocab_size=100,
... token_list=["a", "b", "c"],
... frontend=None,
... specaug=None,
... normalize=None,
... encoder=Encoder(...),
... decoder=AbsDecoder(...),
... joint_network=JointNetwork(...),
... transducer_weight=1.0,
... use_k2_pruned_loss=False,
... k2_pruned_loss_args={},
... warmup_steps=25000,
... validation_nstep=2,
... fastemit_lambda=0.0,
... auxiliary_ctc_weight=0.0,
... auxiliary_ctc_dropout_rate=0.0,
... auxiliary_lm_loss_weight=0.0,
... auxiliary_lm_loss_smoothing=0.05,
... ignore_id=-1,
... sym_space='<space>',
... sym_blank='<blank>',
... report_cer=False,
... report_wer=False,
... extract_feats_in_collect_stats=True,
... )
NOTE
Ensure that all required modules (e.g., AbsFrontend, Encoder, etc.) are properly implemented and compatible with the expected input shapes and data types.
Construct an ESPnetASRTransducerModel object.
collect_feats(speech: Tensor, speech_lengths: Tensor, text: Tensor, text_lengths: Tensor, **kwargs) → Dict[str, Tensor]
Collect features sequences and features lengths sequences.
This method processes the input speech data to extract features and their corresponding lengths. It can either extract actual features using the model’s frontend or return dummy statistics if configured to do so.
- Parameters:
- speech – Speech sequences. Shape: (B, S) where B is the batch size and S is the sequence length.
- speech_lengths – Speech sequences lengths. Shape: (B,).
- text – Label ID sequences. Shape: (B, L).
- text_lengths – Label ID sequences lengths. Shape: (B,).
- kwargs – Additional keyword arguments, can contain “utts_id”.
- Returns:
- “feats”: Features sequences. Shape: (B, T, D_feats),
- ”feats_lengths”: Features sequences lengths. Shape: (B,).
- Return type: A dictionary containing
########### Examples
>>> model = ESPnetASRTransducerModel(vocab_size=100, ...)
>>> speech = torch.randn(32, 16000) # Batch of 32, 1 second audio
>>> speech_lengths = torch.tensor([16000] * 32)
>>> text = torch.randint(0, 100, (32, 50)) # Random label IDs
>>> text_lengths = torch.tensor([50] * 32)
>>> output = model.collect_feats(speech, speech_lengths, text, text_lengths)
>>> print(output["feats"].shape) # Expected shape: (32, T, D_feats)
NOTE
If extract_feats_in_collect_stats is set to False, this method will generate dummy statistics for feats and feats_lengths, which may not reflect the actual features extracted from the input speech.
encode(speech: Tensor, speech_lengths: Tensor) → Tuple[Tensor, Tensor]
ESPnet2 ASR Transducer model.
This class defines the ESPnet2 ASR Transducer model, which includes an encoder, decoder, and joint network for automatic speech recognition (ASR) tasks. The model supports auxiliary losses, data augmentation, and various training configurations.
vocab_size
Size of complete vocabulary (including SOS/EOS and blank).
- Type: int
token_list
List of tokens in vocabulary (excluding reserved tokens).
- Type: list
frontend
Frontend module for feature extraction.
- Type:AbsFrontend
specaug
SpecAugment module for data augmentation.
- Type:AbsSpecAug
normalize
Normalization module for feature processing.
- Type:AbsNormalize
encoder
Encoder module for processing speech features.
- Type:Encoder
decoder
Decoder module for generating predictions.
- Type:AbsDecoder
joint_network
Joint Network module for combining encoder and decoder outputs.
- Type:JointNetwork
transducer_weight
Weight of the Transducer loss.
- Type: float
use_k2_pruned_loss
Whether to use k2 pruned Transducer loss.
- Type: bool
k2_pruned_loss_args
Arguments for k2 pruned Transducer loss.
- Type: dict
warmup_steps
Number of steps in warmup for pruned loss scaling.
- Type: int
validation_nstep
Max number of symbol expansions during validation.
- Type: int
fastemit_lambda
FastEmit lambda value.
- Type: float
auxiliary_ctc_weight
Weight of auxiliary CTC loss.
- Type: float
auxiliary_ctc_dropout_rate
Dropout rate for auxiliary CTC loss inputs.
- Type: float
auxiliary_lm_loss_weight
Weight of auxiliary LM loss.
- Type: float
auxiliary_lm_loss_smoothing
Smoothing rate for LM loss label smoothing.
- Type: float
ignore_id
Initial padding ID.
- Type: int
sym_space
Space symbol.
- Type: str
sym_blank
Blank symbol.
- Type: str
report_cer
Whether to report Character Error Rate during validation.
- Type: bool
report_wer
Whether to report Word Error Rate during validation.
- Type: bool
extract_feats_in_collect_stats
Whether to use extracted features for stats.
Type: bool
Parameters:
- vocab_size (int) – Size of complete vocabulary (including SOS/EOS and blank).
- token_list (Union *[*Tuple *[*str , ... ] , List *[*str ] ]) – List of tokens in vocabulary.
- frontend (Optional [AbsFrontend ]) – Frontend module for feature extraction.
- specaug (Optional [AbsSpecAug ]) – SpecAugment module for data augmentation.
- normalize (Optional [AbsNormalize ]) – Normalization module for feature processing.
- encoder (Encoder) – Encoder module for processing speech features.
- decoder (AbsDecoder) – Decoder module for generating predictions.
- joint_network (JointNetwork) – Joint Network module for combining outputs.
- transducer_weight (float) – Weight of the Transducer loss.
- use_k2_pruned_loss (bool) – Whether to use k2 pruned Transducer loss.
- k2_pruned_loss_args (Dict) – Arguments for k2 pruned Transducer loss.
- warmup_steps (int) – Number of steps in warmup for pruned loss scaling.
- validation_nstep (int) – Max number of symbol expansions during validation.
- fastemit_lambda (float) – FastEmit lambda value.
- auxiliary_ctc_weight (float) – Weight of auxiliary CTC loss.
- auxiliary_ctc_dropout_rate (float) – Dropout rate for auxiliary CTC loss inputs.
- auxiliary_lm_loss_weight (float) – Weight of auxiliary LM loss.
- auxiliary_lm_loss_smoothing (float) – Smoothing rate for LM loss label smoothing.
- ignore_id (int) – Initial padding ID.
- sym_space (str) – Space symbol.
- sym_blank (str) – Blank symbol.
- report_cer (bool) – Whether to report Character Error Rate during validation.
- report_wer (bool) – Whether to report Word Error Rate during validation.
- extract_feats_in_collect_stats (bool) – Whether to use extracted features for stats.
########### Examples
model = ESPnetASRTransducerModel( : vocab_size=100, token_list=[“a”, “b”, “c”], frontend=None, specaug=None, normalize=None, encoder=Encoder(), decoder=AbsDecoder(), joint_network=JointNetwork(), transducer_weight=1.0, use_k2_pruned_loss=False, warmup_steps=25000, report_cer=True, report_wer=True,
)
forward(speech: Tensor, speech_lengths: Tensor, text: Tensor, text_lengths: Tensor, **kwargs) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Forward architecture and compute loss(es).
This method takes the input speech and text sequences, processes them through the model’s architecture, and computes the loss value along with relevant statistics. It is the core function that integrates the encoder, decoder, and joint network to produce the final loss output.
- Parameters:
- speech – Speech sequences with shape (B, S) where B is the batch size and S is the sequence length.
- speech_lengths – Lengths of the speech sequences with shape (B,).
- text – Label ID sequences with shape (B, L) where L is the maximum length of the labels.
- text_lengths – Lengths of the label ID sequences with shape (B,).
- kwargs – Additional keyword arguments that may contain “utts_id”.
- Returns:
- loss: Main loss value (scalar tensor).
- stats: A dictionary containing task statistics such as : individual loss components and error rates.
- weight: Batch size for scaling purposes.
- Return type: Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]
- Raises:AssertionError – If the dimensions of input tensors do not match the expected shapes.
########### Examples
>>> model = ESPnetASRTransducerModel(...)
>>> speech = torch.randn(2, 100) # Example speech input
>>> speech_lengths = torch.tensor([100, 90]) # Lengths of inputs
>>> text = torch.randint(0, model.vocab_size, (2, 50)) # Example labels
>>> text_lengths = torch.tensor([50, 50]) # Lengths of labels
>>> loss, stats, weight = model.forward(speech, speech_lengths, text, text_lengths)