espnet2.speechlm.net_utils.logits_to_tokens
Less than 1 minute
espnet2.speechlm.net_utils.logits_to_tokens
espnet2.speechlm.net_utils.logits_to_tokens(logits: Tensor, opts: SpeechLMInferenceOptions, allow_eos: bool = True, nq_level: int | None = None)
Converts logits to tokens based on specified inference options and search
algorithm.
This function takes a tensor of logits and applies a mask based on the inference options. It then selects tokens from the logits using the specified search algorithm, which can be either sampling or greedy search.
- Parameters:
- logits (torch.Tensor) – A 4D tensor of shape (batch_size, num_heads, sequence_length, vocab_size) representing the model’s output logits.
- opts (SpeechLMInferenceOptions) – An object containing inference options, including masks, top_k, eos token index, and search algorithm.
- allow_eos (bool , optional) – Whether to allow the End Of Sequence (EOS) token to be predicted. Defaults to True.
- nq_level (int , optional) – The specific level for token generation. If None, it uses the default level. Defaults to None.
- Returns: A tuple containing: : - gen_token_idx (torch.Tensor): A tensor of generated token indices : of shape (1, batch_size, num_heads, sequence_length).
- gen_token_score (torch.Tensor): A tensor of scores corresponding to : the generated tokens, of shape (1, batch_size, num_heads, sequence_length).
- Return type: Tuple[torch.Tensor, torch.Tensor]
- Raises:NotImplementedError – If the specified search algorithm is not implemented.
Examples
>>> from espnet2.speechlm.core_lm import SpeechLMInferenceOptions
>>> logits = torch.randn(2, 4, 10, 100) # Example logits
>>> opts = SpeechLMInferenceOptions(masks=torch.ones(1, 1, 10, 100),
... eos=99, top_k=5,
... search_algo='greedy_search')
>>> gen_token_idx, gen_token_score = logits_to_tokens(logits, opts)
>>> print(gen_token_idx.shape) # Output: (1, 2, 4, 10)
>>> print(gen_token_score.shape) # Output: (1, 2, 4, 10)