espnet2.speechlm.net_utils.logits_to_tokens
Less than 1 minute
espnet2.speechlm.net_utils.logits_to_tokens
espnet2.speechlm.net_utils.logits_to_tokens(logits: Tensor, opts: SpeechLMInferenceOptions, allow_eos: bool = True, nq_level: int | None = None)
Converts logits to tokens based on specified inference options and search
algorithm.
This function takes a tensor of logits and applies a mask based on the inference options. It then selects tokens from the logits using the specified search algorithm, which can be either sampling or greedy search.
- Parameters:
- logits (torch.Tensor) – A 4D tensor of shape (batch_size, num_heads, sequence_length, vocab_size) representing the model’s output logits.
 - opts (SpeechLMInferenceOptions) – An object containing inference options, including masks, top_k, eos token index, and search algorithm.
 - allow_eos (bool , optional) – Whether to allow the End Of Sequence (EOS) token to be predicted. Defaults to True.
 - nq_level (int , optional) – The specific level for token generation. If None, it uses the default level. Defaults to None.
 
 - Returns: A tuple containing: : - gen_token_idx (torch.Tensor): A tensor of generated token indices : of shape (1, batch_size, num_heads, sequence_length). 
- gen_token_score (torch.Tensor): A tensor of scores corresponding to : the generated tokens, of shape (1, batch_size, num_heads, sequence_length).
 
 - Return type: Tuple[torch.Tensor, torch.Tensor]
 - Raises:NotImplementedError – If the specified search algorithm is not implemented.
 
Examples
>>> from espnet2.speechlm.core_lm import SpeechLMInferenceOptions
>>> logits = torch.randn(2, 4, 10, 100)  # Example logits
>>> opts = SpeechLMInferenceOptions(masks=torch.ones(1, 1, 10, 100),
...                                  eos=99, top_k=5,
...                                  search_algo='greedy_search')
>>> gen_token_idx, gen_token_score = logits_to_tokens(logits, opts)
>>> print(gen_token_idx.shape)  # Output: (1, 2, 4, 10)
>>> print(gen_token_score.shape)  # Output: (1, 2, 4, 10)