espnet2.asr.frontend.default.DefaultFrontend
espnet2.asr.frontend.default.DefaultFrontend
class espnet2.asr.frontend.default.DefaultFrontend(fs: int | str = 16000, n_fft: int = 512, win_length: int | None = None, hop_length: int = 128, window: str | None = 'hann', center: bool = True, normalized: bool = False, onesided: bool = True, n_mels: int = 80, fmin: int | None = None, fmax: int | None = None, htk: bool = False, frontend_conf: dict | None = {'badim': 320, 'bdropout_rate': 0.0, 'blayers': 3, 'bnmask': 2, 'bprojs': 320, 'btype': 'blstmp', 'bunits': 300, 'delay': 3, 'ref_channel': -1, 'taps': 5, 'use_beamformer': False, 'use_dnn_mask_for_wpe': True, 'use_wpe': False, 'wdropout_rate': 0.0, 'wlayers': 3, 'wprojs': 320, 'wtype': 'blstmp', 'wunits': 300}, apply_stft: bool = True)
Bases: AbsFrontend
DefaultFrontend is a conventional frontend structure for automatic speech recognition (ASR). It processes audio signals through a series of transformations, including Short-Time Fourier Transform (STFT), Weighted Prediction Error (WPE), Minimum Variance Distortionless Response (MVDR) beamforming, power spectrum calculation, and finally converts to Log-Mel filterbanks.
The processing flow is as follows: STFT -> WPE -> MVDR-Beamformer -> Power-spec -> Log-Mel-Fbank
hop_length
The number of audio samples between adjacent STFT frames.
- Type: int
apply_stft
Flag to indicate if STFT should be applied.
- Type: bool
frontend
The frontend model for speech enhancement, if applied.
- Type:Frontend
logmel
The Log-Mel filterbank layer.
- Type:LogMel
n_mels
Number of Mel frequency bins.
Type: int
Parameters:
- fs (Union *[*int , str ]) – Sampling frequency (default is 16000).
- n_fft (int) – Number of FFT points (default is 512).
- win_length (Optional *[*int ]) – Length of the window (default is None).
- hop_length (int) – Number of samples between frames (default is 128).
- window (Optional *[*str ]) – Window function (default is “hann”).
- center (bool) – If True, pads input such that the frame is centered at the original time index (default is True).
- normalized (bool) – If True, normalize the output of STFT (default is False).
- onesided (bool) – If True, returns only the positive frequency components (default is True).
- n_mels (int) – Number of Mel bands to generate (default is 80).
- fmin (Optional *[*int ]) – Minimum frequency (default is None).
- fmax (Optional *[*int ]) – Maximum frequency (default is None).
- htk (bool) – If True, use HTK formula for Mel filterbank (default is False).
- frontend_conf (Optional *[*dict ]) – Configuration for the frontend model (default is a copy of Frontend’s default kwargs).
- apply_stft (bool) – Flag to apply STFT (default is True).
Returns: The processed features and their lengths.
Return type: Tuple[torch.Tensor, torch.Tensor]
Raises:AssertionError – If input dimensions do not match expected shapes.
######### Examples
Initialize the frontend
frontend = DefaultFrontend(fs=16000, n_fft=512, n_mels=80)
Process an input tensor
input_tensor = torch.randn(2, 16000) # Batch of 2, 1 second audio input_lengths = torch.tensor([16000, 16000]) # Lengths of each input features, lengths = frontend(input_tensor, input_lengths)
####### NOTE Ensure that the input tensor has the correct shape and type before processing. The input should be a 2D tensor of shape (batch_size, num_samples).
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(input: Tensor, input_lengths: Tensor) → Tuple[Tensor, Tensor]
Forward pass for the DefaultFrontend class.
This method processes the input tensor through various stages of the frontend pipeline, which includes domain conversion via Short-Time Fourier Transform (STFT), optional speech enhancement, channel selection for multi-channel input, and transformation to Log-Mel features.
- Parameters:
- input (torch.Tensor) – The input tensor containing audio waveforms. The expected shape is (Batch, Length) for single-channel audio or (Batch, Length, Channels) for multi-channel audio.
- input_lengths (torch.Tensor) – A tensor of shape (Batch,) containing the lengths of each input sequence. This is used to handle variable-length inputs.
- Returns: A tuple containing: : - torch.Tensor: The extracted features of shape (Batch, Length, Dim), where Dim is the number of Mel bands.
- torch.Tensor: The lengths of the extracted features.
- Return type: Tuple[torch.Tensor, torch.Tensor]
- Raises:
- AssertionError – If the dimensions of the input STFT do not meet
- the expected requirements. –
######### Examples
>>> frontend = DefaultFrontend()
>>> audio_input = torch.randn(2, 16000) # Batch of 2, 1 second audio
>>> input_lengths = torch.tensor([16000, 16000]) # Lengths of inputs
>>> features, lengths = frontend.forward(audio_input, input_lengths)
>>> print(features.shape) # Should print: torch.Size([2, Length, 80])
####### NOTE The apply_stft argument in the constructor determines whether to apply STFT to the input. If set to False, the input should be a complex tensor.
output_size() → int
Return the output size of the frontend.
This method returns the number of Mel frequency bands that are produced by the frontend’s log-mel layer. The output size is typically determined by the n_mels parameter set during initialization.
- Returns: The number of Mel frequency bands produced by the frontend.
- Return type: int
######### Examples
>>> frontend = DefaultFrontend(n_mels=40)
>>> frontend.output_size()
40
####### NOTE This method is particularly useful for determining the shape of the output tensor after feature extraction, especially in the context of downstream tasks such as ASR.