espnet2.asr_transducer.espnet_transducer_model.ESPnetASRTransducerModel
espnet2.asr_transducer.espnet_transducer_model.ESPnetASRTransducerModel
class espnet2.asr_transducer.espnet_transducer_model.ESPnetASRTransducerModel(vocab_size: int, token_list: Tuple[str, ...] | List[str], frontend: AbsFrontend | None, specaug: AbsSpecAug | None, normalize: AbsNormalize | None, encoder: Encoder, decoder: AbsDecoder, joint_network: JointNetwork, transducer_weight: float = 1.0, use_k2_pruned_loss: bool = False, k2_pruned_loss_args: Dict = {}, warmup_steps: int = 25000, validation_nstep: int = 2, fastemit_lambda: float = 0.0, auxiliary_ctc_weight: float = 0.0, auxiliary_ctc_dropout_rate: float = 0.0, auxiliary_lm_loss_weight: float = 0.0, auxiliary_lm_loss_smoothing: float = 0.05, ignore_id: int = -1, sym_space: str = '<space>', sym_blank: str = '<blank>', report_cer: bool = False, report_wer: bool = False, extract_feats_in_collect_stats: bool = True)
Bases: AbsESPnetModel
ESPnet2ASRTransducerModel module definition.
- Parameters:
- vocab_size β Size of complete vocabulary (w/ SOS/EOS and blank included).
- token_list β List of tokens in vocabulary (minus reserved tokens).
- frontend β Frontend module.
- specaug β SpecAugment module.
- normalize β Normalization module.
- encoder β Encoder module.
- decoder β Decoder module.
- joint_network β Joint Network module.
- transducer_weight β Weight of the Transducer loss.
- use_k2_pruned_loss β Whether to use k2 pruned Transducer loss.
- k2_pruned_loss_args β Arguments of the k2 loss pruned Transducer loss.
- warmup_steps β Number of steps in warmup, used for pruned loss scaling.
- validation_nstep β Maximum number of symbol expansions at each time step when reporting CER or/and WER using mAES.
- fastemit_lambda β FastEmit lambda value.
- auxiliary_ctc_weight β Weight of auxiliary CTC loss.
- auxiliary_ctc_dropout_rate β Dropout rate for auxiliary CTC loss inputs.
- auxiliary_lm_loss_weight β Weight of auxiliary LM loss.
- auxiliary_lm_loss_smoothing β Smoothing rate for LM lossβ label smoothing.
- ignore_id β Initial padding ID.
- sym_space β Space symbol.
- sym_blank β Blank Symbol.
- report_cer β Whether to report Character Error Rate during validation.
- report_wer β Whether to report Word Error Rate during validation.
- extract_feats_in_collect_stats β Whether to use extract_feats stats collection.
Construct an ESPnetASRTransducerModel object.
collect_feats(speech: Tensor, speech_lengths: Tensor, text: Tensor, text_lengths: Tensor, **kwargs) β Dict[str, Tensor]
Collect features sequences and features lengths sequences.
- Parameters:
- speech β Speech sequences. (B, S)
- speech_lengths β Speech sequences lengths. (B,)
- text β Label ID sequences. (B, L)
- text_lengths β Label ID sequences lengths. (B,)
- kwargs β Contains βutts_idβ.
- Returns: βfeatsβ: Features sequences. (B, T, D_feats), : βfeats_lengthsβ: Features sequences lengths. (B,)
- Return type: {}
encode(speech: Tensor, speech_lengths: Tensor) β Tuple[Tensor, Tensor]
Encoder speech sequences.
- Parameters:
- speech β Speech sequences. (B, S)
- speech_lengths β Speech sequences lengths. (B,)
- Returns: Encoder outputs. (B, T, D_enc) encoder_out_lens: Encoder outputs lengths. (B,)
- Return type: encoder_out
forward(speech: Tensor, speech_lengths: Tensor, text: Tensor, text_lengths: Tensor, **kwargs) β Tuple[Tensor, Dict[str, Tensor], Tensor]
Forward architecture and compute loss(es).
- Parameters:
- speech β Speech sequences. (B, S)
- speech_lengths β Speech sequences lengths. (B,)
- text β Label ID sequences. (B, L)
- text_lengths β Label ID sequences lengths. (B,)
- kwargs β Contains βutts_idβ.
- Returns: Main loss value. stats: Task statistics. weight: Task weights.
- Return type: loss
