espnet2.s2st.aux_attention.abs_aux_attention.AbsS2STAuxAttention
espnet2.s2st.aux_attention.abs_aux_attention.AbsS2STAuxAttention
class espnet2.s2st.aux_attention.abs_aux_attention.AbsS2STAuxAttention(*args, **kwargs)
Bases: Module
, ABC
Base class for all S2ST auxiliary attention modules.
This class serves as an abstract base class for implementing different types of auxiliary attention mechanisms used in Speech-to-Speech Translation (S2ST) models. Subclasses should implement the forward method to define their specific attention computation.
For more details on the underlying principles and implementation, refer to the paper: https://arxiv.org/abs/2107.08661
name
The name of the attention module, which will be used as
- Type: str
the key in the reporter. This is to be implemented in the subclasses.
forward()
Computes the attention weights and returns a tensor.
- Raises:
- NotImplementedError – If the forward method is not implemented in
- a subclass. –
####### Examples
class MyAuxAttention(AbsS2STAuxAttention): : def forward(self): : # Implement specific attention mechanism return torch.tensor([1.0, 2.0, 3.0])
attention = MyAuxAttention() output = attention.forward()
Initialize internal Module state, shared by both nn.Module and ScriptModule.
abstract forward() → Tensor
Executes the forward pass of the auxiliary attention module.
This method is intended to be implemented by subclasses of the AbsS2STAuxAttention class. It should define how the forward computation is carried out and return the resulting tensor.
- Returns: The output tensor from the forward computation, which should have the shape of (batch).
- Return type: torch.Tensor
- Raises:
- NotImplementedError – If the method is not implemented in a
- subclass. –
####### Examples
Example usage in a subclass:
class MyAttention(AbsS2STAuxAttention):
def forward(self) -> torch.Tensor: : # Implement the forward logic here return torch.tensor([1.0, 2.0, 3.0]) # Example output
attention = MyAttention() output = attention.forward() print(output) # Output: tensor([1., 2., 3.])
property name : str