espnet2.enh.layers.fasnet.FaSNet_TAC
espnet2.enh.layers.fasnet.FaSNet_TAC
class espnet2.enh.layers.fasnet.FaSNet_TAC(*args, **kwargs)
Bases: FaSNet_base
Single-stage FaSNet with Temporal Adaptive Control (TAC).
This class implements the FaSNet model as described in the paper: “FaSNet: Low-Latency Adaptive Beamforming for Multi-Microphone Audio Processing” by Y. Luo et al. The implementation utilizes the DPRNN (Dynamic-Partial-Recurrent Neural Network) architecture to estimate beamforming filters.
context
The context length for input signal processing.
- Type: int
filter_dim
The dimension of the filter used in the model.
- Type: int
all_BF
The beamforming module for filter estimation.
- Type:BF_module
encoder
Convolutional layer for waveform encoding.
- Type: nn.Conv1d
enc_LN
Group normalization layer for the encoder output.
Type: nn.GroupNorm
Parameters:
- enc_dim (int) – Dimension of the encoder input.
- feature_dim (int) – Dimension of the features.
- hidden_dim (int) – Dimension of the hidden layers.
- layer (int) – Number of layers in the DPRNN.
- segment_size (int , optional) – Size of segments for processing (default=24).
- nspk (int , optional) – Number of speakers (default=2).
- win_len (int , optional) – Length of the window for segmentation (default=16).
- context_len (int , optional) – Length of the context for processing (default=16).
- dropout (float , optional) – Dropout rate (default=0.0).
- sr (int , optional) – Sampling rate (default=16000).
####### Examples
>>> model_TAC = FaSNet_TAC(
... enc_dim=64,
... feature_dim=64,
... hidden_dim=128,
... layer=4,
... segment_size=50,
... nspk=2,
... win_len=4,
... context_len=16,
... sr=16000,
... )
>>> x = torch.rand(2, 4, 32000) # (batch, num_mic, length)
>>> num_mic = torch.tensor([3, 2])
>>> output = model_TAC(x, num_mic.long())
>>> print(output.shape) # Expected shape: (batch, nspk, length)
NOTE
This model assumes input data is in the shape of (batch, num_mic, length), where ‘num_mic’ is the number of microphones and ‘length’ is the duration of the input signal.
- Raises:
- AssertionError – If the dimensions of input data do not match the expected
- dimensions. –
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(input, num_mic)
Abstract forward function for the FaSNet_TAC model.
This method processes the input tensor through the FaSNet architecture and returns the beamformed signals. The input should be organized with dimensions representing the batch size, number of channels, and the sequence length. The method also handles the number of microphones used for each input.
- Parameters:
- input (torch.Tensor) – Input tensor of shape (batch, max_num_ch, T), where batch is the batch size, max_num_ch is the maximum number of channels (microphones), and T is the length of the input sequence.
- num_mic (torch.Tensor) – A tensor of shape (batch,) indicating the number of channels for each input. A value of zero indicates a fixed geometry configuration.
- Returns: The beamformed output signal of shape (B, nspk, T), where B is the batch size, nspk is the number of speakers, and T is the length of the output signal.
- Return type: torch.Tensor
####### Examples
>>> model = FaSNet_TAC(enc_dim=64, feature_dim=64, hidden_dim=128,
... layer=4, segment_size=50, nspk=2,
... win_len=4, context_len=16, sr=16000)
>>> input_tensor = torch.rand(2, 4, 32000) # (batch, num_mic, length)
>>> num_mic = torch.tensor([3, 2]) # Example number of microphones
>>> output = model(input_tensor, num_mic)
>>> print(output.shape) # Expected shape: (2, 2, length)
NOTE
The input tensor should be prepared in accordance with the expected input format. Ensure that the number of channels specified in num_mic corresponds to the actual number of microphones used in the input tensor.