espnet2.layers.stft.Stft
espnet2.layers.stft.Stft
class espnet2.layers.stft.Stft(n_fft: int = 512, win_length: int | None = None, hop_length: int = 128, window: str | None = 'hann', center: bool = True, normalized: bool = False, onesided: bool = True)
Bases: Module
, InversibleInterface
Stft is a PyTorch module for computing the Short-Time Fourier Transform (STFT)
and its inverse. It provides an efficient implementation for transforming time-domain signals into the frequency domain, supporting multi-channel inputs.
n_fft
Number of FFT points.
- Type: int
win_length
Length of the window. If None, defaults to n_fft.
- Type: int
hop_length
Number of samples between each STFT frame.
- Type: int
window
Type of window to apply. Must be a valid PyTorch window function.
- Type: str
center
If True, pads the input signal to ensure that the frames are centered.
- Type: bool
normalized
If True, normalizes the output by the window length.
- Type: bool
onesided
If True, returns a one-sided spectrum.
Type: bool
Parameters:
- n_fft (int) – Number of FFT points (default: 512).
- win_length (Optional *[*int ]) – Length of the window (default: None).
- hop_length (int) – Number of samples between frames (default: 128).
- window (Optional *[*str ]) – Type of window (default: “hann”).
- center (bool) – Center the signal before processing (default: True).
- normalized (bool) – Normalize the output (default: False).
- onesided (bool) – Return one-sided spectrum (default: True).
Returns: output: The STFT output tensor of shape : (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2).
ilens: Optional tensor indicating the lengths of the input signals.
Return type: Tuple[torch.Tensor, Optional[torch.Tensor]]
Yields: None
Raises:
- ValueError – If the specified window is not implemented in PyTorch.
- NotImplementedError – If called in training mode on devices not supporting
- the training mode with librosa. –
########### Examples
Create an instance of Stft
stft = Stft(n_fft=1024, hop_length=256)
Compute the STFT
input_tensor = torch.randn(8, 16000) # Batch of 8 audio samples output, ilens = stft(input_tensor)
Compute the inverse STFT
reconstructed_wavs, ilens = stft.inverse(output, ilens)
NOTE
The STFT implementation is compatible with librosa’s STFT regarding padding and scaling. Note that it differs from scipy.signal.stft.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
extra_repr()
Returns a string representation of the STFT parameters for logging.
This method provides a summary of the key parameters used in the STFT (Short-Time Fourier Transform) configuration. The output is useful for debugging and understanding the current setup of the STFT instance.
n_fft
The number of FFT points.
- Type: int
win_length
The length of each windowed segment.
- Type: int
hop_length
The number of samples between successive frames.
- Type: int
center
Whether to pad the input signal on both sides.
- Type: bool
normalized
Whether to normalize the output.
- Type: bool
onesided
Whether to return a one-sided spectrum.
Type: bool
Returns: A string representation of the STFT parameters.
Return type: str
########### Examples
>>> stft = Stft(n_fft=1024, win_length=512, hop_length=256)
>>> print(stft.extra_repr())
n_fft=1024, win_length=512, hop_length=256, center=True,
normalized=False, onesided=True
STFT forward function.
Computes the Short-Time Fourier Transform (STFT) of the input tensor.
- Parameters:
- input – A tensor of shape (Batch, Nsamples) or (Batch, Nsample, Channels) representing the audio signal.
- ilens – An optional tensor of shape (Batch) that specifies the length of each input signal. If provided, it will be used to mask the output.
- Returns: A tuple containing: : - A tensor of shape (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2) representing the STFT output in the format of real and imaginary components.
- An optional tensor of shape (Batch) that contains the lengths of the output signals after STFT.
- Return type: output
NOTE
The output tensor contains the STFT results with real and imaginary parts represented in the last dimension. The input tensor can be either a single channel or multi-channel audio signal.
########### Examples
>>> stft_layer = Stft(n_fft=512, hop_length=128)
>>> audio_input = torch.randn(10, 16000) # 10 samples, 16000 audio length
>>> output, output_lengths = stft_layer(audio_input)
Inverse STFT.
This function computes the inverse Short-Time Fourier Transform (iSTFT) of the given input tensor, which can be a standard tensor or a complex tensor. The inverse STFT is used to reconstruct the time-domain signal from its frequency-domain representation.
- Parameters:
- input – A tensor of shape (batch, T, F, 2) representing the complex STFT output, or a ComplexTensor of shape (batch, T, F).
- ilens – A tensor of shape (batch,) containing the lengths of the original signals. If provided, it will be used to set the output lengths accordingly.
- Returns:
- wavs: A tensor of shape (batch, samples) containing the reconstructed time-domain waveforms.
- ilens: A tensor of shape (batch,) containing the lengths of the reconstructed signals.
- Return type: Tuple[torch.Tensor, Optional[torch.Tensor]]
########### Examples
>>> stft_layer = Stft()
>>> input_tensor = torch.randn(2, 100, 64, 2) # Example STFT output
>>> lengths = torch.tensor([100, 80]) # Example input lengths
>>> reconstructed_wavs, output_lengths = stft_layer.inverse(input_tensor, lengths)