espnet2.lm.espnet_model_multitask.ESPnetMultitaskLanguageModel
espnet2.lm.espnet_model_multitask.ESPnetMultitaskLanguageModel
class espnet2.lm.espnet_model_multitask.ESPnetMultitaskLanguageModel(lm: AbsLM, vocab_size: int, token_list: Tuple[str, ...] | List[str], ignore_id: int = 0, lsm_weight: float = 0.0, length_normalized_loss: bool = False, sos_syms: List[str] = ['<generatetext>', '<generatespeech>'], eos_sym: str = '<sos/eos>')
Bases: AbsESPnetModel
ESPnetMultitaskLanguageModel is a multitask language model that integrates
various functionalities of language modeling using a provided language model instance. It allows computation of negative log likelihood (NLL) and attention loss for language modeling tasks.
lm
An instance of a language model.
- Type:AbsLM
sos_ids
List of start-of-sequence token indices.
- Type: List[int]
eos_id
End-of-sequence token index.
- Type: int
ignore_id
Token index to ignore during loss computation.
- Type: int
token_list
List of tokens in the vocabulary.
- Type: List[str]
criterion_att
Loss function for attention loss.
Type: LabelSmoothingLoss
Parameters:
- lm (AbsLM) – An instance of the language model.
- vocab_size (int) – The size of the vocabulary.
- token_list (Union *[*Tuple *[*str , ... ] , List *[*str ] ]) – List of tokens in the vocabulary.
- ignore_id (int , optional) – Token index to ignore during loss computation. Defaults to 0.
- lsm_weight (float , optional) – Weight for label smoothing. Defaults to 0.0.
- length_normalized_loss (bool , optional) – Whether to normalize the loss by length. Defaults to False.
- sos_syms (List *[*str ] , optional) – List of start-of-sequence symbols. Defaults to [“<generatetext>”, “<generatespeech>”].
- eos_sym (str , optional) – End-of-sequence symbol. Defaults to “<sos/eos>”.
nll(text
torch.Tensor, text_lengths: torch.Tensor, max_length: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: Computes the negative log likelihood (NLL) for the given input text.
batchify_nll(text
torch.Tensor, text_lengths: torch.Tensor, batch_size: int = 100) -> Tuple[torch.Tensor, torch.Tensor]: Computes NLL in batches to avoid out-of-memory (OOM) errors.
_calc_att_loss(text
torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: Calculates the attention loss and accuracy based on the input text.
forward(text
torch.Tensor, text_lengths: torch.Tensor,
**
kwargs) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: Performs a forward pass through the model, computing loss and accuracy.
collect_feats(text
torch.Tensor, text_lengths: torch.Tensor,
**
kwargs) -> Dict[str, torch.Tensor]: Collects features from the input text (currently returns an empty dict).
############# Examples
>>> model = ESPnetMultitaskLanguageModel(lm, vocab_size, token_list)
>>> text_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> text_lengths = torch.tensor([3, 3])
>>> nll, lengths = model.nll(text_tensor, text_lengths)
>>> print(nll.shape) # Output: (Batch, Length)
NOTE
The nll method is specifically designed for calculating perplexity.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
#
batchify_nll(text
Compute negative log likelihood (nll) from transformer language model.
To avoid out-of-memory (OOM) errors, this function separates the input into batches. It then calls the nll method for each batch, combining and returning the results.
- Parameters:
- text – A tensor of shape (Batch, Length) representing the input sequences.
- text_lengths – A tensor of shape (Batch,) representing the lengths of each input sequence.
- batch_size – An integer specifying the number of samples in each batch when computing nll. You may change this value to avoid OOM errors or to increase throughput.
- Returns:
- nll: A tensor of shape (Total, Length) representing the : negative log likelihood for each input sequence.
- x_lengths: A tensor of shape (Total,) representing the lengths : of the input sequences after processing.
- Return type: A tuple containing
############# Examples
>>> model = ESPnetMultitaskLanguageModel(...)
>>> text = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> text_lengths = torch.tensor([3, 3])
>>> nll, lengths = model.batchify_nll(text, text_lengths, batch_size=2)
#
collect_feats(text
Collect features from the input text for further processing.
This method is currently a placeholder and does not perform any operations on the input text. It is intended to be implemented in the future to extract features from the provided input.
- Parameters:
- text – A tensor of shape (Batch, Length) representing the input text.
- text_lengths – A tensor of shape (Batch,) indicating the lengths of each input sequence.
- Returns: A dictionary containing extracted features as tensors. Currently, this method returns an empty dictionary.
############# Examples
>>> model = ESPnetMultitaskLanguageModel(...)
>>> features = model.collect_feats(text_tensor, text_lengths_tensor)
>>> print(features) # Output: {}
#
forward(text
Forward pass for the ESPnet multitask language model.
This method computes the attention loss and accuracy for the input text and its corresponding lengths. It also ensures that the results are compatible with DataParallel by using force_gatherable.
- Parameters:
- text (torch.Tensor) – Input tensor of shape (Batch, Length) containing the text data.
- text_lengths (torch.Tensor) – Tensor of shape (Batch,) indicating the lengths of each input sequence in the batch.
- **kwargs – Additional keyword arguments.
- Returns:
- loss (torch.Tensor): The computed loss value.
- stats (Dict[str, torch.Tensor]): A dictionary containing statistics such as accuracy.
- weight (torch.Tensor): The batch size for the input.
- Return type: Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]
############# Examples
>>> model = ESPnetMultitaskLanguageModel(...)
>>> text = torch.tensor([[1, 2, 3], [1, 2, 0]])
>>> text_lengths = torch.tensor([3, 2])
>>> loss, stats, weight = model.forward(text, text_lengths)
>>> print(loss)
>>> print(stats['acc'])
NOTE
The method uses an internal _calc_att_loss function to compute the loss and accuracy based on the input text.
#
nll(text
Compute negative log likelihood (nll).
NOTE(yifan): We only use nll to calculate perplexity, so there is no condition in each sentence.
This function is typically called in batchify_nll.
- Parameters:
- text – A tensor of shape (Batch, Length) representing input text.
- text_lengths – A tensor of shape (Batch,) representing the lengths of each sequence in the batch.
- max_length – An optional integer specifying the maximum length of the sequences to consider. If None, the maximum length will be determined from text_lengths.
- Returns:
- A tensor of shape (Batch, Length) representing the negative log likelihood for each sequence.
- A tensor of shape (Batch,) representing the adjusted lengths of the sequences after processing.
- Return type: A tuple containing
############# Examples
>>> model = ESPnetMultitaskLanguageModel(...)
>>> text = torch.tensor([[1, 2, 3], [1, 2, 0]])
>>> text_lengths = torch.tensor([3, 2])
>>> nll, lengths = model.nll(text, text_lengths)
>>> print(nll.shape) # Output: torch.Size([2, 3])
>>> print(lengths) # Output: tensor([2, 1])
- Raises:
- NotImplementedError – If max_length is not None and the
- corresponding functionality has not been implemented. –