espnet2.asr.decoder.whisper_decoder.OpenAIWhisperDecoder
espnet2.asr.decoder.whisper_decoder.OpenAIWhisperDecoder
class espnet2.asr.decoder.whisper_decoder.OpenAIWhisperDecoder(vocab_size: int, encoder_output_size: int, dropout_rate: float = 0.0, whisper_model: str = 'small', download_dir: str | None = None, load_origin_token_embedding=False)
Bases: AbsDecoder
, BatchScorerInterface
OpenAIWhisperDecoder is a transformer-based decoder for speech-to-text tasks using OpenAI’s Whisper model. It inherits from AbsDecoder and implements BatchScorerInterface for scoring functionality.
This decoder is designed to process encoded audio features and produce token predictions based on a given vocabulary size. It allows for customization through various parameters, including dropout rates and model selection.
decoders
The decoder network loaded from the Whisper model.
- Type: torch.nn.Module
load_origin_token_embedding
Flag to indicate whether to load original token embeddings when expanding vocabulary.
Type: bool
Parameters:
- vocab_size (int) – The size of the vocabulary for the model.
- encoder_output_size (int) – The size of the encoder output features.
- dropout_rate (float , optional) – Dropout rate to apply in the decoder. Defaults to 0.0.
- whisper_model (str , optional) – The specific Whisper model to use (e.g., “small”). Defaults to “small”.
- download_dir (Optional *[*str ] , optional) – Directory to download the Whisper model if not already present. Defaults to None.
- load_origin_token_embedding (bool , optional) – If True, load original token embeddings when vocabulary is expanded. Defaults to False.
Returns: A tuple containing: : - x (torch.Tensor): Decoded token scores before softmax <br/> (batch, maxlen_out, token).
- olens (torch.Tensor): Lengths of the output sequences (batch,).
Return type: Tuple[torch.Tensor, torch.Tensor]
Yields: None
Raises:
- AssertionError – If the specified whisper_model is not available.
- Exception – If the Whisper model fails to load due to installation issues.
############# Examples
Initialize the decoder
decoder = OpenAIWhisperDecoder(vocab_size=50000, encoder_output_size=512)
Forward pass through the decoder
hs_pad = torch.rand(16, 100, 512) # Simulated encoder output hlens = torch.tensor([100] * 16) # Lengths of the input sequences ys_in_pad = torch.randint(0, 50000, (16, 50)) # Simulated token ids ys_in_lens = torch.tensor([50] * 16) # Lengths of the output sequences
x, olens = decoder(hs_pad, hlens, ys_in_pad, ys_in_lens)
######### NOTE The Whisper model architecture does not use dropout by default. If a dropout rate is specified, it will be applied during training.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
batch_score(ys: Tensor, states: List[Any], xs: Tensor) → Tuple[Tensor, List[Any]]
Score new token batch using the decoder.
This method computes the scores for the next token predictions based on the provided prefix tokens and encoder features. It is designed for batch processing, allowing multiple sequences to be scored simultaneously.
- Parameters:
- ys (torch.Tensor) – A tensor of shape (n_batch, ylen) containing the prefix tokens in int64 format.
- states (List *[*Any ]) – A list containing the scorer states for the prefix tokens, used for maintaining the state across batches.
- xs (torch.Tensor) – A tensor of shape (n_batch, xlen, n_feat) representing the encoder features that generate the prefix tokens.
- Returns: A tuple containing: : - A tensor of shape (n_batch, n_vocab) with the batchified scores <br/> for the next token.
- A list of next state for the prefix tokens (ys).
- Return type: Tuple[torch.Tensor, List[Any]]
############# Examples
>>> decoder = OpenAIWhisperDecoder(vocab_size=50000, encoder_output_size=512)
>>> ys = torch.tensor([[1, 2, 3], [1, 2, 4]]) # Example prefix tokens
>>> states = [None, None] # Example states for each batch
>>> xs = torch.rand(2, 100, 512) # Example encoder features
>>> logp, new_states = decoder.batch_score(ys, states, xs)
>>> print(logp.shape) # Output: torch.Size([2, 50000])
######### NOTE The method currently ignores the cached state for simplicity.
forward(hs_pad: Tensor, hlens: Tensor, ys_in_pad: Tensor, ys_in_lens: Tensor) → Tuple[Tensor, Tensor]
Forward decoder.
This method takes the encoded memory and input token ids to produce the decoded token scores before applying softmax.
- Parameters:
- hs_pad (torch.Tensor) – Encoded memory with shape (batch, maxlen_in, feat) of type float32.
- hlens (torch.Tensor) – Lengths of the encoded memory, shape (batch).
- ys_in_pad (torch.Tensor) – Input token ids with shape (batch, maxlen_out) of type int64. This could either be token ids if input_layer is “embed”, or a tensor (batch, maxlen_out, #mels) in other scenarios.
- ys_in_lens (torch.Tensor) – Lengths of the input tokens, shape (batch).
- Returns: A tuple containing: : - x (torch.Tensor): Decoded token scores before softmax with shape (batch, maxlen_out, token) if use_output_layer is True.
- olens (torch.Tensor): Lengths of the output tokens, shape (batch,).
- Return type: Tuple[torch.Tensor, torch.Tensor]
############# Examples
>>> hs_pad = torch.rand(32, 10, 512) # Example encoded memory
>>> hlens = torch.tensor([10] * 32) # Example lengths
>>> ys_in_pad = torch.randint(0, 100, (32, 20)) # Example token ids
>>> ys_in_lens = torch.tensor([20] * 32) # Example lengths
>>> decoder = OpenAIWhisperDecoder(vocab_size=100, encoder_output_size=512)
>>> scores, output_lengths = decoder.forward(hs_pad, hlens, ys_in_pad, ys_in_lens)
######### NOTE The method uses the decoder’s token embedding and positional embedding, followed by dropout and block processing through the decoder layers.
- Raises:ValueError – If the input tensor dimensions do not match the expected shapes.
forward_one_step(tgt: Tensor, tgt_mask: Tensor, memory: Tensor, *, cache: List[Tensor] | None = None) → Tuple[Tensor, List[Tensor]]
Forward one step in the decoding process of the OpenAI Whisper model.
This method computes the output for a single decoding step using the provided target tokens and the encoded memory. It also manages the positional embeddings and applies the necessary transformations through the decoder blocks.
- Parameters:
- tgt (torch.Tensor) – Input token ids, of shape (batch, maxlen_out).
- tgt_mask (torch.Tensor) – Input token mask, of shape (batch, maxlen_out). The dtype should be torch.uint8 for PyTorch versions < 1.2 and torch.bool for PyTorch 1.2 and above.
- memory (torch.Tensor) – Encoded memory, of shape (batch, maxlen_in, feat).
- cache (List *[*torch.Tensor ] , optional) – Cached output list of shape (batch, max_time_out-1, size). Defaults to None.
- Returns:
- torch.Tensor: Neural network output value, of shape (batch, maxlen_out, token).
- List[torch.Tensor]: Updated cache, currently returns None as cache implementation is ignored for simplicity.
- Return type: Tuple[torch.Tensor, List[torch.Tensor]]
######### NOTE The cache implementation is not utilized in this version for simplicity and correctness.
############# Examples
>>> decoder = OpenAIWhisperDecoder(vocab_size=1000,
... encoder_output_size=512)
>>> tgt = torch.randint(0, 1000, (32, 10)) # Example input
>>> tgt_mask = torch.ones((32, 10), dtype=torch.bool)
>>> memory = torch.rand((32, 20, 512)) # Example memory
>>> output, cache = decoder.forward_one_step(tgt, tgt_mask, memory)
score(ys, state, x)
Score the token predictions based on the input sequence and current state.
This method computes the log probabilities of the next token given the previous tokens and the encoded memory from the encoder. It uses the forward_one_step method to perform a single decoding step and returns the resulting log probabilities along with the updated state.
- Parameters:
- ys (torch.Tensor) – A tensor of shape (1, ylen) containing the input token IDs, where ylen is the length of the token sequence.
- state (Any) – The current state used for caching previous computations.
- x (torch.Tensor) – A tensor of shape (1, xlen, feat) representing the encoded memory, where xlen is the length of the encoder output and feat is the feature dimension.
- Returns: A tuple containing: : - logp (torch.Tensor): A tensor of shape (n_vocab,) with the <br/> log probabilities of the next token.
- state (Any): The updated state after processing the input.
- Return type: Tuple[torch.Tensor, Any]
############# Examples
>>> decoder = OpenAIWhisperDecoder(vocab_size=50000, encoder_output_size=256)
>>> ys = torch.tensor([1, 2, 3]) # example token sequence
>>> state = None # initial state
>>> x = torch.randn(1, 10, 256) # example encoder output
>>> logp, new_state = decoder.score(ys, state, x)
######### NOTE The input ys must have at least one token, and the shape of x should match the expected input format for the encoder’s output.