espnet2.main_funcs.calculate_all_attentions.calculate_all_attentions
Less than 1 minute
espnet2.main_funcs.calculate_all_attentions.calculate_all_attentions
espnet2.main_funcs.calculate_all_attentions.calculate_all_attentions(model: AbsESPnetModel, batch: Dict[str, Tensor]) → Dict[str, List[Tensor]]
Derive the outputs from all attention layers in a given model.
This function registers forward hooks on the attention layers of the provided model and collects the attention outputs during a forward pass using a given batch of data. The collected attention outputs are returned in a structured dictionary format.
- Parameters:
- model (AbsESPnetModel) – The ESPnet model containing attention layers.
- batch (Dict *[*str , torch.Tensor ]) – A dictionary of input tensors for the model’s forward pass. It should include all necessary input features and lengths.
- Returns: A dictionary where keys are the names of : the attention layers and values are lists of tensors representing the attention outputs. The shape of each tensor will depend on the specific attention mechanism, typically in the form of key_names x batch x (D1, D2, …).
- Return type: Dict[str, List[torch.Tensor]]
Examples
>>> model = MyESPnetModel()
>>> batch = {
... 'input': torch.randn(2, 10, 20),
... 'input_lengths': torch.tensor([10, 10]),
... }
>>> attentions = calculate_all_attentions(model, batch)
>>> print(attentions.keys())
dict_keys(['layer1_att', 'layer2_att', ...])
NOTE
The function runs in evaluation mode (with torch.no_grad()) to prevent gradients from being calculated, which can improve performance during inference.