espnet2.speechlm.net_utils.install_kv_cache_hook
Less than 1 minute
espnet2.speechlm.net_utils.install_kv_cache_hook
espnet2.speechlm.net_utils.install_kv_cache_hook(model, cache)
Install key-value cache hooks for the specified model layers.
This function registers hooks on the MultiHeadAttention layers of the provided model. The hooks will save the key-value outputs to a cache dictionary during the forward pass, allowing for efficient reuse of these outputs in subsequent steps, which is particularly useful for transformer models during autoregressive generation.
- Parameters:
- model (torch.nn.Module) – The model containing the layers on which to install the hooks.
- cache (dict , optional) – A dictionary to store the cached key-value outputs. If None, an empty dictionary will be created.
- Returns: A tuple containing: : - cache (dict): The updated cache dictionary with stored key-value outputs.
- hooks (list): A list of the registered hooks for cleanup if needed.
- Return type: Tuple[dict, list]
Examples
>>> from espnet2.speechlm.module.transformer import TransformerModel
>>> model = TransformerModel(...)
>>> cache = {}
>>> cache, hooks = install_kv_cache_hook(model, cache)
NOTE
Make sure to clean up the hooks after use to prevent memory leaks by calling remove() on each hook in the hooks list.