espnet2.enh.separator.tfgridnetv2_separator.GridNetV2Block
espnet2.enh.separator.tfgridnetv2_separator.GridNetV2Block
class espnet2.enh.separator.tfgridnetv2_separator.GridNetV2Block(emb_dim, emb_ks, emb_hs, n_freqs, hidden_channels, n_head=4, approx_qk_dim=512, activation='prelu', eps=1e-05)
Bases: Module
GridNetV2Block is a neural network block used within the TFGridNetV2 model. It
is designed to process input features through intra- and inter-block recurrent neural networks (RNNs) and multi-head self-attention mechanisms.
emb_dim
The embedding dimension for input features.
- Type: int
emb_ks
Kernel size for convolutions and unfolding.
- Type: int
emb_hs
Hop size for convolutions and deconvolutions.
- Type: int
n_head
Number of heads in the multi-head attention mechanism.
Type: int
Parameters:
- emb_dim (int) – Dimension of the input embeddings.
- emb_ks (int) – Kernel size for unfolding and deconvolution.
- emb_hs (int) – Hop size for unfolding and deconvolution.
- n_freqs (int) – Number of frequency bins in the input features.
- hidden_channels (int) – Number of hidden channels in the RNN.
- n_head (int , optional) – Number of attention heads (default is 4).
- approx_qk_dim (int , optional) – Approximate dimension for Q and K tensors in attention (default is 512).
- activation (str , optional) – Activation function to use (default is “prelu”).
- eps (float , optional) – Small value for numerical stability in normalization layers (default is 1e-5).
####### Examples
Create a GridNetV2Block instance
block = GridNetV2Block(
emb_dim=48, emb_ks=4, emb_hs=1, n_freqs=65, hidden_channels=192
)
Forward pass through the block
input_tensor = torch.randn(10, 48 * 4, 100, 50) # Example input output_tensor = block(input_tensor)
- Raises:AssertionError – If the specified activation function is not “prelu”.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Forward pass of the TFGridNetV2 model.
This method processes the input audio tensor through the TFGridNetV2 architecture, performing speech separation for the specified number of sources. It applies normalization, convolution, and RNN operations, followed by a transposed convolution to produce the output signals.
Parameters:
- input (torch.Tensor) – Batched multi-channel audio tensor with M audio channels and N samples of shape [B, N, M].
- ilens (torch.Tensor) – Input lengths for each batch element of shape [B].
- additional (Dict or None) – Additional data, currently unused in this model.
Returns: A list of length n_srcs containing mono audio tensors : of shape [(B, T), …] where T is the number of samples.
ilens (torch.Tensor): Input lengths of shape (B,). additional (OrderedDict): Additional data, currently unused in
this model, returned as part of the output.
Return type: enhanced (List[Union(torch.Tensor)])
####### Examples
>>> model = TFGridNetV2(n_srcs=2)
>>> input_tensor = torch.randn(8, 16000, 1) # [B, N, M]
>>> ilens = torch.tensor([16000] * 8) # Lengths for each batch
>>> enhanced, lengths, _ = model(input_tensor, ilens)
NOTE
It is recommended to normalize the input and target signals using RMS normalization as follows:
std_
= torch.std(mix, (1, 2)) mix = mix /
std_
tgt = tgt /
std_
- Raises:
- ValueError – If the input tensor does not have the correct
- dimensions or if n_imics is not equal to 1 when input has –
- two dimensions. –