espnet2.asr.maskctc_model.MaskCTCInference
espnet2.asr.maskctc_model.MaskCTCInference
class espnet2.asr.maskctc_model.MaskCTCInference(asr_model: MaskCTCModel, n_iterations: int, threshold_probability: float)
Bases: Module
Mask-CTC-based non-autoregressive inference.
This class implements a non-autoregressive inference method for the Mask-CTC model. It utilizes the CTC probabilities and a masked language model to iteratively predict masked tokens in the output sequence. The inference process leverages a greedy CTC decoding followed by a series of updates to refine the predictions for masked tokens.
ctc
The CTC module of the ASR model.
mlm
The masked language model (MLM) decoder.
mask_token
The token ID used for masking in the output sequence.
n_iterations
The number of iterations for iterative decoding.
threshold_probability
The probability threshold for masking tokens.
converter
A TokenIDConverter for converting token IDs to text.
- Parameters:
- asr_model (MaskCTCModel) – The Mask-CTC model used for inference.
- n_iterations (int) – The number of iterations for iterative decoding.
- threshold_probability (float) – The threshold probability for masking tokens during inference.
######### Examples
>>> model = MaskCTCModel(...)
>>> inference = MaskCTCInference(model, n_iterations=5,
threshold_probability=0.5)
>>> enc_out = torch.randn(1, 10, model.vocab_size) # Example encoder output
>>> hypotheses = inference(enc_out)
>>> print(hypotheses[0].yseq) # Output the predicted sequence
####### NOTE This implementation requires that the CTC output be in log probabilities.
- Raises:ValueError – If n_iterations or threshold_probability are not positive values.
Initialize Mask-CTC inference
forward(enc_out: Tensor) → List[Hypothesis]
Perform Mask-CTC inference.
This method executes the Mask-CTC inference process using the given encoded outputs from the speech recognition model. It performs greedy decoding with the CTC outputs and iteratively refines the predictions for masked tokens using the masked language model (MLM) decoder.
- Parameters:enc_out – A tensor of shape (1, Length, …) containing the encoded outputs from the CTC model. The first dimension is artificially added for batch compatibility.
- Returns: A list of Hypothesis objects representing the predicted sequences from the inference process.
######### Examples
>>> enc_out = torch.randn(1, 100, 256) # Example encoded output
>>> inference_model = MaskCTCInference(asr_model, n_iterations=5,
... threshold_probability=0.5)
>>> hypotheses = inference_model(enc_out)
>>> print(hypotheses)
####### NOTE The inference process involves applying a threshold on the CTC probabilities to determine which tokens to mask and iteratively filling in these masked tokens using the MLM decoder.
ids2text(ids: List[int])
Convert a list of token IDs to a human-readable text string.
This method takes a list of token IDs, converts them to their corresponding tokens using the TokenIDConverter, and formats the output by replacing special tokens with more human-readable representations. Specifically, it replaces the “<mask>” token with an underscore (“_”) and the “<space>” token with a space (” “).
- Parameters:ids (List *[*int ]) – A list of token IDs to be converted to text.
- Returns: A human-readable string representation of the input token IDs.
- Return type: str
######### Examples
>>> inference = MaskCTCInference(...)
>>> token_ids = [1, 2, 3, 4, 5]
>>> text = inference.ids2text(token_ids)
>>> print(text)
"Token1 Token2 Token3 _ Token5"
####### NOTE Ensure that the input list of IDs corresponds to the correct token mapping as defined in the TokenIDConverter.