espnet2.asr.partially_AR_model.PartiallyARInference
espnet2.asr.partially_AR_model.PartiallyARInference
class espnet2.asr.partially_AR_model.PartiallyARInference(ctc: CTC, decoder: AbsDecoder, threshold_probability: float, sos: int | None = None, eos: int | None = None, mask_token: int | None = None, token_list: List[int] | None = None, scorers: Dict[str, ScorerInterface] | None = None, weights: Dict[str, float] | None = None, beam_size: int = 10, max_seq_len: int = 5, max_mask_parallel: int = -1)
Bases: Module
Mask-CTC-based partially autoregressive inference.
This class implements the partially autoregressive inference using a combination of CTC (Connectionist Temporal Classification) and a beam search mechanism tailored for handling masked tokens in the decoding process. It is particularly useful for scenarios where the input data may have uncertain or missing information.
ctc
The CTC module for generating probabilities.
- Type:CTC
decoder
The decoder module used for sequence generation.
- Type:AbsDecoder
threshold_probability
The threshold for determining whether to mask a token based on its CTC probability.
- Type: float
sos
The start-of-sequence token ID.
- Type: int
eos
The end-of-sequence token ID.
- Type: int
mask_token
The token ID used for masking.
- Type: int
converter
Converter for token IDs.
- Type:TokenIDConverter
beam_search
The beam search mechanism used for generating hypotheses.
max_mask_parallel
Maximum number of masks to process simultaneously.
- Type: int
primer
A list of tokens used to prime the hypotheses.
Type: List[int]
Parameters:
- ctc (CTC) – The CTC module for decoding.
- decoder (AbsDecoder) – The decoder for generating sequences.
- threshold_probability (float) – The probability threshold for masking tokens.
- sos (int , optional) – The ID for the start-of-sequence token. Defaults to None.
- eos (int , optional) – The ID for the end-of-sequence token. Defaults to None.
- mask_token (int , optional) – The ID for the mask token. Defaults to None.
- token_list (List *[*int ] , optional) – A list of token IDs. Defaults to None.
- scorers (Dict *[*str , ScorerInterface ] , optional) – Scorers for evaluating hypotheses. Defaults to None.
- weights (Dict *[*str , float ] , optional) – Weights for the scoring functions. Defaults to None.
- beam_size (int , optional) – The size of the beam for search. Defaults to 10.
- max_seq_len (int , optional) – The maximum length of the generated sequence. Defaults to 5.
- max_mask_parallel (int , optional) – The maximum number of masks to process in parallel. Defaults to -1 (unlimited).
Returns: A list of hypotheses generated from the : inference process.
Return type: List[Hypothesis]
Raises:AssertionError – If any scorer is not an instance of MaskParallelScorerInterface.
######### Examples
>>> ctc = CTC(...)
>>> decoder = AbsDecoder(...)
>>> inference = PartiallyARInference(ctc, decoder, 0.5, sos=1, eos=2)
>>> enc_out = torch.randn(1, 10, 20) # Example encoder output
>>> hypotheses = inference(enc_out)
>>> for hypo in hypotheses:
... print(hypo.yseq)
####### NOTE This implementation assumes that the CTC and decoder are properly configured for the specific task and data.
Initialize Mask-CTC inference
forward(enc_out: Tensor, *args, **kwargs) → List[Hypothesis]
Mask-CTC-based partially autoregressive inference.
This class implements a partially autoregressive inference mechanism using a Mask-CTC approach. It utilizes a CTC decoder and beam search to generate hypotheses from the encoded outputs while handling masked tokens.
ctc
The CTC model used for generating token probabilities.
- Type:CTC
decoder
The decoder responsible for generating output sequences.
- Type:AbsDecoder
mask_token
The token ID representing the mask.
- Type: int
threshold_probability
The threshold for determining confident tokens.
- Type: float
sos
The start-of-sequence token ID.
- Type: int
eos
The end-of-sequence token ID.
- Type: int
max_seq_len
The maximum length of generated sequences.
- Type: int
max_mask_parallel
The maximum number of masks to process in parallel.
- Type: int
primer
A list of initial tokens to prepend to hypotheses.
Type: List[int]
Parameters:
- ctc (CTC) – The CTC model.
- decoder (AbsDecoder) – The decoder.
- threshold_probability (float) – Probability threshold for masking.
- sos (int , optional) – Start-of-sequence token ID. Defaults to None.
- eos (int , optional) – End-of-sequence token ID. Defaults to None.
- mask_token (int , optional) – Mask token ID. Defaults to None.
- token_list (List *[*int ] , optional) – List of token IDs. Defaults to None.
- scorers (Dict *[*str , ScorerInterface ] , optional) – Scorers for beam search. Defaults to None.
- weights (Dict *[*str , float ] , optional) – Weights for different components. Defaults to None.
- beam_size (int , optional) – Size of the beam for beam search. Defaults to 10.
- max_seq_len (int , optional) – Maximum sequence length for output. Defaults to 5.
- max_mask_parallel (int , optional) – Max number of masks processed in parallel. Defaults to -1.
Returns: A list of hypotheses generated from the input encoding.
Return type: List[Hypothesis]
Raises:AssertionError – If any scorer is not an instance of MaskParallelScorerInterface.
######### Examples
>>> ctc_model = CTC(...)
>>> decoder_model = AbsDecoder(...)
>>> inference = PartiallyARInference(ctc=ctc_model, decoder=decoder_model,
... threshold_probability=0.5)
>>> enc_out = torch.randn(1, 100, 512) # Example encoder output
>>> hypotheses = inference(enc_out)
>>> for hypo in hypotheses:
... print(hypo.yseq) # Print the generated sequences
####### NOTE This implementation is based on the research from https://arxiv.org/abs/2309.14922 and may have specific requirements related to the input data and model configuration.
set_hyp_primer(primer: List[int])
Set the hypothesis primer for the beam search.
This method allows users to define a sequence of tokens that will be used as a prefix during the beam search process. The primer can help guide the decoding process towards more relevant hypotheses by providing a starting point.
- Parameters:
- primer (List *[*int ]) – A list of token IDs that will be used as the
- search. (initial tokens in the beam)
######### Examples
>>> inference = PartiallyARInference(ctc, decoder, threshold_probability)
>>> inference.set_hyp_primer([2, 3, 5])
>>> print(inference.primer)
[2, 3, 5]
####### NOTE The provided primer should be consistent with the token list used in the model. Ensure that the token IDs in the primer are valid before invoking this method.