espnet2.mt.frontend.embedding.PatchEmbedding
espnet2.mt.frontend.embedding.PatchEmbedding
class espnet2.mt.frontend.embedding.PatchEmbedding(input_size: int = 400, embed_dim: int = 400, token_per_frame: int = 1, pos_enc_class=<class 'espnet.nets.pytorch_backend.transformer.embedding.PositionalEncoding'>, positional_dropout_rate: float = 0.1)
Bases: AbsFrontend
Embedding Frontend for text based inputs.
This class implements an embedding layer that processes input tokens in patches, allowing for a specified number of tokens per frame. It utilizes a specified positional encoding class and applies layer normalization after embedding the input.
embed_dim
Dimension of the embedding.
- Type: int
token_per_frame
Number of tokens per frame in the input.
- Type: int
emb
The embedding layer.
- Type: torch.nn.Embedding
pos
The positional encoding layer.
- Type: PositionalEncoding
ln
Layer normalization layer.
Type: torch.nn.LayerNorm
Parameters:
- input_size (int) – Number of input tokens. Defaults to 400.
- embed_dim (int) – Embedding size. Defaults to 400.
- token_per_frame (int) – Number of tokens per frame in the input. Defaults to 1.
- pos_enc_class – Class for positional encoding, either PositionalEncoding or ScaledPositionalEncoding. Defaults to PositionalEncoding.
- positional_dropout_rate (float) – Dropout rate after adding positional encoding. Defaults to 0.1.
Raises:AssertionError – If input dimensions or lengths are invalid.
######### Examples
>>> patch_embedding = PatchEmbedding(input_size=500, embed_dim=256)
>>> input_tensor = torch.randint(0, 500, (32, 16)) # Batch of 32
>>> input_lengths = torch.full((32,), 16) # All sequences of length 16
>>> output, output_lengths = patch_embedding(input_tensor, input_lengths)
>>> output.shape # Should be (32, 16 // token_per_frame, 256)
torch.Size([32, 16, 256])
Initialize.
- Parameters:
- input_size – Number of input tokens.
- embed_dim – Embedding Size.
- token_per_frame – number of tokens per frame in the input
- pos_enc_class – PositionalEncoding or ScaledPositionalEncoding
- positional_dropout_rate – dropout rate after adding positional encoding
forward(input: Tensor, input_lengths: Tensor) → Tuple[Tensor, Tensor]
Embedding Frontend for text based inputs.
This class is designed to perform patch embedding on the input sequences. It applies a sliding window mechanism to the input and uses an embedding layer followed by positional encoding and layer normalization.
embed_dim
The dimensionality of the embedding space.
- Type: int
token_per_frame
The number of tokens per frame in the input.
Type: int
Parameters:
- input_size (int) – Number of input tokens. Default is 400.
- embed_dim (int) – Embedding size. Default is 400.
- token_per_frame (int) – Number of tokens per frame in the input. Default is 1.
- pos_enc_class – Class for positional encoding (default: PositionalEncoding).
- positional_dropout_rate (float) – Dropout rate after adding positional encoding. Default is 0.1.
Raises:AssertionError – If the input tensor’s dimensions or lengths are invalid.
######### Examples
>>> import torch
>>> model = PatchEmbedding(input_size=500, embed_dim=256, token_per_frame=4)
>>> input_tensor = torch.randint(0, 500, (8, 16)) # (B, T)
>>> input_lengths = torch.tensor([16] * 8) # Lengths for each batch
>>> output, output_lengths = model(input_tensor, input_lengths)
>>> print(output.shape) # Output shape should be (8, 4, 256)
>>> print(output_lengths) # Output lengths should be (8,)
NOTE
Ensure that the input tensor’s second dimension is divisible by token_per_frame, and that input lengths are also valid.
output_size() → int
Return output length of feature dimension D, i.e. the embedding dim.
This method provides the dimensionality of the output feature vector produced by the embedding layer. The output size is equal to the embedding dimension defined during the initialization of the PatchEmbedding class.
- Returns: The size of the output feature dimension, which is equal to the embedding dimension (embed_dim).
- Return type: int
######### Examples
>>> patch_embedding = PatchEmbedding(embed_dim=512)
>>> output_dim = patch_embedding.output_size()
>>> print(output_dim)
512