espnet2.speechlm.core_lm.valle.ValleLM
espnet2.speechlm.core_lm.valle.ValleLM
class espnet2.speechlm.core_lm.valle.ValleLM(vocab_size: int, nq: int, share_emb: bool = True, att_unit: int = 256, head: int = 2, ar_layer: int = 4, nar_layer: int = 4, n_ctx: int = 3000)
Bases: AbsCoreLM
Implementation of the Vall-E model for speech language modeling.
This class initializes and defines the forward and inference methods for the Vall-E model as described in the paper: https://arxiv.org/abs/2301.02111.
emb
The embedding layer for input tokens.
- Type: torch.nn.Embedding
lm_head
The linear layer for generating output logits.
- Type: torch.nn.Linear
ar_decoder
The auto-regressive decoder.
- Type:TransformerDecoder
nar_decoder
The non-auto-regressive decoder.
- Type:ValleNARDecoder
nq
Number of codes for each token/frame.
Type: int
Parameters:
- vocab_size (int) – Dimension of vocabulary.
- nq (int) – Number of codes for each token/frame, usually for speech codec.
- share_emb (bool) – If True, share the embedding and lm_head weight.
- att_unit (int) – Dimension of Transformer attention.
- head (int) – Number of heads in Transformer attention.
- ar_layer (int) – Number of layers in AR Transformer.
- nar_layer (int) – Number of layers in NAR Transformer.
- n_ctx (int) – Maximum context length of AR & NAR Transformer.
forward()
Computes the forward pass of the Vall-E model for training.
inference()
Performs inference using the Vall-E model.
######### Examples
Initialize the model
model = ValleLM(vocab_size=1000, nq=256)
Forward pass
loss, stats, weight = model.forward(dec_seq, dec_seq_lengths)
Inference
generated_tokens, generated_scores = model.inference(prefix, opts)
Initialize Vall-E model
- Parameters:
- vocab_size (int) – Dimention of vocabulary.
- nq (int) – Number of codes for each token / frame, usually for speech codec.
- share_emb (bool) – If true, share the embedding and lm_head weight.
- att_unit (int) – Dimention of Transformer attention.
- head (int) – Number of heads in Transformer attention.
- ar_layer (int) – Number of layers in AR Transformer.
- nar_layer (int) – Number of layers in NAR Transformer.
- n_ctx (int) – maximum context length of AR & NAR Transformer.
forward(dec_seq: Tensor, dec_seq_lengths: Tensor | None = None, enc_seq: Tensor | None = None, enc_seq_lengths: Tensor | None = None, prefix_len: Tensor | None = None) → Tuple[Tensor, Tensor, Dict]
Vall-E forward for training.
This method performs a forward pass through the Vall-E model, calculating the loss and returning relevant statistics. It processes both auto-regressive (AR) and non-auto-regressive (NAR) sequences.
- Parameters:
- dec_seq (LongTensor) – Batch of decoder sequences (B, T, nq).
- dec_seq_lengths (LongTensor) – Lengths of batched decoder sequences (B,).
- enc_seq (LongTensor) – Batch of encoder sequences (B, T, nq), keeping the interface, may not be used.
- enc_seq_lengths (LongTensor) – Lengths of batched encoder sequences (B,), keeping the interface, may not be used.
- prefix_len (LongTensor) – Lengths of condition part in dec_seq (B,).
- Returns: A tuple containing: : - loss (torch.Tensor): Computed loss value.
- stats (torch.Tensor): A dictionary containing statistics related to the model’s performance.
- weight (torch.Tensor): The weight used for loss calculation.
- Return type: Tuple[torch.Tensor, torch.Tensor, Dict]
- Raises:
- AssertionError – If the dimensions of dec_seq are not as
- expected. –
######### Examples
>>> model = ValleLM(vocab_size=100, nq=10)
>>> dec_seq = torch.randint(0, 100, (32, 20, 10))
>>> dec_seq_lengths = torch.randint(1, 21, (32,))
>>> loss, stats, weight = model.forward(dec_seq, dec_seq_lengths)
NOTE
This method assumes that dec_seq has three dimensions, and it is critical to maintain the shape (B, T, nq).
inference(prefix: Tensor, opts: SpeechLMInferenceOptions, enc_seq: Tensor = None, suffix: Tensor = None)
Vall-E Inference.
This method performs inference using the Vall-E model, which can generate sequences based on a provided prefix and optional suffix. The inference process is divided into two parts: an auto-regressive (AR) generation for the first code layer and a non-auto-regressive (NAR) generation for the remaining layers.
- Parameters:
- prefix (LongTensor) – Prefix part of dec_seq (B, T, nq).
- opts (SpeechLMInferenceOptions) – Inference options.
- enc_seq (LongTensor , optional) – Encoder token sequence (B, T, nq).
- suffix (LongTensor , optional) – Suffix part of dec_seq (B, T, nq), usually the target sequence for teacher-forcing.
- Returns: A tuple containing two lists. The first list contains generated token sequences for each valid batch, and the second list contains the corresponding scores.
- Return type: Tuple[List[LongTensor], List[LongTensor]]
######### Examples
>>> prefix = torch.tensor([[[1, 2, 3], [4, 5, 6]]]) # Example prefix
>>> opts = SpeechLMInferenceOptions(nbest=1, minlenratio=0.2,
... maxlenratio=1.0, start=0,
... eos=1, search_algo='greedy')
>>> gen_tokens, gen_scores = model.inference(prefix, opts)
NOTE
The method will log the process of generation, including the termination steps and any warnings if no valid examples are generated.
- Raises:
- AssertionError – If the provided suffix is None when the search
- algorithm is set to "teacher_force". –
prepare_input(dec_seq_emb, prefix_len, level)