espnet2.tts.gst.style_encoder.MultiHeadedAttention
Less than 1 minute
espnet2.tts.gst.style_encoder.MultiHeadedAttention
class espnet2.tts.gst.style_encoder.MultiHeadedAttention(q_dim, k_dim, v_dim, n_head, n_feat, dropout_rate=0.0)
Bases: MultiHeadedAttention
Multi head attention module with different input dimension.
This module extends the base multi-headed attention mechanism to support different input dimensions for queries, keys, and values. It is used in various architectures, including transformers and attention-based models.
- Parameters:
- q_dim (int) – Dimension of the input queries.
- k_dim (int) – Dimension of the input keys.
- v_dim (int) – Dimension of the input values.
- n_head (int) – Number of attention heads.
- n_feat (int) – Total dimension of the input features.
- dropout_rate (float , optional) – Dropout rate for attention weights (default: 0.0).
- Raises:AssertionError – If n_feat is not divisible by n_head.
Examples
>>> attention_layer = MultiHeadedAttention(q_dim=64, k_dim=64, v_dim=64,
... n_head=8, n_feat=64)
>>> query = torch.rand(10, 20, 64) # (batch_size, seq_length, q_dim)
>>> key = torch.rand(10, 15, 64) # (batch_size, seq_length, k_dim)
>>> value = torch.rand(10, 15, 64) # (batch_size, seq_length, v_dim)
>>> output = attention_layer(query, key, value)
>>> print(output.shape) # Output shape should be (10, 20, 64)
NOTE
The d_v (dimension of values) is assumed to be equal to d_k (dimension of keys).
Initialize multi head attention module.