espnet2.asr_transducer.decoder.modules.rwkv.attention.SelfAttention
espnet2.asr_transducer.decoder.modules.rwkv.attention.SelfAttention
class espnet2.asr_transducer.decoder.modules.rwkv.attention.SelfAttention(size: int, attention_size: int, context_size: int, block_id: int, num_blocks: int)
Bases: Module
SelfAttention module definition.
This module implements the SelfAttention mechanism used in RWKV architectures, which allows for effective time mixing in sequence processing tasks. It is based on the original implementation found in the RWKV-LM repository and has been modified to fit within the espnet2 framework.
time_shift
A zero-padding layer for temporal shifting of input sequences.
time_decay
A learnable parameter representing the channel-wise time decay.
time_first
A learnable parameter representing the channel-wise time first.
time_mix_key
A learnable parameter for mixing key inputs.
time_mix_value
A learnable parameter for mixing value inputs.
time_mix_receptance
A learnable parameter for mixing receptance inputs.
proj_key
A linear transformation applied to the key inputs.
proj_value
A linear transformation applied to the value inputs.
proj_receptance
A linear transformation applied to the receptance inputs.
proj_output
A linear transformation applied to the final output.
- Parameters:
- size – Input/Output size.
- attention_size – Attention hidden size.
- context_size – Context size for WKV kernel.
- block_id – Block index.
- num_blocks – Number of blocks in the architecture.
########### Examples
>>> self_attention = SelfAttention(size=512, attention_size=256,
... context_size=128, block_id=0,
... num_blocks=4)
>>> input_tensor = torch.randn(32, 10, 512) # (B, U, size)
>>> output, _ = self_attention(input_tensor)
Construct a SelfAttention object.
WKVLinearAttention function forward pass.
This method computes the forward pass of the WKVLinearAttention function, which applies a weighted key-value mechanism. It uses the provided time decay and time first vectors along with the key and value tensors to produce an output tensor.
- Parameters:
- ctx – The context object for storing information for backward pass.
- time_decay – Channel-wise time decay vector. Shape: (D_att).
- time_first – Channel-wise time first vector. Shape: (D_att).
- key – Key tensor. Shape: (B, U, D_att), where B is batch size and U is sequence length.
- value – Value tensor. Shape: (B, U, D_att).
- Returns: Weighted Key-Value tensor. Shape: (B, U, D_att).
- Return type: out
- Raises:AssertionError – If the length of the key exceeds the context size or if the product of batch size and dimension is not a multiple of the minimum of dimension or 32.
########### Examples
>>> time_decay = torch.tensor([0.1, 0.2])
>>> time_first = torch.tensor([0.5, 0.6])
>>> key = torch.randn(4, 10, 2) # Example with batch size 4, length 10
>>> value = torch.randn(4, 10, 2)
>>> output = WKVLinearAttention.apply(time_decay, time_first, key, value)
>>> print(output.shape)
torch.Size([4, 10, 2])
reset_parameters(size: int, attention_size: int, block_id: int, num_blocks: int) → None
Reset module parameters.
This method initializes the parameters of the SelfAttention module based on the given size and attention configuration. It calculates decay speeds and initializes time-related parameters that control the attention mechanism within the module.
- Parameters:
- size – Block size, representing the input/output dimension.
- attention_size – Attention hidden size, determining the number of attention heads.
- block_id – Block index, indicating the position of this block in a larger architecture.
- num_blocks – Total number of blocks in the architecture.
########### Examples
>>> attention = SelfAttention(size=128, attention_size=64,
... context_size=512, block_id=0,
... num_blocks=4)
>>> attention.reset_parameters(size=128, attention_size=64,
... block_id=0, num_blocks=4)
wkv_linear_attention(time_decay: Tensor, time_first: Tensor, key: Tensor, value: Tensor, state: Tuple[Tensor, Tensor, Tensor]) → Tuple[Tensor, Tuple[Tensor, Tensor, Tensor]]
Attention (time mixing) modules for RWKV block.
Based/Modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py.
Some variables are renamed according to https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py.
wkv_kernel
A global variable that holds the WKV CUDA kernel.
- Parameters:
- time_decay – Channel-wise time decay vector. (D_att)
- time_first – Channel-wise time first vector. (D_att)
- key – Key tensor. (B, U, D_att)
- value – Value tensor. (B, U, D_att)
- Returns: Weighted Key-Value tensor. (B, U, D_att)
- Return type: out
- Raises:AssertionError – If the key length exceeds the context size or if the batch size multiplied by dimension is not a multiple of the minimum dimension.
########### Examples
>>> time_decay = torch.tensor([0.1, 0.2, 0.3])
>>> time_first = torch.tensor([0.5, 0.6, 0.7])
>>> key = torch.rand(32, 10, 64) # (B, U, D_att)
>>> value = torch.rand(32, 10, 64) # (B, U, D_att)
>>> output = WKVLinearAttention.apply(time_decay, time_first, key, value)