espnet2.speechlm.module.transformer.MultiHeadAttention
espnet2.speechlm.module.transformer.MultiHeadAttention
class espnet2.speechlm.module.transformer.MultiHeadAttention(n_state: int, n_head: int, causal: bool = False)
Bases: Module
MultiHeadAttention is a module that implements multi-head attention as part of
the Transformer architecture. This implementation supports both self-attention and cross-attention mechanisms and utilizes PyTorch’s built-in flash attention for efficiency.
n_head
The number of attention heads.
- Type: int
query
Linear layer for the query projection.
- Type:Linear
key
Linear layer for the key projection.
- Type:Linear
value
Linear layer for the value projection.
- Type:Linear
out
Linear layer for the output projection.
- Type:Linear
causal
Indicates whether the attention is causal.
Type: bool
Parameters:
- n_state (int) – The dimension of the input state.
- n_head (int) – The number of attention heads.
- causal (bool , optional) – Whether to use causal attention. Defaults to False.
Raises:ValueError – If the number of heads does not divide the state dimension or if the PyTorch version is incompatible with flash attention.
######### Examples
>>> attention = MultiHeadAttention(n_state=64, n_head=8)
>>> x = torch.rand(10, 20, 64) # (batch_size, sequence_length, n_state)
>>> output = attention(x)
>>> print(output.shape) # (10, 20, 64)
>>> x_a = torch.rand(10, 20, 64)
>>> x_b = torch.rand(10, 30, 64)
>>> output = attention(x_a, xa=x_b)
>>> print(output.shape) # (10, 20, 64)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x: Tensor, xa: Tensor | None = None, mask: Tensor | None = None, kv_cache: dict | None = None)
Computes the forward pass of the MultiHeadAttention module.
This method applies multi-head attention to the input tensor x, and can optionally take in a second input tensor xa for cross-attention, a mask for attention scores, and a key-value cache for efficiency during autoregressive decoding.
- Parameters:
- x (Tensor) – The input tensor of shape (batch_size, seq_length, n_state).
- xa (Optional *[*Tensor ]) – An optional tensor for cross-attention, with the same shape as x. Default is None.
- mask (Optional *[*Tensor ]) – An optional mask tensor of shape (batch_size, n_head, seq_length, seq_length) to prevent attending to certain positions. Default is None.
- kv_cache (Optional *[*dict ]) – A cache for key and value tensors, used for optimizing cross-attention. Default is None.
- Returns: The output tensor of shape (batch_size, seq_length, n_state).
- Return type: Tensor
- Raises:ValueError – If mask is provided while causal is set to True.
######### Examples
>>> mha = MultiHeadAttention(n_state=64, n_head=8, causal=True)
>>> x = torch.rand(10, 20, 64) # (batch_size, seq_length, n_state)
>>> output = mha(x)
>>> output.shape
torch.Size([10, 20, 64])
>>> xa = torch.rand(10, 30, 64) # Cross-attention input
>>> mask = torch.ones(10, 8, 20, 20) # Example mask
>>> output = mha(x, xa=xa, mask=mask)
>>> output.shape
torch.Size([10, 20, 64])
qkv_attention(q: Tensor, k: Tensor, v: Tensor, mask: Tensor | None = None)
Computes the query-key-value attention mechanism.
This method applies the multi-head attention mechanism by computing the scaled dot-product attention between the query (q), key (k), and value (v) tensors. The attention is calculated with consideration for causal masking, if applicable.
- Parameters:
- q (Tensor) – The query tensor of shape (batch_size, seq_len, n_state).
- k (Tensor) – The key tensor of shape (batch_size, seq_len, n_state).
- v (Tensor) – The value tensor of shape (batch_size, seq_len, n_state).
- mask (Optional *[*Tensor ]) – An optional tensor for masking of shape (batch_size, n_head, seq_len, seq_len). Defaults to None.
- Returns: The output tensor after applying the attention mechanism, with shape (batch_size, seq_len, n_state).
- Return type: Tensor
- Raises:ValueError – If causal attention is requested but a mask is provided.
######### Examples
>>> attention_layer = MultiHeadAttention(n_state=64, n_head=8)
>>> q = torch.rand(2, 10, 64) # (batch_size, seq_len, n_state)
>>> k = torch.rand(2, 10, 64)
>>> v = torch.rand(2, 10, 64)
>>> output = attention_layer.qkv_attention(q, k, v)
>>> output.shape
torch.Size([2, 10, 64])
NOTE
The method uses PyTorch’s built-in scaled dot-product attention and assumes that the input tensors are of appropriate shapes.