espnet2.ssl.loss.hubert.HuBERTLoss
Less than 1 minute
espnet2.ssl.loss.hubert.HuBERTLoss
class espnet2.ssl.loss.hubert.HuBERTLoss(encoder_output_size: int, num_classes: int, final_dim: int, loss_type: str = 'cross_entropy', layers: List = [-1], loss_weights: List = [1.0])
Bases: AbsSSLLoss
HuBERT MLM Loss
- Parameters:
- encoder_output_size (int) β input dimension
- num_classes (int) β vocab size
- final_dim (int) β final projection dim
- loss_type (str) β TODO, unused for now
- layers (List) β encoder output layers for loss
- loss_weights (List) β weight of each layer for loss
forward(encoder_output: List, encoder_output_lengths: Tensor = None, text: Tensor = None, text_lengths: Tensor = None, mask_info: Dict = None) β Tuple[Tensor, Dict]
HuBERT forward
- Parameters:
- encoder_output (List) β List of encoded sequences (B, T, D) from each layer.
- encoder_output_lengths (Tensor) β Lengths of batched encoder sequences (B,).
- text (Tensor) β text targets (B, T)
- text_lengths (Tensor) β Lengths of text targets (B,).
- mask_info (Dict) β Contains masked/unmasked indices
