espnet2.enh.layers.dnn_beamformer.AttentionReference
espnet2.enh.layers.dnn_beamformer.AttentionReference
class espnet2.enh.layers.dnn_beamformer.AttentionReference(bidim, att_dim, eps=1e-06)
Bases: Module
Attention-based reference for estimating spatial filters.
This module utilizes an attention mechanism to generate a reference vector based on the power spectral density (PSD) input. It computes the attention weights from the input features and outputs a normalized reference vector for use in beamforming applications.
mlp_psd
Linear layer for mapping input PSD to attention features.
- Type: torch.nn.Linear
gvec
Linear layer for generating the attention weights.
- Type: torch.nn.Linear
eps
Small constant to avoid division by zero in calculations.
Type: float
Parameters:
- bidim (int) – Input feature dimension of the PSD.
- att_dim (int) – Dimension of the attention features.
- eps (float , optional) – Small constant added for numerical stability (default: 1e-6).
Returns: A tuple containing the normalized attention vector and the input lengths.
Return type: Tuple[torch.Tensor, torch.LongTensor]
####### Examples
>>> attention_ref = AttentionReference(bidim=320, att_dim=128)
>>> psd_input = torch.randn(10, 256, 4, 4) # Example PSD input
>>> ilens = torch.tensor([256] * 10) # Example input lengths
>>> u, ilens_out = attention_ref(psd_input, ilens)
>>> print(u.shape) # Should output (10, 4)
NOTE
The input psd_in should have a shape of (B, F, C, C), where B is the batch size, F is the number of frequency bins, and C is the number of channels.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(psd_in: Tensor | ComplexTensor, ilens: LongTensor, scaling: float = 2.0) → Tuple[Tensor, LongTensor]
DNN_Beamformer forward function.
This method performs the forward pass of the DNN beamformer, applying beamforming to the input data based on the estimated masks or provided oracle masks. The method takes into account various parameters such as input lengths and power spectral densities.
Notation: : B: Batch C: Channel T: Time or Sequence length F: Frequency
- Parameters:
- data (torch.complex64/ComplexTensor) – Input data of shape (B, T, C, F).
- ilens (torch.Tensor) – Input lengths of shape (B,).
- powers (List *[*torch.Tensor ] or None) – Optional. Used for wMPDR or WPD of shape (B, F, T).
- oracle_masks (List *[*torch.Tensor ] or None) – Optional. Oracle masks of shape (B, F, C, T). If provided, these masks will be used instead of self.mask.
- Returns: Tuple[Union[torch.Tensor, ComplexTensor], torch.LongTensor, torch.Tensor]: A tuple containing:
- enhanced (torch.complex64/ComplexTensor): Enhanced output of shape (B, T, F).
- ilens (torch.Tensor): Input lengths of shape (B,).
- masks (torch.Tensor): Estimated masks of shape (B, T, C, F).
####### Examples
>>> beamformer = DNN_Beamformer(...)
>>> data = torch.randn(2, 100, 3, 64, dtype=torch.complex64)
>>> ilens = torch.tensor([100, 90])
>>> enhanced, ilens, masks = beamformer(data, ilens)
NOTE
The input data can either be a standard PyTorch tensor or a ComplexTensor. The method performs necessary transformations to the input data for further processing.
- Raises:
- ValueError – If the provided beamformer_type is not supported
- or if any other invalid input is detected. –