espnet2.enh.layers.fasnet.FaSNet_base
espnet2.enh.layers.fasnet.FaSNet_base
class espnet2.enh.layers.fasnet.FaSNet_base(enc_dim, feature_dim, hidden_dim, layer, segment_size=24, nspk=2, win_len=16, context_len=16, dropout=0.0, sr=16000)
Bases: Module
Base module for FaSNet.
This class serves as the base for the FaSNet architecture, which is designed for low-latency adaptive beamforming in multi-microphone audio processing. It provides methods for signal segmentation and context extraction, as well as cosine similarity calculations between reference and target microphone signals.
win_len
Length of the window in milliseconds for segmentation.
- Type: int
window
Size of the window in samples.
- Type: int
stride
Stride size for segmentation.
- Type: int
sr
Sampling rate in Hz.
- Type: int
context_len
Length of the context in milliseconds.
- Type: int
dropout
Dropout rate for regularization.
- Type: float
enc_dim
Dimensionality of the encoder input.
- Type: int
feature_dim
Dimensionality of the features.
- Type: int
hidden_dim
Dimensionality of the hidden layers.
- Type: int
segment_size
Size of the segments for processing.
- Type: int
layer
Number of layers in the model.
- Type: int
num_spk
Number of speakers to be processed.
- Type: int
eps
Small constant to avoid division by zero.
Type: float
Parameters:
- enc_dim (int) – Dimensionality of the encoder input.
- feature_dim (int) – Dimensionality of the features.
- hidden_dim (int) – Dimensionality of the hidden layers.
- layer (int) – Number of layers in the model.
- segment_size (int , optional) – Size of the segments for processing. Default is 24.
- nspk (int , optional) – Number of speakers to be processed. Default is 2.
- win_len (int , optional) – Length of the window in milliseconds. Default is 16.
- context_len (int , optional) – Length of the context in milliseconds. Default is 16.
- dropout (float , optional) – Dropout rate for regularization. Default is 0.0.
- sr (int , optional) – Sampling rate in Hz. Default is 16000.
pad_input(input, window)
Zero-padding input according to window/stride size.
seg_signal_context(x, window, context)
Segment the signal into chunks with specific context.
signal_context(x, context)
Create a signal context function for the input signal.
seq_cos_sim(ref, target)
Compute cosine similarity between reference and target microphone signals.
forward(input, num_mic)
Abstract forward function to be implemented in derived classes.
############### Examples
Creating an instance of the FaSNet_base class
fasnet = FaSNet_base(enc_dim=64, feature_dim=64, hidden_dim=128,
layer=4, segment_size=50, nspk=2, win_len=4, context_len=16, sr=16000)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(input, num_mic)
Abstract forward function for FaSNet base model.
This method defines the forward pass for the FaSNet base model. It takes the input audio signal and the number of microphones as arguments. The expected shape of the input is (batch, max_num_ch, T), where ‘batch’ is the batch size, ‘max_num_ch’ is the maximum number of channels, and ‘T’ is the length of the audio signal. The ‘num_mic’ parameter is a tensor of shape (batch,) that indicates the number of channels for each input, where zero denotes a fixed geometry configuration.
- Parameters:
- input (torch.Tensor) – Input tensor of shape (batch, max_num_ch, T).
- num_mic (torch.Tensor) – Tensor of shape (batch,) indicating the number of channels for each input. Zero indicates fixed geometry.
- Returns: Output tensor from the forward pass, the shape of which depends on the specific implementation of the derived class.
- Return type: torch.Tensor
############### Examples
>>> model = FaSNet_TAC(enc_dim=64, feature_dim=64, hidden_dim=128,
... layer=4, segment_size=50, nspk=2,
... win_len=4, context_len=16, sr=16000)
>>> input_data = torch.rand(2, 4, 32000) # (batch, num_mic, length)
>>> num_mic = torch.tensor([3, 2]) # Number of active microphones
>>> output = model(input_data, num_mic)
>>> print(output.shape) # Shape depends on the implementation
pad_input(input, window)
Zero-padding input according to window/stride size.
This method pads the input tensor such that its length matches the required window size for processing. It adds padding at the beginning and end of the input signal as needed to ensure compatibility with the window and stride parameters.
stride
The stride size calculated from the window size.
Type: int
Parameters:
- input (torch.Tensor) – The input tensor of shape (B, nmic, nsample), where B is the batch size, nmic is the number of microphones, and nsample is the number of samples.
- window (int) – The window size used for padding.
Returns: A tuple containing the padded input tensor and the number of samples added as padding at the end.
Return type: Tuple[torch.Tensor, int]
############### Examples
>>> input_tensor = torch.randn(2, 4, 320) # (batch, num_mic, length)
>>> padded_input, padding_rest = pad_input(input_tensor, window=64)
>>> print(padded_input.shape) # Output shape may vary based on input
>>> print(padding_rest) # Number of samples added as padding
######## NOTE The padding is performed using zero values, which may impact signal processing tasks if not handled appropriately downstream.
seg_signal_context(x, window, context)
Segmenting the signal into chunks with specific context.
This method segments the input signal x into overlapping chunks of a specified window size, while also incorporating additional context frames before and after each chunk. The context allows for better handling of signal dependencies in subsequent processing steps.
- Parameters:
- x (torch.Tensor) – Input signal of shape (B, ch, T), where B is the batch size, ch is the number of channels, and T is the length of the signal.
- window (int) – The size of each segment/chunk.
- context (int) – The number of context frames to include before and after each chunk.
- Returns: A tuple containing: : - center_frame (torch.Tensor): The center frames of the chunks of shape (B, ch, L, window), where L is the number of chunks.
- chunks (torch.Tensor): The complete set of chunks of shape (B, ch, L, 2 * context + window), including the context frames.
- rest (int): The number of remaining samples after chunking.
- Return type: tuple
############### Examples
>>> x = torch.rand(2, 4, 320) # Batch of 2, 4 channels, 320 samples
>>> window = 16
>>> context = 4
>>> center_frame, chunks, rest = seg_signal_context(x, window, context)
>>> center_frame.shape
torch.Size([2, 4, 21, 16]) # 21 chunks of 16 samples each
######## NOTE The input signal is padded to ensure that it can be segmented correctly based on the specified window size and context.
seq_cos_sim(ref, target)
Computes the cosine similarity between the reference microphones and the
target microphones.
This function takes two input tensors representing signals from different microphones and calculates the cosine similarity across their segments. It ensures that the input tensors have compatible dimensions for the calculation.
- Parameters:
- ref (torch.Tensor) – A tensor of shape (nmic1, L, seg1) representing the reference microphone signals.
- target (torch.Tensor) – A tensor of shape (nmic2, L, seg2) representing the target microphone signals.
- Returns: A tensor of shape (larger_ch, L, seg1-seg2+1) containing : the cosine similarity values between the reference and target microphones.
- Return type: torch.Tensor
- Raises:AssertionError – If the lengths of the reference and target tensors do not match or if the reference tensor has fewer segments than the target tensor.
############### Examples
>>> ref = torch.rand(3, 100, 50) # 3 microphones, 100 length, 50 segments
>>> target = torch.rand(2, 100, 30) # 2 microphones, 100 length, 30 segments
>>> cos_sim = seq_cos_sim(ref, target)
>>> print(cos_sim.shape) # Output: torch.Size([3, 100, 21])
######## NOTE This function uses the PyTorch library for tensor operations and requires that the input tensors be of type torch.Tensor.
signal_context(x, context)
Base module for FaSNet.
This class implements the base functionality for the FaSNet model, which is designed for low-latency adaptive beamforming for multi-microphone audio processing.
win_len
The length of the window for segmentation.
- Type: int
window
The window size in samples.
- Type: int
stride
The stride size for segmentation.
- Type: int
sr
The sample rate of the input audio.
- Type: int
context_len
The length of context to consider during processing.
- Type: int
dropout
The dropout rate.
- Type: float
enc_dim
The dimension of the encoder.
- Type: int
feature_dim
The dimension of the feature representation.
- Type: int
hidden_dim
The dimension of the hidden states.
- Type: int
segment_size
The size of segments for processing.
- Type: int
layer
The number of layers in the model.
- Type: int
num_spk
The number of speakers.
- Type: int
eps
A small constant to prevent division by zero.
Type: float
Parameters:
- enc_dim (int) – The encoder dimension.
- feature_dim (int) – The feature dimension.
- hidden_dim (int) – The hidden dimension.
- layer (int) – The number of layers.
- segment_size (int , optional) – Size of segments for processing. Defaults to 24.
- nspk (int , optional) – Number of speakers. Defaults to 2.
- win_len (int , optional) – Window length in milliseconds. Defaults to 16.
- context_len (int , optional) – Context length in milliseconds. Defaults to 16.
- dropout (float , optional) – Dropout rate. Defaults to 0.0.
- sr (int , optional) – Sample rate. Defaults to 16000.
pad_input(input, window)
Zero-padding input according to window/stride size.
seg_signal_context(x, window, context)
Segmenting the signal into chunks with specific context.
signal_context(x, context)
Signal context function that segments the signal into chunks.
seq_cos_sim(ref, target)
Computes cosine similarity between reference and target signals.
forward(input, num_mic)
Abstract forward function that processes the input.
############### Examples
Example of creating a FaSNet_base model
model = FaSNet_base(enc_dim=64, feature_dim=64, hidden_dim=128, layer=4)
Example of padding input
padded_input, rest = model.pad_input(torch.randn(2, 4, 32000), window=512)
Example of segmenting signal context
center_frame, chunks, rest = model.seg_signal_context(torch.randn(2, 4, 32000),
window=512, context=16)
Example of computing cosine similarity
cos_sim = model.seq_cos_sim(torch.randn(3, 512, 100), torch.randn(2, 512, 100))
######## NOTE The model is designed for processing audio signals and may require specific configurations based on the application.