espnet2.asr.decoder.whisper_decoder.ExpandedTokenEmbedding
espnet2.asr.decoder.whisper_decoder.ExpandedTokenEmbedding
class espnet2.asr.decoder.whisper_decoder.ExpandedTokenEmbedding(ori_emebedding, additional_size)
Bases: Module
ExpandedTokenEmbedding is a PyTorch module that extends the functionality of a given embedding layer by adding additional token embeddings. This class is designed to accommodate scenarios where the vocabulary size needs to be expanded while maintaining the original embeddings.
ori_emb
The original embedding layer.
- Type: torch.nn.Embedding
add_emb
The additional embedding layer for new tokens.
- Type: torch.nn.Embedding
num_embeddings
Total number of embeddings after expansion.
Type: int
Parameters:
- ori_emebedding (torch.nn.Embedding) – The original embedding layer to extend.
- additional_size (int) – The number of additional embeddings to add.
Returns: The concatenated weights of the original and additional embeddings.
Return type: torch.Tensor
######### Examples
>>> original_embedding = torch.nn.Embedding(10, 5) # 10 tokens, 5 dimensions
>>> expanded_embedding = ExpandedTokenEmbedding(original_embedding, 5)
>>> expanded_embedding.num_embeddings
15 # Original 10 plus 5 additional tokens
####### NOTE The additional embeddings are initialized with the same mean and standard deviation as the original embeddings.
- Raises:ValueError – If the original embedding is not of type torch.nn.Embedding.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(input)
Forward decoder.
This method performs the forward pass of the decoder, taking encoded memory and input token IDs to produce token scores. The output can be used for further processing such as computing loss or for generating predictions.
- Parameters:
- hs_pad – Encoded memory, a float32 tensor of shape (batch, maxlen_in, feat) representing the features from the encoder.
- hlens – A tensor of shape (batch) containing the lengths of the encoded sequences in hs_pad.
- ys_in_pad – Input token IDs, an int64 tensor of shape (batch, maxlen_out). This represents the input tokens for the decoder. If input_layer is set to “embed”, this should be a tensor of token IDs. In other cases, it can be a tensor of shape (batch, maxlen_out, #mels).
- ys_in_lens – A tensor of shape (batch) containing the lengths of the input sequences in ys_in_pad.
- Returns: A tuple containing: : - x: Decoded token scores before softmax, a tensor of shape (batch, maxlen_out, token) if use_output_layer is True.
- olens: A tensor of shape (batch,) containing the lengths of the output sequences.
- Return type: tuple
######### Examples
>>> hs_pad = torch.randn(2, 10, 512) # (batch, maxlen_in, feat)
>>> hlens = torch.tensor([10, 8]) # (batch)
>>> ys_in_pad = torch.tensor([[1, 2, 3], [1, 2, 0]]) # (batch, maxlen_out)
>>> ys_in_lens = torch.tensor([3, 2]) # (batch)
>>> x, olens = decoder.forward(hs_pad, hlens, ys_in_pad, ys_in_lens)
####### NOTE Ensure that the input tensors are appropriately sized and formatted as per the expected shapes to avoid runtime errors.
property weight
Returns the concatenated weights of the original and additional embeddings.
The weight property returns a tensor that combines the weights from the original embedding and the additional embedding, allowing for an expanded token representation.
weight
A tensor of shape (num_embeddings, embedding_dim) containing the concatenated weights of the original and additional embeddings.
Type: torch.Tensor
Returns: The combined weights of the original and additional : embeddings.
Return type: torch.Tensor
######### Examples
>>> ori_embedding = torch.nn.Embedding(10, 5) # original embedding
>>> expanded_embedding = ExpandedTokenEmbedding(ori_embedding, 5)
>>> combined_weights = expanded_embedding.weight
>>> combined_weights.shape
torch.Size([15, 5]) # 10 original + 5 additional embeddings
####### NOTE The additional embedding weights are initialized using the mean and standard deviation of the original embedding weights.