espnet2.enh.layers.bsrnn.BandSplit
espnet2.enh.layers.bsrnn.BandSplit
class espnet2.enh.layers.bsrnn.BandSplit(input_dim, target_fs=48000, channels=128, norm_type='GN')
Bases: Module
Splits the input tensor into frequency subbands for processing.
This class implements the band-splitting operation, dividing the input frequency bins into several subbands for further processing in neural network architectures, particularly in the context of speech enhancement.
subbands
A tuple representing the number of frequency bins in each subband.
- Type: tuple
subband_freqs
Frequencies corresponding to the subbands calculated from the FFT bins.
- Type: torch.Tensor
norm
A list of normalization layers for each subband.
- Type: nn.ModuleList
fc
A list of 1D convolutional layers for each subband.
Type: nn.ModuleList
Parameters:
- input_dim (int) – Maximum number of frequency bins corresponding to target_fs. Must be an odd number.
- target_fs (int) – Maximum sampling frequency supported by the model.
- channels (int) – Number of output channels after convolution for each subband.
- norm_type (str) – Type of normalization layer to use (e.g., “GN”, “BN”, etc.).
Raises:
- AssertionError – If input_dim is not an odd number or if the sum of subbands does not equal input_dim.
- NotImplementedError – If the specified input_dim and target_fs do not match predefined configurations.
####### Examples
>>> band_split = BandSplit(input_dim=481, target_fs=48000, channels=128)
>>> input_tensor = torch.randn(10, 100, 481, 2) # (B, T, F, 2)
>>> output = band_split(input_tensor)
>>> print(output.shape) # Should be (B, N, T, K')
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x, fs=None)
BSRNN forward.
- Parameters:
- x (torch.Tensor) – Input tensor of shape (B, T, F, 2), where B is the batch size, T is the number of time steps, F is the number of frequency bins, and 2 represents the real and imaginary parts of the complex input.
- fs (int , optional) – Sampling rate of the input signal. If not None, the input signal will be truncated to only process the effective frequency subbands. If None, the input signal is assumed to be already truncated to only contain effective frequency subbands.
- Returns: Output tensor of shape (B, num_spk, T, F, 2), : where num_spk is the number of speakers, T is the number of time steps, F is the number of frequency bins, and 2 represents the real and imaginary parts of the output.
- Return type: out (torch.Tensor)
####### Examples
>>> model = BSRNN()
>>> input_tensor = torch.randn(4, 100, 481, 2) # (B, T, F, 2)
>>> output = model(input_tensor, fs=48000)
>>> output.shape
torch.Size([4, 1, 100, 481, 2]) # (B, num_spk, T, F, 2)