espnet2.enh.separator.tfgridnetv3_separator.GridNetV3Block
espnet2.enh.separator.tfgridnetv3_separator.GridNetV3Block
class espnet2.enh.separator.tfgridnetv3_separator.GridNetV3Block(emb_dim, emb_ks, emb_hs, hidden_channels, n_head=4, qk_output_channel=4, activation='prelu', eps=1e-05)
Bases: Module
GridNetV3 Block for processing audio features.
This class implements a block of the GridNetV3 architecture, which is designed for audio signal processing. It utilizes intra- and inter- recurrent neural networks (RNNs) with attention mechanisms for enhanced feature extraction.
emb_dim
The embedding dimension.
- Type: int
emb_ks
Kernel size for embedding.
- Type: int
emb_hs
Hop size for embedding.
- Type: int
n_head
Number of heads in the attention mechanism.
Type: int
Parameters:
- emb_dim (int) – The embedding dimension.
- emb_ks (int) – Kernel size for embedding.
- emb_hs (int) – Hop size for embedding.
- hidden_channels (int) – Number of hidden channels in LSTM.
- n_head (int , optional) – Number of heads in the attention mechanism. Defaults to 4.
- qk_output_channel (int , optional) – Output channels of point-wise conv2d for key and query. Defaults to 4.
- activation (str , optional) – Activation function to use, defaults to “prelu”.
- eps (float , optional) – Small value for numerical stability in normalization layers. Defaults to 1e-5.
Raises:AssertionError – If the activation function is not “prelu”.
####### Examples
>>> block = GridNetV3Block(emb_dim=64, emb_ks=3, emb_hs=1,
... hidden_channels=128)
>>> x = torch.randn(32, 192, 100, 50) # Example input
>>> output = block(x)
>>> print(output.shape) # Output shape should match input shape
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Perform the forward pass of the TFGridNetV3 model.
This method processes the input multi-channel audio tensor and applies the model architecture to separate the sources. It takes a batch of audio signals, applies convolutional layers, and processes them through multiple GridNetV3 blocks before returning the enhanced audio signals.
Parameters:
- input (torch.Tensor) – Batched multi-channel audio tensor with M audio channels and N samples shaped as [B, T, F].
- ilens (torch.Tensor) – Input lengths for each batch element shaped as [B].
- additional (Dict or None) – Additional data, currently unused in this model.
Returns: A list of length n_srcs containing : mono audio tensors shaped as [(B, T), …] with T samples each.
ilens (torch.Tensor): Input lengths shaped as (B,). additional (OrderedDict): The additional data returned in the output,
currently unused in this model.
Return type: enhanced (List[torch.Tensor])
####### Examples
>>> model = TFGridNetV3(n_srcs=2, n_imics=1)
>>> input_tensor = torch.randn(4, 256, 2) # 4 batches, 256 time steps, 2 channels
>>> ilens = torch.tensor([256, 256, 256, 256]) # lengths for each input
>>> enhanced, lengths, _ = model(input_tensor, ilens)
NOTE
Ensure that the input tensor is normalized as described in the model’s notes for optimal performance. The model works best with variance normalized mixture input and target signals.
- Raises:
- AssertionError – If the input tensor does not have the expected shape
- or if the input tensor is not a single-channel mixture. –