espnet2.asr.frontend.whisper.WhisperFrontend
espnet2.asr.frontend.whisper.WhisperFrontend
class espnet2.asr.frontend.whisper.WhisperFrontend(whisper_model: str = 'small', fs: int | str = 16000, freeze_weights: bool = True, download_dir: str | None = None)
Bases: AbsFrontend
WhisperFrontend is a speech representation frontend that utilizes the outputs from OpenAI’s Whisper model to convert audio signals into log-mel spectrograms and encoded features.
This class inherits from AbsFrontend and is designed to work with audio data processed through the Whisper model. The Whisper model is capable of handling speech recognition tasks and this frontend allows users to extract meaningful features from audio inputs.
For more information on the Whisper model, please visit: https://github.com/openai/whisper
n_fft
The number of FFT components.
- Type: int
win_length
The window length for the STFT.
- Type: int
hop_length
The hop length for the STFT.
- Type: int
n_mels
The number of mel filter banks.
- Type: int
mel_filters
Function to generate mel filters.
- Type: callable
pad_or_trim
Function to pad or trim audio inputs.
- Type: callable
whisper
The loaded Whisper model for feature extraction.
- Type: Model
freeze_weights
If True, the weights of the model are frozen.
Type: bool
Parameters:
- whisper_model (str) – The name of the Whisper model to use (default: “small”).
- fs (Union *[*int , str ]) – The sampling frequency of the audio (default: 16000).
- freeze_weights (bool) – Whether to freeze the weights of the Whisper model during feature extraction (default: True).
- download_dir (Optional *[*str ]) – Directory to download the Whisper model if not available locally (default: None).
Returns: None
Raises:
- ImportError – If the Whisper model is not properly installed.
- AssertionError – If the provided whisper_model is not available.
############# Examples
Initialize the frontend with a specific Whisper model
frontend = WhisperFrontend(whisper_model=”base”)
Process an audio tensor
audio_tensor = torch.randn(1, 16000) # Example audio tensor input_lengths = torch.tensor([16000]) # Lengths of the input audio features, lengths = frontend(audio_tensor, input_lengths)
######## NOTE The Whisper model only supports audio sampled at 16 kHz. Using a different sampling rate will result in a warning.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(input: Tensor, input_lengths: Tensor) → Tuple[Tensor, Tensor]
Processes the input audio tensor and computes the log-mel spectrogram followed by encoding through the Whisper model.
- Parameters:
- input (torch.Tensor) – The input audio tensor with shape (B, T), where B is the batch size and T is the number of time steps.
- input_lengths (torch.Tensor) – A tensor of shape (B,) containing the lengths of each input sequence.
- Returns: A tuple containing: : - feats (torch.Tensor): The encoded features from the Whisper model with shape (B, D, L’), where D is the feature dimension and L’ is the output sequence length.
- feats_lens (torch.Tensor): A tensor of shape (B,) containing the lengths of the encoded features.
- Return type: Tuple[torch.Tensor, torch.Tensor]
############# Examples
>>> frontend = WhisperFrontend()
>>> audio_tensor = torch.randn(2, 16000) # Example audio for 2 batches
>>> lengths = torch.tensor([16000, 16000]) # Input lengths
>>> features, lengths = frontend.forward(audio_tensor, lengths)
>>> print(features.shape) # Output shape should be (2, D, L')
>>> print(lengths) # Output lengths for each batch
######## NOTE The freeze_weights attribute determines whether the weights of the Whisper model should be frozen during the forward pass.
log_mel_spectrogram(audio: Tensor, ilens: Tensor | None = None) → Tensor
Computes the log-mel spectrogram of the given audio input.
This method applies Short-Time Fourier Transform (STFT) to the input audio, computes the mel spectrogram, and then converts it to a log scale. The output can be used as input features for further processing in speech recognition tasks.
- Parameters:
- audio (torch.Tensor) – A tensor of audio waveforms with shape (N, T), where N is the batch size and T is the number of audio samples.
- ilens (torch.Tensor , optional) – A tensor containing the lengths of the audio sequences in the batch. If provided, the output lengths will be computed based on these input lengths. Default is None.
- Returns:
- log_spec (torch.Tensor): A tensor containing the log-mel spectrogram with shape (N, n_mels, T’).
- olens (Optional[torch.Tensor]): A tensor containing the output lengths of the log-mel spectrogram sequences, with shape (N,). Returns None if ilens is not provided.
- Return type: Tuple[torch.Tensor, Optional[torch.Tensor]]
- Raises:ValueError – If the audio tensor is empty or has an invalid shape.
############# Examples
>>> frontend = WhisperFrontend()
>>> audio_tensor = torch.randn(1, 16000) # Example audio
>>> log_mel_spec, output_lengths = frontend.log_mel_spectrogram(audio_tensor)
>>> print(log_mel_spec.shape) # Output: (1, 80, T')
>>> print(output_lengths) # Output: lengths of log-mel spectrogram
######## NOTE The input audio should be sampled at 16 kHz for optimal results, as the Whisper model is trained on this sampling rate.
output_size() → int
Returns the output size of the Whisper model’s encoder.
The output size corresponds to the number of features produced by the last layer of the encoder in the Whisper model. This can be useful for downstream tasks where the output dimension needs to be known.
- Parameters:None
- Returns: The output size of the Whisper model’s encoder.
- Return type: int
############# Examples
>>> frontend = WhisperFrontend(whisper_model='small')
>>> size = frontend.output_size()
>>> print(size)
768 # This value may vary depending on the model used.
whisper
Encodes input audio features using the Whisper model’s encoder.
This method processes the input tensor through the Whisper encoder, applying convolutional layers and positional embeddings. It returns the encoded output along with the optional output lengths, which indicate the number of valid frames produced by the encoder.
- Parameters:
- input (torch.Tensor) – The input audio features to be encoded, expected to be in the shape (batch_size, num_features, sequence_length).
- ilens (torch.Tensor , optional) – The lengths of the input sequences, used to calculate output lengths. If not provided, output lengths will not be computed.
- Returns: A tuple containing: : - torch.Tensor: The encoded output from the Whisper encoder.
- Optional[torch.Tensor]: The lengths of the output sequences, <br/> if ilens was provided. Otherwise, this will be None.
- Return type: Tuple[torch.Tensor, Optional[torch.Tensor]]
############# Examples
>>> frontend = WhisperFrontend()
>>> audio_features = torch.randn(2, 80, 100) # Batch of 2, 80 features, 100 time steps
>>> output, output_lengths = frontend.whisper_encode(audio_features)
>>> print(output.shape) # Should print: (2, n_heads, n_frames)
######## NOTE The input tensor should contain log-mel spectrogram features, and the audio should be sampled at 16 kHz, as expected by the Whisper model.
- Raises:RuntimeError – If the input tensor has an invalid shape or if the Whisper model encounters an error during encoding.