espnet2.speechlm.espnet_model.ESPnetSpeechLMModel
espnet2.speechlm.espnet_model.ESPnetSpeechLMModel
class espnet2.speechlm.espnet_model.ESPnetSpeechLMModel(corelm: AbsCoreLM, extract_feats_in_collect_stats: bool = False)
Bases: AbsESPnetModel
ESPnetSpeechLMModel is a language model for speech processing that leverages a core
language model (CoreLM) to generate and evaluate sequences. It is a subclass of the AbsESPnetModel and provides a forward method to process input sequences.
corelm
An instance of a core language model that performs the actual sequence processing.
- Type:AbsCoreLM
extract_feats_in_collect_stats
A flag indicating whether to extract features while collecting statistics.
Type: bool
Parameters:
- corelm (AbsCoreLM) – An instance of the core language model to be used.
- extract_feats_in_collect_stats (bool , optional) – Whether to extract features in the collect statistics phase. Defaults to False.
Returns: A tuple containing : the loss, statistics, and weight computed during the forward pass.
Return type: Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]
Raises:NotImplementedError – If the collect_feats method is called, as it is not implemented in this class.
######### Examples
>>> model = ESPnetSpeechLMModel(corelm)
>>> dec_seq = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> dec_seq_lengths = torch.tensor([3, 3])
>>> loss, stats, weight = model.forward(dec_seq, dec_seq_lengths)
####### NOTE This model is designed for use in the ESPnet framework for speech language modeling tasks.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
collect_feats(**kwargs)
Collects features from the model. This method is currently not implemented.
- Parameters:**kwargs – Additional keyword arguments that may be required for future implementation.
- Raises:NotImplementedError – This method is a placeholder and should be implemented in subclasses.
######### Examples
To use this method, you would typically call it on an instance of ESPnetSpeechLMModel, like so:
python model = ESPnetSpeechLMModel(corelm) features = model.collect_feats()
####### NOTE This method is expected to be overridden in derived classes to provide the actual feature collection functionality.
forward(dec_seq: Tensor, dec_seq_lengths: Tensor, **kwargs) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Performs the forward pass of the ESPnetSpeechLMModel, computing the loss,
statistics, and weights based on the provided input sequences.
- Parameters:
dec_seq (torch.Tensor) – The decoded sequence tensor of shape (B, T_dec).
dec_seq_lengths (torch.Tensor) – A tensor containing the lengths of the decoded sequences of shape (B,).
**kwargs –
Additional keyword arguments, which may include: enc_seq (torch.Tensor, optional): The encoded sequence tensor of shape
(B, T_enc).
enc_seq_lengths (torch.Tensor, optional): A tensor containing the lengths : of the encoded sequences of shape (B,).
prefix_len (torch.Tensor, optional): A tensor representing the length of : the prefix for the sequence.
- Returns: A tuple containing: : - loss (torch.Tensor): The computed loss value.
- stats (Dict[str, torch.Tensor]): A dictionary containing various : statistics.
- weight (torch.Tensor): The computed weight values.
- Return type: Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]
- Raises:NotImplementedError – If called inappropriately or without required components.
######### Examples
>>> model = ESPnetSpeechLMModel(corelm)
>>> dec_seq = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> dec_seq_lengths = torch.tensor([3, 3])
>>> loss, stats, weight = model.forward(dec_seq, dec_seq_lengths)
####### NOTE Ensure that the corelm is properly initialized and configured before invoking this method.