espnet2.asr_transducer.decoder.modules.rwkv.attention.WKVLinearAttention
espnet2.asr_transducer.decoder.modules.rwkv.attention.WKVLinearAttention
class espnet2.asr_transducer.decoder.modules.rwkv.attention.WKVLinearAttention(*args, **kwargs)
Bases: Function
WKVLinearAttention function definition.
This class implements a linear attention mechanism based on the RWKV model, which allows for efficient computation of attention scores. The forward and backward methods utilize custom CUDA kernels for performance.
The implementation is based on the RWKV architecture and is inspired by previous works available in the RWKV-LM GitHub repository and Hugging Face Transformers library.
None
- Parameters:
- ctx – The context object used to store information for backward pass.
- time_decay – Channel-wise time decay vector. Shape: (D_att).
- time_first – Channel-wise time first vector. Shape: (D_att).
- key – Key tensor. Shape: (B, U, D_att).
- value – Value tensor. Shape: (B, U, D_att).
- Returns: Weighted Key-Value tensor. Shape: (B, U, D_att).
- Return type: out
- Yields: None
- Raises:AssertionError – If the length of key exceeds the context size or if the batch size multiplied by dimension is not a multiple of the minimum of dimension and 32.
######### Examples
>>> time_decay = torch.randn(D_att)
>>> time_first = torch.randn(D_att)
>>> key = torch.randn(B, U, D_att)
>>> value = torch.randn(B, U, D_att)
>>> output = WKVLinearAttention.apply(time_decay, time_first, key, value)
static backward(ctx, grad_output: Tensor) → Tuple[Tensor, Tensor, Tensor, Tensor]
WKVLinearAttention function backward pass.
This method computes the gradients of the inputs with respect to the output of the forward pass. It uses the saved tensors from the forward context to calculate the gradients for the time decay, time first, key, and value tensors.
Parameters:
- ctx – Context object containing saved tensors from forward pass.
- grad_output – Output gradient. Shape: (B, U, D_att)
Returns: Gradient for channel-wise time decay vector. : Shape: (D_att)
grad_time_first: Gradient for channel-wise time first vector. : Shape: (D_att)
grad_key: Gradient for key tensor. Shape: (B, U, D_att) grad_value: Gradient for value tensor. Shape: (B, U, D_att)
Return type: grad_time_decay
######### Examples
>>> grad_output = torch.randn(2, 3, 4) # Example gradient output
>>> grad_time_decay, grad_time_first, grad_key, grad_value = (
... WKVLinearAttention.backward(ctx, grad_output)
... )
NOTE
Ensure that the context contains the necessary tensors saved during the forward pass, as they are crucial for computing the gradients.
- Raises:RuntimeError – If the context does not contain the expected tensors.
static forward(ctx, time_decay: Tensor, time_first: Tensor, key: Tensor, value: tensor) → Tensor
WKVLinearAttention function forward pass.
This method computes the forward pass for the WKV linear attention mechanism, which involves applying time decay and time first vectors to key and value tensors to produce a weighted output tensor.
- Parameters:
- ctx – The context object to store information for the backward pass.
- time_decay – Channel-wise time decay vector of shape (D_att).
- time_first – Channel-wise time first vector of shape (D_att).
- key – Key tensor of shape (B, U, D_att).
- value – Value tensor of shape (B, U, D_att).
- Returns: Weighted Key-Value tensor of shape (B, U, D_att).
- Return type: out
- Raises:
- AssertionError – If the length of the key tensor exceeds the context size
- or if the product of batch size and dimension is not a multiple of –
- the minimum of dimension or 32. –
######### Examples
>>> time_decay = torch.tensor([0.1, 0.2])
>>> time_first = torch.tensor([0.3, 0.4])
>>> key = torch.rand(2, 5, 2) # Example with batch size 2, length 5, D_att 2
>>> value = torch.rand(2, 5, 2)
>>> output = WKVLinearAttention.apply(time_decay, time_first, key, value)
>>> print(output.shape) # Output shape will be (2, 5, 2)
NOTE
Ensure that the WKV kernel is loaded before calling this function.