espnet2.enh.decoder.stft_decoder.STFTDecoder
espnet2.enh.decoder.stft_decoder.STFTDecoder
class espnet2.enh.decoder.stft_decoder.STFTDecoder(n_fft: int = 512, win_length: int | None = None, hop_length: int = 128, window='hann', center: bool = True, normalized: bool = False, onesided: bool = True, default_fs: int = 16000, spec_transform_type: str | None = None, spec_factor: float = 0.15, spec_abs_exponent: float = 0.5)
Bases: AbsDecoder
STFTDecoder is a subclass of AbsDecoder that implements the Short-Time Fourier Transform (STFT) decoder for speech enhancement and separation tasks.
This class takes in complex spectrograms and reconstructs the time-domain waveforms. It includes functionality for various spectral transformations and supports both batch and streaming processing.
n_fft
Number of FFT points.
- Type: int
win_length
Length of each window segment.
- Type: int
hop_length
Number of samples between successive frames.
- Type: int
window
Type of window to use (e.g., “hann”).
- Type: str
center
If True, the signal is padded to center the window.
- Type: bool
default_fs
Default sampling rate in Hz.
- Type: int
spec_transform_type
Type of spectral transformation (“exponent”, “log”, or “none”).
- Type: str
spec_factor
Scaling factor for the output spectrum.
- Type: float
spec_abs_exponent
Exponent factor used in the “exponent” transformation.
Type: float
Parameters:
- n_fft (int) – Number of FFT points. Default is 512.
- win_length (int , optional) – Length of each window segment. Default is None.
- hop_length (int) – Number of samples between successive frames. Default is 128.
- window (str) – Type of window to use (e.g., “hann”). Default is “hann”.
- center (bool) – If True, the signal is padded to center the window. Default is True.
- normalized (bool) – If True, the window is normalized. Default is False.
- onesided (bool) – If True, the output will be one-sided. Default is True.
- default_fs (int) – Default sampling rate in Hz. Default is 16000.
- spec_transform_type (str , optional) – Type of spectral transformation (“exponent”, “log”, or “none”). Default is None.
- spec_factor (float) – Scaling factor for the output spectrum. Default is 0.15.
- spec_abs_exponent (float) – Exponent factor used in the “exponent” transformation. Default is 0.5.
Returns: Reconstructed waveforms and their lengths.
Return type: Tuple[torch.Tensor, torch.Tensor]
Raises:TypeError – If the input tensor is not a complex tensor.
############# Examples
>>> import torch
>>> from espnet2.enh.encoder.stft_encoder import STFTEncoder
>>> input_audio = torch.randn((1, 100))
>>> ilens = torch.LongTensor([100])
>>> nfft = 32
>>> win_length = 28
>>> hop = 10
>>> encoder = STFTEncoder(n_fft=nfft, win_length=win_length,
... hop_length=hop, onesided=True,
... spec_transform_type="exponent")
>>> decoder = STFTDecoder(n_fft=nfft, win_length=win_length,
... hop_length=hop, onesided=True,
... spec_transform_type="exponent")
>>> frames, flens = encoder(input_audio, ilens)
>>> wav, ilens = decoder(frames, ilens)
######## NOTE The class supports half-precision training for compatible input types.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(input: ComplexTensor, ilens: Tensor, fs: int = None)
Forward method for the STFTDecoder class, which processes the input spectrum
and reconstructs the time-domain waveform.
This method takes a complex spectrum as input and uses the inverse Short-Time Fourier Transform (iSTFT) to convert it back to the time-domain waveform. The input can be configured for different sampling rates.
- Parameters:
- input (ComplexTensor) – Spectrum tensor of shape [Batch, T, (C,) F], where T is the number of time frames, C is the number of channels, and F is the number of frequency bins.
- ilens (torch.Tensor) – A tensor containing the lengths of each input sequence in the batch. Shape [Batch].
- fs (int , optional) – The sampling rate in Hz. If not None, the iSTFT window and hop lengths are reconfigured for the new sampling rate while keeping their duration fixed.
- Returns: A tuple containing: : - wav (torch.Tensor): The reconstructed waveform of shape [Batch, Nsamples, (C,)].
- wav_lens (torch.Tensor): The lengths of the reconstructed waveforms, shape [Batch].
- Return type: tuple
- Raises:
- TypeError – If the input tensor is not of type ComplexTensor and
- if PyTorch version is 1.9.0 or higher and the input is not a complex tensor. –
############# Examples
>>> import torch
>>> from torch_complex.tensor import ComplexTensor
>>> decoder = STFTDecoder(n_fft=512, hop_length=128)
>>> input_spectrum = ComplexTensor(torch.randn(1, 100, 1, 257)) # Example spectrum
>>> ilens = torch.tensor([100]) # Example input lengths
>>> wav, wav_lens = decoder(input_spectrum, ilens)
>>> print(wav.shape, wav_lens.shape) # Output shapes
######## NOTE The input tensor must be a complex tensor to perform the inverse STFT operation.
forward_streaming(input_frame: Tensor)
Process a single frame of complex spectrum input to produce audio.
This method performs an inverse short-time Fourier transform (iSTFT) on the input frame and returns the corresponding audio waveform. The input is expected to be a complex tensor representing the spectrum of a single frame.
- Parameters:input_frame (torch.Tensor) – Spectrum of shape [Batch, 1, F] where F is the number of frequency bins.
- Returns: The reconstructed audio waveform of shape : [Batch, 1, self.win_length].
- Return type: torch.Tensor
############# Examples
>>> input_frame = torch.randn((1, 1, 512), dtype=torch.complex64)
>>> output_wav = decoder.forward_streaming(input_frame)
>>> output_wav.shape
torch.Size([1, 1, 512])
spec_back(spec)
STFTDecoder is a class that implements a Short-Time Fourier Transform (STFT)
decoder for speech enhancement and separation.
This decoder is designed to convert complex spectral representations back into time-domain waveforms, enabling applications in speech processing.
stft
Instance of the STFT layer used for converting spectra back to waveforms.
- Type:Stft
win_length
Length of the window used for STFT.
- Type: int
n_fft
Number of FFT points.
- Type: int
hop_length
Number of samples to hop between frames.
- Type: int
window
Type of window function used for STFT.
- Type: str
center
If True, the signal is centered before the STFT.
- Type: bool
default_fs
Default sampling frequency for reconfiguration.
- Type: int
spec_transform_type
Type of spectral transformation to apply (“exponent”, “log”, or “none”).
- Type: str
spec_factor
Factor to scale the spectrum.
- Type: float
spec_abs_exponent
Exponent used in the “exponent” transformation.
Type: float
Parameters:
- n_fft (int) – Number of FFT points. Default is 512.
- win_length (int , optional) – Length of the window. If None, defaults to n_fft.
- hop_length (int) – Number of samples to hop between frames. Default is 128.
- window (str) – Type of window function (e.g., “hann”). Default is “hann”.
- center (bool) – If True, signal is centered before the STFT. Default is True.
- normalized (bool) – If True, normalize the STFT. Default is False.
- onesided (bool) – If True, use a one-sided STFT. Default is True.
- default_fs (int) – Default sampling frequency. Default is 16000.
- spec_transform_type (str , optional) – Type of spectral transformation (“exponent”, “log”, or “none”). Default is None.
- spec_factor (float) – Factor to scale the spectrum. Default is 0.15.
- spec_abs_exponent (float) – Exponent for “exponent” transformation. Default is 0.5.
Returns: Returns the reconstructed waveforms and their lengths.
Return type: Tuple[torch.Tensor, torch.Tensor]
Raises:
- TypeError – If the input tensor is not of type ComplexTensor or a
- compatible complex tensor. –
############# Examples
Example usage:
import torch from espnet2.enh.encoder.stft_encoder import STFTEncoder
input_audio = torch.randn((1, 100)) ilens = torch.LongTensor([100])
nfft = 32 win_length = 28 hop = 10
encoder = STFTEncoder( : n_fft=nfft, win_length=win_length, hop_length=hop, onesided=True, spec_transform_type=”exponent”,
) decoder = STFTDecoder(
n_fft=nfft, win_length=win_length, hop_length=hop, onesided=True, spec_transform_type=”exponent”,
) frames, flens = encoder(input_audio, ilens) wav, ilens = decoder(frames, ilens)
######## NOTE The STFTDecoder is particularly useful for applications in speech enhancement and separation tasks, where the conversion from spectral to time-domain representations is essential.
streaming_merge(chunks, ilens=None)
Merge frame-level processed audio chunks in a streaming simulation.
This method merges audio chunks processed at the frame level. It is important to note that, in real applications, the processed audio should be sent to the output channel frame by frame. This function can be referred to for managing the streaming output buffer.
- Parameters:
- chunks (List *[*torch.Tensor ]) – A list of audio chunks, each of shape (B, frame_size).
- ilens (torch.Tensor , optional) – Input lengths of shape [B]. If provided, it will be used to trim the merged audio.
- Returns: Merged audio of shape [B, T].
- Return type: torch.Tensor
############# Examples
>>> decoder = STFTDecoder(win_length=256, hop_length=128)
>>> chunks = [torch.randn(2, 256) for _ in range(5)] # 5 chunks
>>> ilens = torch.tensor([256, 256]) # Lengths for each batch
>>> merged_audio = decoder.streaming_merge(chunks, ilens)
>>> print(merged_audio.shape)
torch.Size([2, T]) # T will depend on the number of chunks
######## NOTE The output audio is normalized based on the applied windowing function.