espnet2.enh.encoder.stft_encoder.STFTEncoder
espnet2.enh.encoder.stft_encoder.STFTEncoder
class espnet2.enh.encoder.stft_encoder.STFTEncoder(n_fft: int = 512, win_length: int | None = None, hop_length: int = 128, window='hann', center: bool = True, normalized: bool = False, onesided: bool = True, use_builtin_complex: bool = True, default_fs: int = 16000, spec_transform_type: str | None = None, spec_factor: float = 0.15, spec_abs_exponent: float = 0.5)
Bases: AbsEncoder
Short-Time Fourier Transform (STFT) encoder for speech enhancement and separation.
This encoder transforms mixed speech input into frequency domain representations using the Short-Time Fourier Transform. It can be configured with various parameters including the number of FFT points, window length, hop length, and window type. The encoder also supports spectral transformations to modify the output spectrum.
output_dim
The dimension of the output spectrum.
- Type: int
stft
An instance of the Stft class that performs the Short-Time Fourier Transform.
- Type:Stft
use_builtin_complex
Flag indicating whether to use built-in complex number support in PyTorch.
- Type: bool
win_length
The length of the window used in the STFT.
- Type: int
hop_length
The number of samples between successive frames.
- Type: int
window
The type of window to use (e.g., ‘hann’).
- Type: str
n_fft
The number of FFT points.
- Type: int
center
Whether to pad the input signal so that the frame is centered at the point of analysis.
- Type: bool
default_fs
The default sampling rate in Hz.
- Type: int
spec_transform_type
Type of spectral transformation (‘exponent’, ‘log’, or ‘none’).
- Type: str
spec_factor
Scaling factor for the output spectrum.
- Type: float
spec_abs_exponent
Exponent factor used in the “exponent” transformation.
Type: float
Parameters:
- n_fft (int) – Number of FFT points. Defaults to 512.
- win_length (int) – Length of the window. Defaults to None, which sets it to n_fft.
- hop_length (int) – Number of samples between frames. Defaults to 128.
- window (str) – Type of window to use. Defaults to ‘hann’.
- center (bool) – If True, pad input so that the frame is centered. Defaults to True.
- normalized (bool) – If True, normalize the output. Defaults to False.
- onesided (bool) – If True, compute a one-sided spectrum. Defaults to True.
- use_builtin_complex (bool) – If True, use PyTorch’s built-in complex type. Defaults to True.
- default_fs (int) – Default sampling rate in Hz. Defaults to 16000.
- spec_transform_type (str) – Type of spectral transformation. Defaults to None.
- spec_factor (float) – Scaling factor for the spectrum. Defaults to 0.15.
- spec_abs_exponent (float) – Exponent for the absolute value in “exponent” transformation. Defaults to 0.5.
Returns: The transformed spectrum of shape : [Batch, T, (C,) F].
flens (torch.Tensor): The lengths of the output sequences : [Batch].
Return type: spectrum (ComplexTensor)
Raises:AssertionError – If input to forward_streaming is not a single-channel tensor.
############### Examples
encoder = STFTEncoder(n_fft=1024, win_length=512, hop_length=256) mixed_speech = torch.randn(8, 16000) # Example batch of audio ilens = torch.tensor([16000] * 8) # Example lengths spectrum, flens = encoder(mixed_speech, ilens)
streaming_input = torch.randn(1, 512) # Example single-channel input feature = encoder.forward_streaming(streaming_input)
audio = torch.randn(1, 16000) # Continuous audio signal chunks = encoder.streaming_frame(audio)
######### NOTE The spectral transformation can be customized using spec_transform_type. For example, setting it to “log” will apply a logarithmic transformation to the output spectrum.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(input: Tensor, ilens: Tensor, fs: int = None)
Perform the forward pass of the STFT encoder.
This method computes the Short-Time Fourier Transform (STFT) of the input mixed speech signal and returns the resulting spectrum along with the frame lengths. The STFT can be reconfigured based on a new sampling rate if provided.
- Parameters:
- input (torch.Tensor) – Mixed speech signal with shape [Batch, sample].
- ilens (torch.Tensor) – Input lengths with shape [Batch].
- fs (int , optional) – Sampling rate in Hz. If not None, the STFT window and hop lengths are reconfigured for the new sampling rate while keeping their duration fixed.
- Returns: A tuple containing: : - spectrum (ComplexTensor): The computed STFT spectrum : with shape [Batch, T, (C,) F].
- flens (torch.Tensor): Frame lengths with shape [Batch].
- Return type: tuple
- Raises:ValueError – If the input dimensions are incorrect or if the sampling rate provided is invalid.
############### Examples
>>> encoder = STFTEncoder()
>>> input_tensor = torch.randn(2, 16000) # Batch of 2 samples
>>> ilens_tensor = torch.tensor([16000, 16000]) # Input lengths
>>> spectrum, flens = encoder.forward(input_tensor, ilens_tensor)
>>> print(spectrum.shape) # Output shape: [2, T, F]
######### NOTE Ensure that the input tensor is of the correct shape before calling this method. The input should represent mixed speech signals for the STFT computation to be valid.
forward_streaming(input: Tensor)
STFT encoder for speech enhancement and separation.
This encoder utilizes Short-Time Fourier Transform (STFT) for processing speech signals. It supports various transformations on the spectrogram, allowing for flexible configurations suitable for speech enhancement and separation tasks.
stft
An instance of the STFT layer.
- Type:Stft
_output_dim
The output dimension of the STFT.
- Type: int
use_builtin_complex
Flag to use built-in complex tensor.
- Type: bool
win_length
The length of the window.
- Type: int
hop_length
The hop length for STFT.
- Type: int
window
The type of window function used.
- Type: str
n_fft
The number of FFT points.
- Type: int
center
If True, the input is padded so that the window is centered.
- Type: bool
default_fs
The default sampling frequency.
- Type: int
spec_transform_type
Type of spectral transformation.
- Type: str
spec_factor
Factor for scaling the output spectrum.
- Type: float
spec_abs_exponent
Exponent for the absolute value transformation.
Type: float
Parameters:
- n_fft (int) – Number of FFT points. Default is 512.
- win_length (int , optional) – Length of the window. Default is None.
- hop_length (int) – Hop length for STFT. Default is 128.
- window (str) – Type of window function. Default is “hann”.
- center (bool) – If True, the input is padded. Default is True.
- normalized (bool) – If True, the output is normalized. Default is False.
- onesided (bool) – If True, use a one-sided spectrum. Default is True.
- use_builtin_complex (bool) – If True, use built-in complex tensor. Default is True.
- default_fs (int) – Default sampling frequency. Default is 16000.
- spec_transform_type (str , optional) – Type of spectral transformation. Default is None.
- spec_factor (float) – Factor for scaling the output spectrum. Default is 0.15.
- spec_abs_exponent (float) – Exponent for the absolute value transformation. Default is 0.5.
############### Examples
encoder = STFTEncoder(n_fft=1024, win_length=512, hop_length=256) mixed_speech = torch.randn(10, 16000) # Example input tensor ilens = torch.tensor([16000] * 10) # Input lengths spectrum, flens = encoder.forward(mixed_speech, ilens)
######### NOTE This encoder requires PyTorch version 1.9.0 or later for certain functionalities.
- Raises:
- AssertionError – If the input tensor does not have the correct
- dimensions in forward_streaming. –
property output_dim : int
Output dimension of the STFT encoder.
This property returns the output dimension, which is calculated based on the number of FFT points and whether the STFT is one-sided or not. If the STFT is one-sided, the output dimension is equal to half of the FFT points plus one, otherwise it equals the number of FFT points.
- Returns: The output dimension of the STFT encoder.
- Return type: int
############### Examples
>>> encoder = STFTEncoder(n_fft=512, onesided=True)
>>> encoder.output_dim
257
>>> encoder = STFTEncoder(n_fft=512, onesided=False)
>>> encoder.output_dim
512
spec_transform_func(spec)
Applies the specified spectral transformation to the input spectrum.
This function modifies the input spectral representation based on the specified transformation type. The available transformation types are: “exponent”, “log”, and “none”. The transformations can help in various tasks such as speech enhancement and separation.
spec_transform_type
Type of transformation to apply. It can be “exponent”, “log”, or “none”.
- Type: str
spec_factor
Factor by which to scale the output spectrum.
- Type: float
spec_abs_exponent
Exponent factor used in the “exponent” transformation.
Type: float
Parameters:spec (ComplexTensor) – The input spectrum to be transformed.
Returns: The transformed spectrum after applying the specified : transformation.
Return type: ComplexTensor
############### Examples
>>> encoder = STFTEncoder(spec_transform_type="log", spec_factor=0.1)
>>> input_spec = ComplexTensor(torch.tensor([[1.0, 2.0], [3.0, 4.0]]))
>>> output_spec = encoder.spec_transform_func(input_spec)
>>> print(output_spec)
>>> encoder = STFTEncoder(spec_transform_type="exponent",
... spec_abs_exponent=2.0)
>>> output_spec = encoder.spec_transform_func(input_spec)
>>> print(output_spec)
######### NOTE Ensure that the spec is a valid ComplexTensor before calling this function to avoid runtime errors.
streaming_frame(audio)
Splits continuous audio into frame-level chunks for streaming simulation.
This function takes the entire long audio as input for a streaming simulation. It is designed to help manage your streaming input buffer in a real streaming application.
- Parameters:
- audio (torch.Tensor) – Input tensor of shape (B, T), where B is the
- audio. (batch size and T is the length of the)
- Returns: A list of tensors, each of shape (B, frame_size), representing the chunked audio frames.
- Return type: List[torch.Tensor]
######### NOTE The function assumes that the audio input has at least one dimension for the batch size and one for the audio length.
############### Examples
>>> encoder = STFTEncoder()
>>> audio_input = torch.randn(2, 16000) # Batch of 2, 16000 samples
>>> frames = encoder.streaming_frame(audio_input)
>>> print(len(frames)) # Number of frames produced
>>> print(frames[0].shape) # Shape of the first frame (B, frame_size)
- Raises:
- AssertionError – If the input audio tensor does not have at least 2
- dimensions. –