espnet2.asr_transducer.decoder.modules.rwkv.attention.load_wkv_kernel
Less than 1 minute
espnet2.asr_transducer.decoder.modules.rwkv.attention.load_wkv_kernel
espnet2.asr_transducer.decoder.modules.rwkv.attention.load_wkv_kernel(context_size: int) → None
Load WKV CUDA kernel.
This function loads the WKV (Weighted Key-Value) CUDA kernel for efficient computation in the RWKV model. The kernel is loaded using the PyTorch C++ extension loader and requires CUDA support.
- Parameters:context_size – The context size to be used by the WKV kernel. It determines the maximum length of the input sequences that can be processed.
- Raises:ImportError – If the Ninja package is not installed or if CUDA is not available.
NOTE
Ensure that the ‘ninja’ package is installed in your Python environment to load the WKV kernel. You can install it via pip: pip install ninja.
Examples
To load the WKV kernel with a specific context size, you can use the following code:
python load_wkv_kernel(context_size=128)
This will prepare the WKV kernel for usage in further computations related to the RWKV model.