espnet2.enh.separator.tfgridnet_separator.GridNetBlock
espnet2.enh.separator.tfgridnet_separator.GridNetBlock
class espnet2.enh.separator.tfgridnet_separator.GridNetBlock(emb_dim, emb_ks, emb_hs, n_freqs, hidden_channels, n_head=4, approx_qk_dim=512, activation='prelu', eps=1e-05)
Bases: Module
A block in the TFGridNet architecture for processing audio features.
This class implements a single GridNetBlock, which consists of intra and inter temporal processing layers, followed by an attention mechanism. It applies LSTM layers for sequential data processing and utilizes convolutional layers for feature transformations.
emb_dim
The embedding dimension for the input features.
- Type: int
emb_ks
The kernel size for convolution operations.
- Type: int
emb_hs
The hop size for convolution operations.
- Type: int
n_head
The number of attention heads in the attention mechanism.
Type: int
Parameters:
- emb_dim (int) – The dimensionality of the input features.
- emb_ks (int) – The kernel size for convolution operations.
- emb_hs (int) – The hop size for convolution operations.
- n_freqs (int) – The number of frequency bins.
- hidden_channels (int) – The number of hidden channels in LSTM.
- n_head (int , optional) – The number of heads in self-attention. Defaults to 4.
- approx_qk_dim (int , optional) – Approximate dimension for key and value tensors. Defaults to 512.
- activation (str , optional) – The activation function to use. Defaults to ‘prelu’.
- eps (float , optional) – Small value for numerical stability in normalization. Defaults to 1e-5.
####### Examples
>>> block = GridNetBlock(emb_dim=64, emb_ks=3, emb_hs=1, n_freqs=128,
... hidden_channels=128, n_head=4)
>>> input_tensor = torch.randn(8, 128, 16, 32) # Example input
>>> output_tensor = block(input_tensor)
>>> output_tensor.shape
torch.Size([8, 64, 16, 32]) # Output shape after processing
- Raises:ValueError – If the input tensor does not have 4 dimensions.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Forward pass for the GridNetBlock.
This method processes the input audio tensor through the model, performing operations such as normalization, encoding, and applying the GridNetBlock layers.
- Parameters:
- input (torch.Tensor) – Batched multi-channel audio tensor with M audio channels and N samples, shaped as [B, N, M].
- ilens (torch.Tensor) – Input lengths, shaped as [B].
- additional (Dict or None) – Other data, currently unused in this model.
- Returns:
- enhanced (List[torch.Tensor]): A list of length n_srcs, containing mono audio tensors with T samples each.
- ilens (torch.Tensor): The input lengths, shaped as [B].
- additional (OrderedDict): Other data, currently unused in this model, returned in the output.
- Return type: Tuple[List[torch.Tensor], torch.Tensor, OrderedDict]
####### Examples
>>> model = TFGridNet(n_srcs=2, n_fft=128)
>>> input_tensor = torch.randn(4, 16000, 1) # [B, N, M]
>>> ilens = torch.tensor([16000] * 4) # Input lengths
>>> enhanced, ilens_out, _ = model(input_tensor, ilens)
NOTE
The model works best when trained with variance normalized mixture input and target. Normalize input by dividing with torch.std(mixture, (1, 2)) and do the same for the target signals. This is encouraged when not using scale-invariant loss functions such as SI-SDR.