espnet2.enh.separator.tfgridnetv3_separator.TFGridNetV3
espnet2.enh.separator.tfgridnetv3_separator.TFGridNetV3
class espnet2.enh.separator.tfgridnetv3_separator.TFGridNetV3(input_dim, n_srcs=2, n_imics=1, n_layers=6, lstm_hidden_units=192, attn_n_head=4, attn_qk_output_channel=4, emb_dim=48, emb_ks=4, emb_hs=1, activation='prelu', eps=1e-05)
Bases: AbsSeparator
TFGridNetV3 is an advanced model for offline time-frequency (TF) audio source
separation, extending the capabilities of TFGridNetV2. It is designed to be sampling-frequency-independent (SFI) by ensuring that all layers are independent of the input’s time and frequency dimensions.
References:
- Z.-Q. Wang, S. Cornell, S. Choi, Y. Lee, B.-Y. Kim, and S. Watanabe,
“TF-GridNet: Integrating Full- and Sub-Band Modeling for Speech Separation”, in TASLP, 2023.
- Z.-Q. Wang, S. Cornell, S. Choi, Y. Lee, B.-Y. Kim, and S. Watanabe, “TF-GridNet: Making Time-Frequency Domain Models Great Again for Monaural Speaker Separation”, in ICASSP, 2023.
Notes: This model performs optimally when trained with variance-normalized mixture inputs and targets. For a mixture tensor of shape [batch, samples, microphones], normalize it using:
std_
= std(mixture, (1, 2)) mixture = mixture /
std_
target = target /
std_
n_srcs
Number of output sources/speakers.
- Type: int
n_layers
Number of TFGridNetV3 blocks.
- Type: int
n_imics
Number of microphone channels (only fixed-array geometry supported).
Type: int
Parameters:
- input_dim (int) – Placeholder, not used.
- n_srcs (int) – Number of output sources/speakers (default: 2).
- n_fft (int) – STFT window size.
- stride (int) – STFT stride.
- window (str or None) – STFT window type, can be ‘hamming’, ‘hanning’, or None.
- n_imics (int) – Number of microphones channels (default: 1).
- n_layers (int) – Number of TFGridNetV3 blocks (default: 6).
- lstm_hidden_units (int) – Number of hidden units in LSTM (default: 192).
- attn_n_head (int) – Number of heads in self-attention (default: 4).
- attn_qk_output_channel (int) – Output channels for point-wise conv2d for getting key and query (default: 4).
- emb_dim (int) – Embedding dimension (default: 48).
- emb_ks (int) – Kernel size for unfolding and deconv1D (default: 4).
- emb_hs (int) – Hop size for unfolding and deconv1D (default: 1).
- activation (str) – Activation function to use in the model, can be any torch-supported activation (default: ‘prelu’).
- eps (float) – Small epsilon for normalization layers (default: 1.0e-5).
- use_builtin_complex (bool) – Whether to use built-in complex type or not.
####### Examples
Instantiate the model
model = TFGridNetV3(n_srcs=3, n_layers=4)
Prepare input tensor
input_tensor = torch.randn(8, 2, 512) # Example with 8 batches, 2 channels, 512 samples ilens = torch.tensor([512] * 8) # Input lengths for each batch
Forward pass
enhanced, ilens, additional = model(input_tensor, ilens)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(input: Tensor, ilens: Tensor, additional: Dict | None = None) → Tuple[List[Tensor], Tensor, OrderedDict]
Forward pass of the TFGridNetV3 model.
This method takes a batched multi-channel audio tensor as input and processes it through the model to produce enhanced audio signals for the specified number of sources (speakers).
Parameters:
- input (torch.Tensor) – Batched multi-channel audio tensor with M audio channels and N samples of shape [B, T, F].
- ilens (torch.Tensor) – Input lengths of shape [B].
- additional (Dict or None) – Other data, currently unused in this model.
Returns: A list of length n_srcs containing mono audio tensors : of shape [(B, T), …], where T is the number of samples.
ilens (torch.Tensor): : The input lengths, returned as shape (B,).
additional (OrderedDict): : Other data, currently unused in this model, returned as output.
Return type: enhanced (List[torch.Tensor])
####### Examples
>>> model = TFGridNetV3(n_srcs=2)
>>> input_tensor = torch.randn(4, 16000, 2) # 4 samples, 16000 time steps
>>> ilens = torch.tensor([16000, 16000, 16000, 16000]) # lengths
>>> enhanced, ilens_out, _ = model(input_tensor, ilens)
NOTE
This model works best when trained with variance normalized mixture input and target. Normalize the mixture by dividing it with torch.std(mixture, (1, 2)), and do the same for the target signals. Specifically, use:
std_
= std(mix) mix = mix /
std_
tgt = tgt /
std_
- Raises:
- AssertionError – If the input tensor is not in the expected shape
- or the number of channels is not equal to 2. –
property num_spk