espnet2.enh.separator.tfgridnet_separator.TFGridNet
espnet2.enh.separator.tfgridnet_separator.TFGridNet
class espnet2.enh.separator.tfgridnet_separator.TFGridNet(input_dim, n_srcs=2, n_fft=128, stride=64, window='hann', n_imics=1, n_layers=6, lstm_hidden_units=192, attn_n_head=4, attn_approx_qk_dim=512, emb_dim=48, emb_ks=4, emb_hs=1, activation='prelu', eps=1e-05, use_builtin_complex=False, ref_channel=-1)
Bases: AbsSeparator
Offline TFGridNet for speech separation.
This class implements the TF-GridNet model, which integrates full- and sub-band modeling for speech separation, as described in the following references:
[1] Z.-Q. Wang, S. Cornell, S. Choi, Y. Lee, B.-Y. Kim, and S. Watanabe, “TF-GridNet: Integrating Full- and Sub-Band Modeling for Speech Separation”, in arXiv preprint arXiv:2211.12433, 2022.
[2] Z.-Q. Wang, S. Cornell, S. Choi, Y. Lee, B.-Y. Kim, and S. Watanabe, “TF-GridNet: Making Time-Frequency Domain Models Great Again for Monaural Speaker Separation”, in arXiv preprint arXiv:2209.03952, 2022.
####### NOTE The model performs optimally when trained with variance-normalized mixture input and target signals. For instance, for a mixture tensor of shape [batch, samples, microphones], normalize it by dividing with torch.std(mixture, (1, 2)). The same normalization should be applied to the target signals, especially when not using scale-invariant loss functions like SI-SDR.
- Parameters:
- input_dim (int) – Placeholder, not used.
- n_srcs (int) – Number of output sources/speakers.
- n_fft (int) – STFT window size.
- stride (int) – STFT stride.
- window (str) – STFT window type; options are ‘hamming’, ‘hanning’, or None.
- n_imics (int) – Number of microphone channels (only fixed-array geometry supported).
- n_layers (int) – Number of TFGridNet blocks.
- lstm_hidden_units (int) – Number of hidden units in LSTM.
- attn_n_head (int) – Number of heads in self-attention.
- attn_approx_qk_dim (int) – Approximate dimension of frame-level key and value tensors.
- emb_dim (int) – Embedding dimension.
- emb_ks (int) – Kernel size for unfolding and deconvolution (deconv1D).
- emb_hs (int) – Hop size for unfolding and deconvolution (deconv1D).
- activation (str) – Activation function to use in the TFGridNet model, e.g., ‘relu’ or ‘elu’.
- eps (float) – Small epsilon for normalization layers.
- use_builtin_complex (bool) – Whether to use built-in complex type or not.
- ref_channel (int) – Reference channel for the input signals, default is -1.
- Returns: A tuple containing: : - enhanced (List[torch.Tensor]): List of mono audio tensors with shape [(B, T), …] for each source.
- ilens (torch.Tensor): Input lengths of shape (B,).
- additional (OrderedDict): Currently unused data, returned as part of the output.
- Return type: Tuple[List[torch.Tensor], torch.Tensor, OrderedDict]
########### Examples
>>> model = TFGridNet(n_srcs=2, n_fft=256)
>>> input_tensor = torch.randn(10, 16000, 1) # [B, N, M]
>>> ilens = torch.tensor([16000] * 10) # Lengths
>>> enhanced, ilens, _ = model(input_tensor, ilens)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(input: Tensor, ilens: Tensor, additional: Dict | None = None) → Tuple[List[Tensor], Tensor, OrderedDict]
Forward pass for the TFGridNet model.
This method processes the input multi-channel audio tensor to perform speech separation using the TFGridNet architecture.
- Parameters:
- input (torch.Tensor) – Batched multi-channel audio tensor with M audio channels and N samples of shape [B, N, M].
- ilens (torch.Tensor) – Input lengths of shape [B].
- additional (Dict or None) – Other data, currently unused in this model.
- Returns:
- enhanced (List[torch.Tensor]): A list of length n_srcs containing mono audio tensors with T samples of shape [(B, T), …].
- ilens (torch.Tensor): Input lengths of shape (B,).
- additional (OrderedDict): Other data, currently unused in this model, returned as part of the output.
- Return type: Tuple[List[torch.Tensor], torch.Tensor, OrderedDict]
########### Examples
>>> model = TFGridNet(n_srcs=2, n_fft=128)
>>> input_tensor = torch.randn(4, 16000, 2) # [B, N, M]
>>> ilens = torch.tensor([16000, 16000, 16000, 16000]) # [B]
>>> enhanced, ilens_out, _ = model(input_tensor, ilens)
####### NOTE As outlined in the model documentation, this model works best when trained with variance normalized mixture input and target. For a mixture of shape [batch, samples, microphones], normalize it by dividing with torch.std(mixture, (1, 2)). The same should be done for the target signals, especially when not using scale-invariant loss functions such as SI-SDR.
property num_spk
Offline TFGridNet
Reference: [1] Z.-Q. Wang, S. Cornell, S. Choi, Y. Lee, B.-Y. Kim, and S. Watanabe, “TF-GridNet: Integrating Full- and Sub-Band Modeling for Speech Separation”, in arXiv preprint arXiv:2211.12433, 2022. [2] Z.-Q. Wang, S. Cornell, S. Choi, Y. Lee, B.-Y. Kim, and S. Watanabe, “TF-GridNet: Making Time-Frequency Domain Models Great Again for Monaural Speaker Separation”, in arXiv preprint arXiv:2209.03952, 2022.
NOTES: As outlined in the Reference, this model works best when trained with variance normalized mixture input and target, e.g., with mixture of shape [batch, samples, microphones], you normalize it by dividing with torch.std(mixture, (1, 2)). You must do the same for the target signals. It is encouraged to do so when not using scale-invariant loss functions such as SI-SDR.
- Parameters:
- input_dim – placeholder, not used
- n_srcs – number of output sources/speakers.
- n_fft – stft window size.
- stride – stft stride.
- window – stft window type choose between ‘hamming’, ‘hanning’ or None.
- n_imics – number of microphones channels (only fixed-array geometry supported).
- n_layers – number of TFGridNet blocks.
- lstm_hidden_units – number of hidden units in LSTM.
- attn_n_head – number of heads in self-attention.
- attn_approx_qk_dim – approximate dimension of frame-level key and value tensors.
- emb_dim – embedding dimension.
- emb_ks – kernel size for unfolding and deconv1D.
- emb_hs – hop size for unfolding and deconv1D.
- activation – activation function to use in the whole TFGridNet model, you can use any torch supported activation e.g. ‘relu’ or ‘elu’.
- eps – small epsilon for normalization layers.
- use_builtin_complex – whether to use builtin complex type or not.
########### Examples
>>> model = TFGridNet(n_srcs=3, n_fft=256)
>>> input_tensor = torch.randn(2, 1000, 1) # [B, N, M]
>>> ilens = torch.tensor([1000, 1000]) # input lengths
>>> output, ilens_out, _ = model(input_tensor, ilens)
num_spk
Returns the number of output sources/speakers.
static pad2(input_tensor, target_len)
Pads the input tensor to the specified target length.
This method uses PyTorch’s functional pad operation to add zeros to the end of the input tensor until it reaches the desired length. If the input tensor is already longer than the target length, it will remain unchanged.
- Parameters:
- input_tensor (torch.Tensor) – The input tensor to be padded. It is expected to be of shape [B, C, T] where B is the batch size, C is the number of channels, and T is the length of the tensor along the last dimension.
- target_len (int) – The desired length of the last dimension after padding.
- Returns: The padded tensor with the shape [B, C, target_len].
- Return type: torch.Tensor
########### Examples
>>> input_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> target_len = 5
>>> padded_tensor = TFGridNet.pad2(input_tensor, target_len)
>>> print(padded_tensor)
tensor([[1, 2, 3, 0, 0],
[4, 5, 6, 0, 0]])
####### NOTE If the input tensor’s last dimension is already equal to or greater than target_len, the output tensor will be the same as the input tensor.