espnet2.hubert.espnet_model.HubertPretrainModel
espnet2.hubert.espnet_model.HubertPretrainModel
class espnet2.hubert.espnet_model.HubertPretrainModel(vocab_size: int, token_list: Tuple[str, ...] | List[str], frontend: AbsFrontend | None, specaug: AbsSpecAug | None, normalize: AbsNormalize | None, preencoder: AbsPreEncoder | None, encoder: AbsEncoder, ignore_id: int = -1, lsm_weight: float = 0.0, length_normalized_loss: bool = False, report_cer: bool = False, report_wer: bool = False, sym_space: str = '<space>', sym_blank: str = '<blank>', pred_masked_weight: float = 1.0, pred_nomask_weight: float = 0.0, loss_weights: float = 0.0, **kwargs)
Bases: AbsESPnetModel
HubertPretrainModel is a model class for pre-training HuBERT (Hidden Unit
BERT) using self-supervised learning techniques.
This model takes speech input and associated text input, processes them through a series of layers, and computes the loss based on the predictions made. It is designed for training with masked and unmasked tokens to improve the model’s understanding of speech.
sos
Start of sequence token ID.
- Type: int
eos
End of sequence token ID.
- Type: int
vocab_size
Size of the vocabulary.
- Type: int
ignore_id
Token ID to ignore during loss computation.
- Type: int
token_list
List of tokens corresponding to the vocabulary.
- Type: list
frontend
Frontend processing module.
- Type:AbsFrontend
specaug
SpecAugment module for data augmentation.
- Type:AbsSpecAug
normalize
Normalization module.
- Type:AbsNormalize
preencoder
Pre-encoder module for raw input data.
- Type:AbsPreEncoder
encoder
Main encoder module.
- Type:AbsEncoder
criterion_hubert
Loss computation module for HuBERT.
- Type:HubertPretrainLoss
pred_masked_weight
Weight for masked predictions in loss.
- Type: float
pred_nomask_weight
Weight for unmasked predictions in loss.
- Type: float
loss_weights
Additional loss weights for training.
- Type: float
error_calculator
Optional error calculator for evaluation metrics.
Type:ErrorCalculator
Parameters:
- vocab_size (int) – Size of the vocabulary.
- token_list (Union *[*Tuple *[*str , ... ] , List *[*str ] ]) – List of tokens.
- frontend (Optional [AbsFrontend ]) – Frontend processing module.
- specaug (Optional [AbsSpecAug ]) – SpecAugment module for data augmentation.
- normalize (Optional [AbsNormalize ]) – Normalization module.
- preencoder (Optional [AbsPreEncoder ]) – Pre-encoder module for raw input data.
- encoder (AbsEncoder) – Main encoder module.
- ignore_id (int) – Token ID to ignore during loss computation (default: -1).
- lsm_weight (float) – Label smoothing weight (default: 0.0).
- length_normalized_loss (bool) – Whether to normalize loss by length (default: False).
- report_cer (bool) – Whether to report character error rate (default: False).
- report_wer (bool) – Whether to report word error rate (default: False).
- sym_space (str) – Token representing space (default: “<space>”).
- sym_blank (str) – Token representing blank (default: “<blank>”).
- pred_masked_weight (float) – Weight for masked predictions in loss (default: 1.0).
- pred_nomask_weight (float) – Weight for unmasked predictions in loss (default: 0.0).
- loss_weights (float) – Additional loss weights for training (default: 0.0).
- **kwargs – Additional keyword arguments.
Returns: A tuple containing: : - loss (torch.Tensor): Computed loss value.
- stats (Dict[str, torch.Tensor]): Statistics of the model including accuracy metrics.
- weight (torch.Tensor): Weight tensor for DataParallel.
Return type: Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]
########### Examples
model = HubertPretrainModel(vocab_size=5000, token_list=[“<blank>”, “<space>”, …]) speech_tensor = torch.randn(32, 16000) # Example batch of speech data speech_lengths = torch.tensor([16000] * 32) # Example lengths text_tensor = torch.randint(0, 5000, (32, 100)) # Example text data text_lengths = torch.tensor([100] * 32) # Example lengths
loss, stats, weight = model(speech_tensor, speech_lengths, text_tensor, text_lengths)
NOTE
This model is built upon the ESPnet framework and requires appropriate backend components such as frontend, encoder, and normalization layers to function correctly.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
collect_feats(speech: Tensor, speech_lengths: Tensor, text: Tensor, text_lengths: Tensor, **kwargs) → Dict[str, Tensor]
Extract features from speech input.
This method takes the speech input and its corresponding lengths, extracts the features using the model’s frontend, and returns the features along with their lengths in a dictionary.
- Parameters:
- speech – A tensor of shape (Batch, Length, …) representing the speech signals.
- speech_lengths – A tensor of shape (Batch,) containing the lengths of the speech signals.
- text – A tensor of shape (Batch, Length) representing the text input (not used in this method).
- text_lengths – A tensor of shape (Batch,) containing the lengths of the text input (not used in this method).
- kwargs – Additional keyword arguments.
- Returns:
- ‘feats’: The extracted features as a tensor.
- ’feats_lengths’: The lengths of the extracted features as a tensor.
- Return type: A dictionary containing
########### Examples
>>> model = HubertPretrainModel(...)
>>> speech = torch.randn(2, 16000) # Example speech tensor
>>> speech_lengths = torch.tensor([16000, 12000]) # Lengths
>>> text = torch.tensor([[1, 2, 3], [4, 5, 6]]) # Example text
>>> text_lengths = torch.tensor([3, 3]) # Lengths
>>> features = model.collect_feats(speech, speech_lengths, text, text_lengths)
>>> print(features['feats'].shape) # Output shape of features
>>> print(features['feats_lengths']) # Output lengths of features
compute_correct(logits)
Computes the number of correct predictions from logits.
This method evaluates the logits to determine the number of correct predictions based on the argmax and argmin criteria. It calculates the count of correct predictions while ensuring that both max and min predictions do not contribute to the correct count simultaneously.
- Parameters:logits (torch.Tensor) – A tensor containing the logits for which to compute correct predictions. The tensor must have at least two dimensions.
- Returns: A tuple containing: : - corr (int): The number of correct predictions.
- count (int): The total number of predictions evaluated.
- Return type: Tuple[int, int]
########### Examples
>>> import torch
>>> logits = torch.tensor([[0.2, 0.8], [0.9, 0.1], [0.5, 0.5]])
>>> correct, total = compute_correct(logits)
>>> print(correct, total)
(2, 3) # Example output may vary based on the content of logits.
encode(speech: Tensor, speech_lengths: Tensor, y_pad: Tensor, y_pad_length: Tensor) → Tuple[Tensor, Tensor]
Frontend + Encoder. Note that this method is used by asr_inference.py
- Parameters:
- speech – (Batch, Length, …)
- speech_lengths – (Batch, )
- y_pad – (Batch, Length, …)
- y_pad_length – (Batch, )
forward(speech: Tensor, speech_lengths: Tensor, text: Tensor, text_lengths: Tensor, **kwargs) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Frontend + Encoder + Calc loss
This method processes the input speech and text data through the frontend and encoder components, computes the loss, and returns the relevant statistics.
- Parameters:
- speech – A tensor of shape (Batch, Length, …) representing the input speech data.
- speech_lengths – A tensor of shape (Batch,) representing the lengths of each speech sample in the batch.
- text – A tensor of shape (Batch, Length) representing the input text data.
- text_lengths – A tensor of shape (Batch,) representing the lengths of each text sample in the batch.
- kwargs – Additional keyword arguments, where “utt_id” is among the inputs.
- Returns:
- loss: A tensor representing the computed loss.
- stats: A dictionary with statistics including accuracy.
- weight: A tensor representing the weight for data-parallel processing.
- Return type: A tuple containing
- Raises:AssertionError – If the dimensions of the input tensors do not match.
########### Examples
>>> model = HubertPretrainModel(...)
>>> speech = torch.randn(4, 16000) # Example speech tensor
>>> speech_lengths = torch.tensor([16000, 16000, 16000, 16000])
>>> text = torch.randint(0, 100, (4, 20)) # Example text tensor
>>> text_lengths = torch.tensor([20, 20, 20, 20])
>>> loss, stats, weight = model.forward(speech, speech_lengths, text,
... text_lengths)