espnet2.asr.maskctc_model.MaskCTCModel
espnet2.asr.maskctc_model.MaskCTCModel
class espnet2.asr.maskctc_model.MaskCTCModel(vocab_size: int, token_list: Tuple[str, ...] | List[str], frontend: AbsFrontend | None, specaug: AbsSpecAug | None, normalize: AbsNormalize | None, preencoder: AbsPreEncoder | None, encoder: AbsEncoder, postencoder: AbsPostEncoder | None, decoder: MLMDecoder, ctc: CTC, joint_network: Module | None = None, ctc_weight: float = 0.5, interctc_weight: float = 0.0, ignore_id: int = -1, lsm_weight: float = 0.0, length_normalized_loss: bool = False, report_cer: bool = True, report_wer: bool = True, sym_space: str = '<space>', sym_blank: str = '<blank>', sym_mask: str = '<mask>', extract_feats_in_collect_stats: bool = True)
Bases: ESPnetASRModel
Hybrid CTC/Masked LM Encoder-Decoder model (Mask-CTC).
This model combines Connectionist Temporal Classification (CTC) and Masked Language Modeling (MLM) to perform automatic speech recognition tasks. It utilizes an encoder-decoder architecture where the encoder processes the input speech, and the decoder predicts the output tokens using both CTC and MLM loss functions.
vocab_size
The size of the vocabulary including the mask token.
- Type: int
token_list
A list of tokens corresponding to the vocabulary.
- Type: List[str]
mask_token
The index of the mask token in the vocabulary.
- Type: int
criterion_mlm
The loss function used for MLM.
- Type: LabelSmoothingLoss
error_calculator
Object to calculate error metrics.
Type: Optional[ErrorCalculator]
Parameters:
- vocab_size (int) – Size of the vocabulary.
- token_list (Union *[*Tuple *[*str , ... ] , List *[*str ] ]) – List of tokens.
- frontend (Optional [AbsFrontend ]) – Frontend module for feature extraction.
- specaug (Optional [AbsSpecAug ]) – SpecAugment module for data augmentation.
- normalize (Optional [AbsNormalize ]) – Normalization layer.
- preencoder (Optional [AbsPreEncoder ]) – Pre-encoder module.
- encoder (AbsEncoder) – Encoder module.
- postencoder (Optional [AbsPostEncoder ]) – Post-encoder module.
- decoder (MLMDecoder) – Decoder module for MLM.
- ctc (CTC) – CTC module.
- joint_network (Optional *[*torch.nn.Module ]) – Joint network module, if any.
- ctc_weight (float) – Weight for CTC loss (default: 0.5).
- interctc_weight (float) – Weight for intermediate CTC loss (default: 0.0).
- ignore_id (int) – ID to ignore during loss calculation (default: -1).
- lsm_weight (float) – Label smoothing weight (default: 0.0).
- length_normalized_loss (bool) – If True, normalize loss by length (default: False).
- report_cer (bool) – If True, report Character Error Rate (default: True).
- report_wer (bool) – If True, report Word Error Rate (default: True).
- sym_space (str) – Token representing space (default: “<space>”).
- sym_blank (str) – Token representing blank (default: “<blank>”).
- sym_mask (str) – Token representing mask (default: “<mask>”).
- extract_feats_in_collect_stats (bool) – If True, extract features during statistics collection (default: True).
########### Examples
>>> model = MaskCTCModel(
... vocab_size=100,
... token_list=['<blank>', '<space>', '<mask>'] + ["a", "b", "c"],
... frontend=None,
... specaug=None,
... normalize=None,
... preencoder=None,
... encoder=SomeEncoder(),
... postencoder=None,
... decoder=SomeMLMDecoder(),
... ctc=SomeCTC(),
... ctc_weight=0.5,
... interctc_weight=0.0,
... ignore_id=-1,
... lsm_weight=0.1,
... length_normalized_loss=False,
... report_cer=True,
... report_wer=True,
... sym_space='<space>',
... sym_blank='<blank>',
... sym_mask='<mask>',
... extract_feats_in_collect_stats=True
... )
NOTE
This model is designed for tasks where both CTC and MLM are beneficial, such as in noisy speech recognition or when the input data is limited.
- Raises:AssertionError – If the dimensions of the input tensors do not match.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
batchify_nll(encoder_out: Tensor, encoder_out_lens: Tensor, ys_pad: Tensor, ys_pad_lens: Tensor, batch_size: int = 100)
Batchify the negative log likelihood (NLL) computation.
This method takes the encoded outputs and targets, and splits them into smaller batches for the computation of the negative log likelihood. This is useful for efficiently processing large datasets that may not fit into memory all at once.
- Parameters:
- encoder_out (torch.Tensor) – The output from the encoder, shaped (Batch, Length, Features).
- encoder_out_lens (torch.Tensor) – The lengths of each sequence in the encoder output, shaped (Batch,).
- ys_pad (torch.Tensor) – The target sequences, shaped (Batch, Length).
- ys_pad_lens (torch.Tensor) – The lengths of each target sequence, shaped (Batch,).
- batch_size (int , optional) – The size of each batch to process. Defaults to 100.
- Returns: A tensor containing the computed negative log likelihoods for each batch.
- Return type: torch.Tensor
- Raises:NotImplementedError – This method is not yet implemented.
########### Examples
>>> encoder_out = torch.randn(200, 50, 256) # Example encoder output
>>> encoder_out_lens = torch.randint(1, 50, (200,))
>>> ys_pad = torch.randint(0, 100, (200, 30)) # Example target
>>> ys_pad_lens = torch.randint(1, 30, (200,))
>>> nll_values = model.batchify_nll(encoder_out, encoder_out_lens,
... ys_pad, ys_pad_lens, batch_size=50)
forward(speech: Tensor, speech_lengths: Tensor, text: Tensor, text_lengths: Tensor, **kwargs) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Process input through the model’s frontend, encoder, and decoder, and compute the associated loss.
This method takes speech and text input, processes them through the model’s architecture, and calculates the CTC and MLM losses. The output includes the total loss, statistics for loss and accuracy, and the batch size.
- Parameters:
- speech (torch.Tensor) – Input speech tensor of shape (Batch, Length, …).
- speech_lengths (torch.Tensor) – Lengths of the input speech tensor of shape (Batch,).
- text (torch.Tensor) – Input text tensor of shape (Batch, Length).
- text_lengths (torch.Tensor) – Lengths of the input text tensor of shape (Batch,).
- Returns:
- Total loss for the batch.
- A dictionary containing statistics such as CTC loss, MLM loss, and accuracies.
- Batch size for data-parallel processing.
- Return type: Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]
- Raises:AssertionError – If the dimensions of input tensors do not match.
########### Examples
>>> model = MaskCTCModel(...)
>>> speech = torch.randn(16, 100, 40) # Batch of 16, 100 time steps, 40 features
>>> speech_lengths = torch.tensor([100] * 16)
>>> text = torch.randint(0, 50, (16, 20)) # Batch of 16, 20 tokens
>>> text_lengths = torch.tensor([20] * 16)
>>> loss, stats, batch_size = model.forward(speech, speech_lengths, text, text_lengths)
NOTE
This function assumes that the input speech and text tensors are properly preprocessed and padded to the same batch size.
nll(encoder_out: Tensor, encoder_out_lens: Tensor, ys_pad: Tensor, ys_pad_lens: Tensor) → Tensor
Computes the negative log-likelihood (NLL) for the MaskCTCModel.
This method is intended to be implemented in subclasses to provide the functionality for calculating the NLL based on the encoder outputs, lengths, and the target sequences. This function is currently not implemented and raises a NotImplementedError.
- Parameters:
- encoder_out (torch.Tensor) – The output from the encoder, typically of shape (Batch, Length, Features).
- encoder_out_lens (torch.Tensor) – The lengths of the encoder outputs, of shape (Batch,).
- ys_pad (torch.Tensor) – The padded target sequences, of shape (Batch, Length).
- ys_pad_lens (torch.Tensor) – The lengths of the target sequences, of shape (Batch,).
- Returns: The negative log-likelihood value for the given encoder outputs and target sequences.
- Return type: torch.Tensor
- Raises:NotImplementedError – This method is not implemented in the base class.
########### Examples
>>> model = MaskCTCModel(...) # Initialize your model
>>> encoder_output = torch.rand(32, 100, 256) # Example encoder output
>>> encoder_output_lengths = torch.randint(1, 100, (32,))
>>> target_sequences = torch.randint(0, model.vocab_size, (32, 50))
>>> target_lengths = torch.randint(1, 50, (32,))
>>> nll_value = model.nll(encoder_output, encoder_output_lengths,
... target_sequences, target_lengths)