espnet2.asr_transducer.beam_search_transducer.BeamSearchTransducer
espnet2.asr_transducer.beam_search_transducer.BeamSearchTransducer
class espnet2.asr_transducer.beam_search_transducer.BeamSearchTransducer(decoder: AbsDecoder, joint_network: JointNetwork, beam_size: int, lm: Module | None = None, lm_weight: float = 0.1, search_type: str = 'default', max_sym_exp: int = 3, u_max: int = 50, nstep: int = 2, expansion_gamma: float = 2.3, expansion_beta: int = 2, score_norm: bool = False, nbest: int = 1, streaming: bool = False)
Bases: object
Beam search implementation for Transducer models.
This class implements a beam search algorithm for transducer models in automatic speech recognition (ASR). It is designed to work with a decoder and a joint network module to produce N-best hypotheses from the encoder output.
decoder
Decoder module for generating sequences.
- Type:AbsDecoder
joint_network
Joint network module that combines encoder and decoder outputs.
- Type:JointNetwork
beam_size
Size of the beam for search.
- Type: int
lm
Language model module for soft fusion.
- Type: Optional[torch.nn.Module]
lm
Weight for the language model in scoring.
- Type: float
search_type
Type of search algorithm used during inference.
- Type: str
max_sym_exp
Maximum symbol expansions at each time step.
- Type: int
u_max
Maximum expected target sequence length.
- Type: int
nstep
Maximum expansion steps at each time step.
- Type: int
expansion_gamma
Log probability difference for pruning.
- Type: float
expansion_beta
Additional candidates for expanded hypotheses.
- Type: int
score_norm
Whether to normalize final scores by length.
- Type: bool
nbest
Number of final hypotheses.
- Type: int
streaming
Whether to perform chunk-by-chunk beam search.
Type: bool
Parameters:
- decoder (AbsDecoder) – The decoder module.
- joint_network (JointNetwork) – The joint network module.
- beam_size (int) – The size of the beam for search.
- lm (Optional *[*torch.nn.Module ]) – The language model for soft fusion.
- lm_weight (float) – Weight of the language model.
- search_type (str) – Type of search algorithm to use.
- max_sym_exp (int) – Maximum symbol expansions.
- u_max (int) – Maximum expected target sequence length.
- nstep (int) – Maximum number of expansion steps.
- expansion_gamma (float) – Log probability difference for pruning.
- expansion_beta (int) – Additional candidates for expanded hypotheses.
- score_norm (bool) – Normalize final scores by length.
- nbest (int) – Number of final hypotheses.
- streaming (bool) – Perform chunk-by-chunk beam search.
####################### Examples
>>> beam_search = BeamSearchTransducer(decoder, joint_network, beam_size=5)
>>> hypotheses = beam_search(enc_out, is_final=True)
- Raises:NotImplementedError – If the specified search type is not supported.
########## NOTE Ensure that the beam_size is less than or equal to the vocabulary size of the decoder.
Construct a BeamSearchTransducer object.
align_length_sync_decoding(enc_out: Tensor) → List[Hypothesis]
Alignment-length synchronous beam search implementation.
This method performs a beam search that synchronizes the length of the generated sequences with the input encoder outputs. The search is based on the algorithm described in the paper: “A Generalized Beam Search Algorithm for Sequence-to-Sequence Learning” (https://ieeexplore.ieee.org/document/9053040).
- Parameters:enc_out – Encoder output sequences. Shape is (T, D) where T is the number of time steps and D is the dimension of the encoder output.
- Returns: A list of N-best hypotheses generated from : the beam search.
- Return type: List[Hypothesis]
####################### Examples
>>> beam_search_transducer = BeamSearchTransducer(...)
>>> encoder_output = torch.randn(10, 256) # Example encoder output
>>> hypotheses = beam_search_transducer.align_length_sync_decoding(encoder_output)
>>> for hyp in hypotheses:
... print(hyp.yseq, hyp.score)
create_lm_batch_inputs(hyps_seq: List[List[int]]) → Tensor
Make batch of inputs with left padding for LM scoring.
This function creates a padded batch of hypothesis sequences, where each sequence is left-padded with a start-of-sequence token and right-padded with zeros to ensure that all sequences in the batch have the same length.
- Parameters:hyps_seq – A list of hypothesis sequences, where each sequence is a list of integers representing label IDs.
- Returns: A tensor containing the padded batch of sequences. The : shape of the tensor will be (batch_size, max_length), where max_length is the length of the longest sequence in the input.
- Return type: torch.Tensor
####################### Examples
>>> hyps_seq = [[1, 2, 3], [4, 5], [6]]
>>> batch_inputs = create_lm_batch_inputs(hyps_seq)
>>> print(batch_inputs)
tensor([[ 0, 1, 2, 3],
[ 0, 4, 5, 0],
[ 0, 6, 0, 0]])
########## NOTE The start-of-sequence token is defined as self.sos, and zero is used for padding.
default_beam_search(enc_out: Tensor) → List[Hypothesis]
Beam search implementation without prefix search.
This method performs a beam search over the output of the encoder without using prefix search. It evaluates the hypotheses at each time step, expanding the most promising ones according to the beam size and the scores computed from the joint network.
Modified from: https://arxiv.org/pdf/1211.3711.pdf
- Parameters:enc_out – Encoder output sequence. Shape (T, D).
- Returns: List of N-best hypotheses sorted by their scores.
- Return type: nbest_hyps
####################### Examples
>>> enc_out = torch.randn(10, 256) # Example encoder output
>>> beam_search = BeamSearchTransducer(decoder, joint_network, beam_size=5)
>>> results = beam_search.default_beam_search(enc_out)
>>> print(results) # List of Hypothesis objects with their scores and sequences.
########## NOTE The hypotheses are scored based on both the decoder output and the language model (if available), and are pruned according to the beam size at each time step.
modified_adaptive_expansion_search(enc_out: Tensor) → List[ExtendedHypothesis]
Modified version of Adaptive Expansion Search (mAES).
This method implements a modified version of the Adaptive Expansion Search algorithm for beam search decoding in transducer models. It utilizes a combination of hypotheses from previous steps and expands them based on the current encoder output.
Based on the original Adaptive Expansion Search (AES) as described in https://ieeexplore.ieee.org/document/9250505 and the Non-Stationary Context (NSC) approach from https://arxiv.org/abs/2201.05420.
- Parameters:enc_out – Encoder output sequence. (T, D_enc)
- Returns: N-best hypotheses sorted by score.
- Return type: nbest_hyps
####################### Examples
>>> enc_out = torch.rand(10, 256) # Example encoder output
>>> beam_search = BeamSearchTransducer(...)
>>> nbest_hyps = beam_search.modified_adaptive_expansion_search(enc_out)
recombine_hyps(hyps: List[Hypothesis]) → List[Hypothesis]
Recombine hypotheses with same label ID sequence.
This method aggregates the scores of hypotheses that share the same label ID sequence, effectively merging them into a single hypothesis with a combined score. The score is computed using the log-sum-exp trick to prevent numerical underflow.
- Parameters:hyps – A list of Hypothesis objects that need to be recombined.
- Returns: A list of recombined Hypothesis objects, where each unique label ID sequence is represented by a single Hypothesis with its score adjusted accordingly.
- Return type: final
####################### Examples
>>> hyps = [
... Hypothesis(score=1.0, yseq=[1, 2]),
... Hypothesis(score=0.5, yseq=[1, 2]),
... Hypothesis(score=2.0, yseq=[3, 4]),
... ]
>>> recombined = recombine_hyps(hyps)
>>> for hyp in recombined:
... print(hyp.yseq, hyp.score)
[1, 2] -1.2039728043259318 # logaddexp(1.0, 0.5)
[3, 4] 2.0
reset_cache() → None
Reset cache for streaming decoding.
This method clears the score cache in the decoder and resets the search cache. It is particularly useful in scenarios where multiple decoding chunks are processed in a streaming manner, ensuring that previous state information does not interfere with subsequent decoding steps.
score_cache
A dictionary used by the decoder to cache scores.
- Type: dict
search_cache
A placeholder for caching hypotheses during the search process.
- Type: None
####################### Examples
>>> beam_search = BeamSearchTransducer(...)
>>> beam_search.reset_cache() # Resets caches before new decoding.
########## NOTE This method is automatically called at the end of a decoding pass if is_final is set to True in the __call__ method.
select_k_expansions(hyps: List[ExtendedHypothesis], topk_idx: Tensor, topk_logp: Tensor) → List[ExtendedHypothesis]
Return K hypotheses candidates for expansion from a list of hypotheses.
K candidates are selected according to the extended hypotheses probabilities and a prune-by-value method. Where K is equal to beam_size + beta.
- Parameters:
- hyps – List of extended hypotheses to select from.
- topk_idx – Indices of candidate hypotheses.
- topk_logp – Log-probabilities of candidate hypotheses.
- Returns: List of the best K expansion hypotheses candidates.
- Return type: k_expansions
####################### Examples
>>> hyps = [ExtendedHypothesis(yseq=[0], score=1.0),
... ExtendedHypothesis(yseq=[1], score=0.8)]
>>> topk_idx = torch.tensor([[0, 1], [0, 1]])
>>> topk_logp = torch.tensor([[0.5, 0.3], [0.4, 0.2]])
>>> k_expansions = select_k_expansions(hyps, topk_idx, topk_logp)
>>> print(k_expansions)
[<ExtendedHypothesis>, <ExtendedHypothesis>]
sort_nbest(hyps: List[Hypothesis]) → List[Hypothesis]
Sort in-place hypotheses by score or score given sequence length.
This method sorts a list of hypotheses based on their scores. If score_norm is set to True, it normalizes the scores by the length of the corresponding label sequences. The sorted list will contain only the top nbest hypotheses.
- Parameters:hyps – A list of Hypothesis instances to be sorted.
- Returns: A sorted list of hypotheses, containing only the top nbest hypotheses based on their scores.
- Return type: List[Hypothesis]
####################### Examples
>>> hyp1 = Hypothesis(score=10.0, yseq=[1, 2])
>>> hyp2 = Hypothesis(score=15.0, yseq=[1, 3])
>>> hyp3 = Hypothesis(score=5.0, yseq=[2, 3])
>>> sorted_hyps = sort_nbest([hyp1, hyp2, hyp3])
>>> sorted_hyps[0].score # Should return 15.0
15.0
########## NOTE The sorting is done in-place, meaning the original list hyps will be modified.
time_sync_decoding(enc_out: Tensor) → List[Hypothesis]
Time synchronous beam search implementation.
This method implements a beam search algorithm that operates in a time synchronous manner. It takes the encoder output sequence and generates N-best hypotheses based on the joint network’s log-probabilities. The approach allows for multiple symbol expansions at each time step, making it suitable for scenarios where temporal alignment is critical.
- Parameters:enc_out – Encoder output sequence. (T, D)
- Returns: N-best hypotheses, sorted by their scores.
- Return type: List[Hypothesis]
####################### Examples
>>> decoder = AbsDecoder(...)
>>> joint_network = JointNetwork(...)
>>> beam_search = BeamSearchTransducer(decoder, joint_network, beam_size=5)
>>> enc_out = torch.randn(10, decoder.input_dim) # Example encoder output
>>> nbest_hyps = beam_search.time_sync_decoding(enc_out)
>>> for hyp in nbest_hyps:
>>> print(hyp.yseq, hyp.score)
########## NOTE The method can utilize a language model if one is provided during the initialization of the BeamSearchTransducer.
- Raises:RuntimeError – If the input tensor dimensions do not match expected shapes.