espnet2.asr.transducer.rnnt_multi_blank.utils.cpu_utils.cpu_rnnt.CpuRNNT_metadata
espnet2.asr.transducer.rnnt_multi_blank.utils.cpu_utils.cpu_rnnt.CpuRNNT_metadata
class espnet2.asr.transducer.rnnt_multi_blank.utils.cpu_utils.cpu_rnnt.CpuRNNT_metadata(T: int, U: int, workspace: Tensor, bytes_used: int, blank: int, labels: Tensor, log_probs: Tensor, idx: CpuRNNT_index)
Bases: object
Metadata for CPU-based RNNT loss calculation.
This class holds the working space memory and initializes the log probabilities for the RNNT model during loss computation.
Args: : T: Length of the acoustic sequence (without padding). U: Length of the target sequence (without padding). workspace: Working space memory for the CPU. bytes_used: Number of bytes currently used for indexing the <br/>
working space memory. Generally starts at 0. <br/> blank: Index of the blank token in the vocabulary. labels: Ground truth padded labels matrix of shape [B, U]. log_probs: Log probabilities / activation matrix of flattened <br/> shape [B, T, U, V+1]. <br/> idx: An instance of CpuRNNT_index for indexing purposes.
Attributes: : alphas: Memory for the forward variable (alpha) calculations. betas: Memory for the backward variable (beta) calculations. log_probs2: Memory for storing log probabilities of blank and <br/>
label tokens.
Examples: : ```python
T = 5 # Length of acoustic sequence U = 3 # Length of target sequence workspace = torch.zeros(100) # Example workspace tensor bytes_used = 0 blank = 0 # Index of blank token labels = torch.tensor([[1, 2, 3]]) # Example labels log_probs = torch.zeros((1, T, U, 4)) # Example log_probs idx = CpuRNNT_index(U, U, 1, 4, True) rnnt_metadata = CpuRNNT_metadata(T, U, workspace, bytes_used, ... blank, labels, log_probs, idx)
Note: : The memory allocation for alphas, betas, and log_probs2 is done using slices of the provided workspace tensor. Ensure that the workspace has sufficient size to accommodate these tensors.
Todo: : Consider adding error handling for invalid input shapes or types.
Metadata for CPU based RNNT loss calculation. Holds the working space memory.
- Parameters:
- T – Length of the acoustic sequence (without padding).
- U – Length of the target sequence (without padding).
- workspace – Working space memory for the CPU.
- bytes_used – Number of bytes currently used for indexing the working space memory. Generally 0.
- blank – Index of the blank token in the vocabulary.
- labels – Ground truth padded labels matrix of shape [B, U]
- log_probs – Log probs / activation matrix of flattented shape [B, T, U, V+1]
- idx
setup_probs(T: int, U: int, labels: Tensor, blank: int, log_probs: Tensor, idx: CpuRNNT_index)
Initializes the log probabilities for blank and label tokens.
This method sets up the log probabilities memory for the blank and label tokens by populating the log_probs2 tensor. The log probabilities are extracted from the log_probs tensor based on the provided indices and the dimensions of the target and acoustic sequences.
- Parameters:
- T – Length of the acoustic sequence (not padded).
- U – Length of the target sequence (not padded).
- labels – Tensor containing the ground truth labels, shape [B, U].
- blank – Index of the blank token in the vocabulary.
- log_probs – Log probabilities tensor of shape [B, T, U, V+1].
- idx – An instance of CpuRNNT_index for indexing purposes.
Examples
>>> labels = torch.tensor([[1, 2, 3], [1, 0, 2]])
>>> log_probs = torch.rand(2, 5, 4, 5) # Random log probs
>>> idx = CpuRNNT_index(4, 5, 2, 5, True)
>>> rnnt_metadata = CpuRNNT_metadata(5, 4, torch.empty(100), 0, 1, labels, log_probs, idx)
>>> rnnt_metadata.setup_probs(5, 4, labels, 1, log_probs, idx)
NOTE
The first blank token does not have an associated label.
- Raises:IndexError – If indices are out of bounds.