espnet2.speechlm.core_lm.ar_multiscale.MultiScaleLM
espnet2.speechlm.core_lm.ar_multiscale.MultiScaleLM
class espnet2.speechlm.core_lm.ar_multiscale.MultiScaleLM(vocab_size: int, nq: int, share_emb: bool = True, g_att_unit: int = 256, g_head: int = 2, g_layer: int = 4, l_att_unit: int = 256, l_head: int = 2, l_layer: int = 4, n_ctx: int = 3000, first_layer_weight: int = 1.0)
Bases: AbsCoreLM
MultiScaleLM is an implementation of the UniAudio architecture for speech
language modeling. It leverages multi-scale attention mechanisms to process input sequences and generate output sequences effectively.
emb
Embedding layer for input vocabulary.
- Type: torch.nn.Embedding
lm_head
Linear layer mapping from hidden states to vocabulary logits.
- Type: torch.nn.Linear
g_decoders
Global Transformer decoder for processing long-range dependencies.
- Type:TransformerDecoder
l_decoders
Local Transformer decoder for processing fine-grained details.
- Type:TransformerDecoder
placeholder
A learnable parameter to facilitate local attention computations.
- Type: torch.nn.Parameter
nq
Number of codes for each token/frame.
- Type: int
first_layer_weight
Scaling factor for gradients in the first layer.
Type: float
Parameters:
- vocab_size (int) – Dimension of vocabulary.
- nq (int) – Number of codes for each token/frame, usually for speech codec.
- share_emb (bool) – If true, share the embedding and lm_head weight.
- g_att_unit (int) – Dimension of global Transformer attention.
- g_head (int) – Number of heads in global Transformer attention.
- g_layer (int) – Number of layers in global Transformer.
- l_att_unit (int) – Dimension of local Transformer attention.
- l_head (int) – Number of heads in local Transformer attention.
- l_layer (int) – Number of layers in local Transformer.
- n_ctx (int) – Maximum context length of global Transformer.
- first_layer_weight (float) – A factor to scale the gradient for the first-layer codes.
####### Examples
Initialize the MultiScaleLM model
model = MultiScaleLM(vocab_size=5000, nq=10, share_emb=True)
Forward pass
dec_seq = torch.randint(0, 5000, (32, 50, 10)) # Example decoder input dec_seq_lengths = torch.tensor([50] * 32) # Lengths of the decoder sequences loss, stats, weight = model.forward(dec_seq, dec_seq_lengths)
Inference
prefix = torch.randint(0, 5000, (32, 10, 10)) # Example prefix input opts = SpeechLMInferenceOptions() # Define inference options gen_tokens, gen_scores = model.inference(prefix, opts)
NOTE
This implementation is based on the architecture described in the paper “UniAudio” (https://arxiv.org/abs/2310.00704).
- Raises:ValueError – If global and local attention sizes are not equal during initialization.
Initialize MultiScaleLM
- Parameters:
- vocab_size (int) – Dimention of vocabulary.
- nq (int) – Number of codes for each token / frame, usually for speech codec.
- share_emb (bool) – If true, share the embedding and lm_head weight.
- g_att_unit (int) – Dimention of global Transformer attention.
- g_head (int) – Number of heads in global Transformer attention.
- g_layer (int) – Number of layers in global Transformer.
- l_att_unit (int) – Dimention of local Transformer attention.
- l_head (int) – Number of heads in local Transformer attention.
- l_layer (int) – Number of layers in local Transformer.
- n_ctx (int) – maximum context length of global Transformer.
- first_layer_weight (int) – a factor to scale the gradient for the first-layer codes.
forward(dec_seq: Tensor, dec_seq_lengths: Tensor | None = None, enc_seq: Tensor | None = None, enc_seq_lengths: Tensor | None = None, prefix_len: Tensor | None = None) → Tuple[Tensor, Dict, Tensor]
Auto-Regresive MultiScale forward for training
- Parameters:
- dec_seq (LongTensor) – Batch of decoder sequences (B, T, nq).
- dec_seq_lengths (LongTensor) – Lengths of batched decoder sequences (B,).
- enc_seq (LongTensor) – Batch of encoder sequences (B, T, nq), keep the interface, may not be used.
- enc_seq_lengths (LongTensor) – Lengths of batched encoder sequences (B,), keep the interface, may not be used.
- prefix_len (LongTensor) – Lengths of condition part in dec_seq (B,).
inference(prefix: Tensor, opts: SpeechLMInferenceOptions, enc_seq: Tensor = None, suffix: Tensor = None)
Auto-Regresive MultiScale Inference.
This method performs inference using the MultiScaleLM model. It generates tokens based on a given prefix and optional suffix. The process involves global and local attention mechanisms to create a coherent output sequence.
- Parameters:
- prefix (LongTensor) – Prefix part of dec_seq (B, T_dec, nq).
- opts (SpeechLMInferenceOptions) – Inference options, including parameters such as nbest, minlenratio, maxlenratio, search_algo, start, and eos.
- enc_seq (LongTensor , optional) – Encoder token sequence (B, T_enc, nq). This is optional and may not be used in the inference process.
- suffix (LongTensor , optional) – Suffix part of dec_seq (B, T_dec, nq), usually the target sequence for teacher-forcing. This is optional.
- Returns: A tuple containing: : - gen_tokens (List[torch.Tensor]): Generated token sequences for each batch.
- gen_scores (List[torch.Tensor]): Corresponding scores for the generated tokens.
- Return type: Tuple[List[torch.Tensor], List[torch.Tensor]]
####### Examples
>>> model = MultiScaleLM(vocab_size=1000, nq=10)
>>> prefix = torch.randint(0, 1000, (2, 5, 10)) # Example prefix
>>> opts = SpeechLMInferenceOptions(nbest=3, minlenratio=0.5,
... maxlenratio=1.5, search_algo='greedy',
... start=0, eos=1)
>>> gen_tokens, gen_scores = model.inference(prefix, opts)
NOTE
This method uses caching mechanisms for efficiency, allowing the model to keep track of previously computed states during the generation process.