espnet2.enh.separator.tfgridnetv3_separator.GridNetV3Block
espnet2.enh.separator.tfgridnetv3_separator.GridNetV3Block
class espnet2.enh.separator.tfgridnetv3_separator.GridNetV3Block(emb_dim, emb_ks, emb_hs, hidden_channels, n_head=4, qk_output_channel=4, activation='prelu', eps=1e-05)
Bases: Module
GridNetV3 Block for processing audio features.
This class implements a block of the GridNetV3 architecture, which is designed for audio signal processing. It utilizes intra- and inter- recurrent neural networks (RNNs) with attention mechanisms for enhanced feature extraction.
emb_dim
The embedding dimension.
- Type: int
emb_ks
Kernel size for embedding.
- Type: int
emb_hs
Hop size for embedding.
- Type: int
n_head
Number of heads in the attention mechanism.
Type: int
Parameters:
- emb_dim (int) – The embedding dimension.
- emb_ks (int) – Kernel size for embedding.
- emb_hs (int) – Hop size for embedding.
- hidden_channels (int) – Number of hidden channels in LSTM.
- n_head (int , optional) – Number of heads in the attention mechanism. Defaults to 4.
- qk_output_channel (int , optional) – Output channels of point-wise conv2d for key and query. Defaults to 4.
- activation (str , optional) – Activation function to use, defaults to “prelu”.
- eps (float , optional) – Small value for numerical stability in normalization layers. Defaults to 1e-5.
Raises:AssertionError – If the activation function is not “prelu”.
####### Examples
>>> block = GridNetV3Block(emb_dim=64, emb_ks=3, emb_hs=1,
... hidden_channels=128)
>>> x = torch.randn(32, 192, 100, 50) # Example input
>>> output = block(x)
>>> print(output.shape) # Output shape should match input shapeInitialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Perform the forward pass of the TFGridNetV3 model.
This method processes the input multi-channel audio tensor and applies the model architecture to separate the sources. It takes a batch of audio signals, applies convolutional layers, and processes them through multiple GridNetV3 blocks before returning the enhanced audio signals.
Parameters:
- input (torch.Tensor) – Batched multi-channel audio tensor with M audio channels and N samples shaped as [B, T, F].
- ilens (torch.Tensor) – Input lengths for each batch element shaped as [B].
- additional (Dict or None) – Additional data, currently unused in this model.
Returns: A list of length n_srcs containing : mono audio tensors shaped as [(B, T), …] with T samples each.
ilens (torch.Tensor): Input lengths shaped as (B,). additional (OrderedDict): The additional data returned in the output,
currently unused in this model.
Return type: enhanced (List[torch.Tensor])
####### Examples
>>> model = TFGridNetV3(n_srcs=2, n_imics=1)
>>> input_tensor = torch.randn(4, 256, 2) # 4 batches, 256 time steps, 2 channels
>>> ilens = torch.tensor([256, 256, 256, 256]) # lengths for each input
>>> enhanced, lengths, _ = model(input_tensor, ilens)NOTE
Ensure that the input tensor is normalized as described in the model’s notes for optimal performance. The model works best with variance normalized mixture input and target signals.
- Raises:
- AssertionError – If the input tensor does not have the expected shape
- or if the input tensor is not a single-channel mixture. –
