espnet2.enh.layers.fasnet.BF_module
espnet2.enh.layers.fasnet.BF_module
class espnet2.enh.layers.fasnet.BF_module(input_dim, feature_dim, hidden_dim, output_dim, num_spk=2, layer=4, segment_size=100, bidirectional=True, dropout=0.0, fasnet_type='ifasnet')
Bases: Module
Beamforming module for FaSNet.
This module implements the beamforming filter estimation using a Dual-Path Recurrent Neural Network (DPRNN) as described in: Y. Luo, et al. “FaSNet: Low-Latency Adaptive Beamforming for Multi-Microphone Audio Processing”. The implementation is based on the repository: https://github.com/yluo42/TAC and is licensed under CC BY-NC-SA 3.0 US.
input_dim
The dimension of the input features.
- Type: int
feature_dim
The dimension of the feature representation.
- Type: int
hidden_dim
The dimension of the hidden layers.
- Type: int
output_dim
The dimension of the output features.
- Type: int
num_spk
The number of speakers (default is 2).
- Type: int
layer
The number of layers in the DPRNN (default is 4).
- Type: int
segment_size
The size of the segments to process (default is 100).
- Type: int
bidirectional
Whether to use a bidirectional RNN (default is True).
- Type: bool
dropout
The dropout rate (default is 0.0).
- Type: float
fasnet_type
Type of FaSNet to use (‘fasnet’ or ‘ifasnet’).
Type: str
Parameters:
- input_dim (int) – Dimension of the input features.
- feature_dim (int) – Dimension of the feature representation.
- hidden_dim (int) – Dimension of the hidden layers.
- output_dim (int) – Dimension of the output features.
- num_spk (int , optional) – Number of speakers. Defaults to 2.
- layer (int , optional) – Number of layers in the DPRNN. Defaults to 4.
- segment_size (int , optional) – Size of the segments to process. Defaults to 100.
- bidirectional (bool , optional) – Whether to use a bidirectional RNN. Defaults to True.
- dropout (float , optional) – The dropout rate. Defaults to 0.0.
- fasnet_type (str , optional) – Type of FaSNet to use (‘fasnet’ or ‘ifasnet’). Defaults to ‘ifasnet’.
Returns: The estimated beamforming filter of shape (B, ch, nspk, L, K) for ‘ifasnet’ and (B, ch, nspk, K, L) for ‘fasnet’, where B is batch size, ch is number of channels, nspk is number of speakers, L is the segment length, and K is the output dimension.
Return type: torch.Tensor
Raises:AssertionError – If fasnet_type is not ‘fasnet’ or ‘ifasnet’.
####### Examples
>>> bf = BF_module(input_dim=64, feature_dim=64, hidden_dim=128,
... output_dim=64, num_spk=2, layer=4,
... segment_size=100, bidirectional=True,
... dropout=0.0, fasnet_type='fasnet')
>>> input_tensor = torch.rand(2, 4, 64, 320) # (B, ch, N, T)
>>> num_mic = torch.tensor([3, 2]) # number of microphones
>>> output = bf(input_tensor, num_mic)
>>> print(output.shape) # Output shape will depend on fasnet_type
NOTE
The module requires the dprnn layer from espnet2.enh.layers.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(input, num_mic)
Forward pass for the beamforming filter estimation.
This method processes the input tensor, which contains audio signals from multiple microphones, and computes the beamforming filters for each speaker based on the provided input and the number of active microphones. The input tensor is reshaped and passed through the necessary layers to produce the output beamforming filters.
- Parameters:
- input (torch.Tensor) – Input tensor of shape (B, ch, N, T) where: B is the batch size, ch is the number of channels (microphones), N is the number of features, T is the sequence length.
- num_mic (torch.Tensor) – Tensor of shape (B,) indicating the number of channels for each input. A value of zero indicates a fixed geometry configuration.
- Returns: Output beamforming filters of shape (B, ch, nspk, K, L) for ‘ifasnet’ or (B, ch, nspk, L, N) for ‘fasnet’, where:
nspk is the number of speakers, K is the output dimension, L is the segment length.
- Return type: torch.Tensor
####### Examples
>>> model = BF_module(input_dim=4, feature_dim=64, hidden_dim=128,
... output_dim=32, num_spk=2)
>>> input_tensor = torch.randn(2, 4, 64, 320) # (batch, ch, N, T)
>>> num_mic = torch.tensor([3, 2]) # Number of active microphones
>>> output_filters = model.forward(input_tensor, num_mic)
>>> print(output_filters.shape) # Should output shape based on nspk
NOTE
The function assumes that the input tensor has been properly formatted and that the model has been initialized with valid parameters.
- Raises:
- AssertionError – If the shape of input tensor does not match the
- expected dimensions or if num_mic tensor is not properly defined. –