espnet2.asr.transducer.rnnt_multi_blank.utils.cpu_utils.cpu_rnnt.CPURNNT
espnet2.asr.transducer.rnnt_multi_blank.utils.cpu_utils.cpu_rnnt.CPURNNT
class espnet2.asr.transducer.rnnt_multi_blank.utils.cpu_utils.cpu_rnnt.CPURNNT(minibatch: int, maxT: int, maxU: int, alphabet_size: int, workspace: Tensor, blank: int, fastemit_lambda: float, clamp: float, num_threads: int, batch_first: bool)
Bases: object
Helper class to compute the Transducer Loss on CPU.
This class provides an implementation of the RNNT (Recurrent Neural Network Transducer) loss calculation on CPU. It manages the required workspace and computes the forward and backward variables for the RNNT.
minibatch_
Size of the minibatch b.
maxT_
The maximum possible acoustic sequence length (T).
maxU_
The maximum possible target sequence length (U).
alphabet_size_
The vocabulary dimension (V+1, inclusive of RNNT blank).
workspace
An allocated chunk of memory that will be sliced and reshaped into required blocks used as working memory.
blank_
Index of the RNNT blank token in the vocabulary.
fastemit_lambda_
Float scaling factor for FastEmit regularization.
clamp_
Float value for clamping gradients.
num_threads_
Number of OMP threads to launch.
batch_first
Bool that decides if batch dimension is first or third.
- Parameters:
- minibatch – Size of the minibatch b.
- maxT – The maximum possible acoustic sequence length.
- maxU – The maximum possible target sequence length.
- alphabet_size – The vocabulary dimension V+1 (inclusive of RNNT blank).
- workspace – An allocated chunk of memory that will be sliced off and reshaped into required blocks used as working memory.
- blank – Index of the RNNT blank token in the vocabulary.
- fastemit_lambda – Float scaling factor for FastEmit regularization.
- clamp – Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp].
- num_threads – Number of OMP threads to launch.
- batch_first – Bool that decides if batch dimension is first or third.
############### Examples
>>> import torch
>>> cpurnnt = CPURNNT(
... minibatch=32,
... maxT=100,
... maxU=50,
... alphabet_size=30,
... workspace=torch.zeros(10000),
... blank=0,
... fastemit_lambda=0.1,
... clamp=1.0,
... num_threads=4,
... batch_first=True
... )
>>> log_probs = torch.rand(32, 100, 50, 31) # Log probabilities tensor
>>> grads = torch.zeros_like(log_probs)
>>> costs = torch.zeros(32)
>>> flat_labels = torch.randint(0, 30, (32, 50))
>>> label_lengths = torch.randint(1, 50, (32,))
>>> input_lengths = torch.randint(1, 100, (32,))
>>> status = cpurnnt.cost_and_grad(log_probs, grads, costs, flat_labels,
... label_lengths, input_lengths)
- Returns: Status of the RNNT computation.
- Return type:global_constants.RNNTStatus
Helper class to compute the Transducer Loss on CPU.
- Parameters:
- minibatch – Size of the minibatch b.
- maxT – The maximum possible acoustic sequence length. Represents T in the logprobs tensor.
- maxU – The maximum possible target sequence length. Represents U in the logprobs tensor.
- alphabet_size – The vocabulary dimension V+1 (inclusive of RNNT blank).
- workspace – An allocated chunk of memory that will be sliced off and reshaped into required blocks used as working memory.
- blank – Index of the RNNT blank token in the vocabulary. Generally the first or last token in the vocab.
- fastemit_lambda – Float scaling factor for FastEmit regularization. Refer to FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization.
- clamp – Float value. When set to value >= 0.0, will clamp the gradient to [-clamp, clamp].
- num_threads – Number of OMP threads to launch.
- batch_first – Bool that decides if batch dimension is first or third.
compute_alphas(log_probs: Tensor, T: int, U: int, alphas: Tensor)
Compute the probability of the forward variable alpha.
This method calculates the forward probabilities (alphas) for the RNNT (Recurrent Neural Network Transducer) given the log probabilities of the outputs. It fills the alphas tensor with the computed values, which are used in the computation of the RNNT loss.
- Parameters:
- log_probs – A flattened tensor of shape [B, T, U, V+1] representing the log probabilities of the model outputs. B is the batch size, T is the length of the acoustic sequence (not padded), U is the length of the target sequence (not padded), and V is the size of the vocabulary.
- T – Length of the acoustic sequence T (not padded).
- U – Length of the target sequence U (not padded).
- alphas – A tensor of shape [B, T, U] that serves as the working space memory for the alpha values.
- Returns: Loglikelihood of the forward variable alpha.
############### Examples
>>> log_probs = torch.randn(1, 5, 3, 4) # Example log probabilities
>>> T = 5 # Length of acoustic sequence
>>> U = 3 # Length of target sequence
>>> alphas = torch.zeros(1, T, U) # Initialize alphas
>>> log_likelihood = cpurnnt.compute_alphas(log_probs, T, U, alphas)
>>> print(log_likelihood) # Loglikelihood of the forward variable
######## NOTE The computation is performed in-place on the alphas tensor, which should be allocated beforehand with the appropriate shape.
compute_betas_and_grads(grad: Tensor, log_probs: Tensor, T: int, U: int, alphas: Tensor, betas: Tensor, labels: Tensor, logll: Tensor)
Compute the backward variable beta and gradients of the activation matrix.
This function calculates the backward variable (beta) and the gradients of the activation matrix with respect to the log-likelihood of the forward variable. The results are used in training the RNNT model by updating the gradients based on the computed values.
- Parameters:
- grad – Working space memory of flattened shape [B, T, U, V+1], which will be updated in place with gradients.
- log_probs – Activation tensor of flattened shape [B, T, U, V+1], which contains the log probabilities for each time step and target label.
- T – Length of the acoustic sequence T (not padded).
- U – Length of the target sequence U (not padded).
- alphas – Working space memory for alpha of shape [B, T, U], which stores the forward variable probabilities.
- betas – Working space memory for beta of shape [B, T, U], which will store the computed backward variable probabilities.
- labels – Ground truth label tensor of shape [B, U] representing the target sequences.
- logll – Log-likelihood of the forward variable, which is used to normalize the gradients.
- Returns: Loglikelihood of the forward variable and in-place updates to the grad tensor.
############### Examples
>>> grad = torch.zeros((batch_size, max_T, max_U, vocab_size))
>>> log_probs = torch.randn((batch_size, max_T, max_U, vocab_size))
>>> alphas = torch.zeros((batch_size, max_T, max_U))
>>> betas = torch.zeros((batch_size, max_T, max_U))
>>> labels = torch.randint(0, vocab_size, (batch_size, max_U))
>>> logll = torch.tensor(0.0)
>>> loglikelihood = compute_betas_and_grads(grad, log_probs, T, U, alphas, betas, labels, logll)
######## NOTE This method is intended for internal use within the RNNT training process and should not be called directly in user code.
cost_and_grad(log_probs: Tensor, grads: Tensor, costs: Tensor, flat_labels: Tensor, label_lengths: Tensor, input_lengths: Tensor) → RNNTStatus
Computes the cost and gradients for the RNNT loss on CPU.
This function calculates the RNNT loss for a batch of sequences, storing the computed gradients in the provided gradient tensor. The loss is computed using the forward and backward probabilities (alphas and betas) and is scaled according to the FastEmit regularization factor. The function iterates over each example in the minibatch, calling the cost_and_grad_kernel to perform the necessary calculations.
- Parameters:
- log_probs (torch.Tensor) – Log probabilities tensor of shape [B, T, U, V+1], where B is the batch size, T is the length of the acoustic sequence, U is the length of the target sequence, and V is the vocabulary size (including blank).
- grads (torch.Tensor) – Tensor to store the computed gradients of shape [B, T, U, V+1].
- costs (torch.Tensor) – Tensor to store the computed costs for each example in the minibatch of shape [B].
- flat_labels (torch.Tensor) – Ground truth labels in a flattened tensor of shape [B, U].
- label_lengths (torch.Tensor) – Lengths of each label sequence in the minibatch of shape [B].
- input_lengths (torch.Tensor) – Lengths of each input sequence in the minibatch of shape [B].
- Returns: The status of the RNNT computation, : indicating success or failure.
- Return type:global_constants.RNNTStatus
############### Examples
>>> log_probs = torch.rand((2, 5, 4, 10)) # Example log probs
>>> grads = torch.zeros_like(log_probs) # Initialize grads
>>> costs = torch.zeros(2) # Initialize costs
>>> flat_labels = torch.randint(0, 9, (2, 3)) # Random labels
>>> label_lengths = torch.tensor([3, 3]) # Lengths of labels
>>> input_lengths = torch.tensor([5, 5]) # Lengths of inputs
>>> rnnt = CPURNNT(...) # Initialize CPURNNT
>>> status = rnnt.cost_and_grad(log_probs, grads, costs,
... flat_labels, label_lengths,
... input_lengths)
######## NOTE Ensure that the input tensors are properly shaped and contain valid data before calling this function.
- Raises:
- ValueError – If the dimensions of the input tensors do not match
- expected shapes or if any tensor contains invalid values. –
cost_and_grad_kernel(log_probs: Tensor, grad: Tensor, labels: Tensor, mb: int, T: int, U: int, bytes_used: int)
Computes the cost and gradients for the RNNT loss on the CPU.
This method utilizes the log probabilities of the model outputs to compute the forward and backward variables (alphas and betas) and subsequently calculates the gradients for the RNNT loss. It also manages the necessary workspace memory for the computations.
- Parameters:
- log_probs (torch.Tensor) – A tensor containing the log probabilities of shape [B, T, U, V+1], where B is the batch size, T is the length of the acoustic sequence, U is the length of the target sequence, and V is the vocabulary size.
- grad (torch.Tensor) – A tensor to store the computed gradients of shape [B, T, U, V+1].
- labels (torch.Tensor) – A tensor containing the ground truth labels of shape [B, U].
- mb (int) – The current minibatch index.
- T (int) – The length of the acoustic sequence (not padded).
- U (int) – The length of the target sequence (not padded).
- bytes_used (int) – The number of bytes currently used in the workspace memory.
- Returns: The negative log likelihood of the RNNT loss for the given minibatch.
- Return type: float
############### Examples
>>> log_probs = torch.randn(2, 10, 5, 6) # Example log probabilities
>>> grad = torch.zeros_like(log_probs)
>>> labels = torch.randint(0, 5, (2, 5)) # Example labels
>>> cpurnnt = CPURNNT(minibatch=2, maxT=10, maxU=5, alphabet_size=6,
... workspace=torch.zeros(1024), blank=0,
... fastemit_lambda=0.5, clamp=1.0,
... num_threads=1, batch_first=True)
>>> loss = cpurnnt.cost_and_grad_kernel(log_probs, grad, labels, 0, 10, 5, 0)
>>> print(loss) # Output the loss value
######## NOTE This method is designed for use in the context of the RNNT loss calculation and assumes that the input tensors are correctly shaped and pre-allocated.
- Raises:
- ValueError – If the input tensors have incorrect dimensions or
- incompatible shapes. –
score_forward(log_probs: Tensor, costs: Tensor, flat_labels: Tensor, label_lengths: Tensor, input_lengths: Tensor)
Computes the forward score for a batch of input sequences using the RNNT loss.
This method calculates the log likelihood of the forward variable (alpha) for each sequence in the minibatch based on the provided log probabilities and labels. The results are stored in the costs tensor.
- Parameters:
- log_probs (torch.Tensor) – A tensor of shape [B, T, U, V+1] containing the log probabilities for each token in the vocabulary at each time step, where B is the batch size, T is the length of the acoustic sequence, U is the length of the target sequence, and V is the size of the vocabulary (excluding the blank token).
- costs (torch.Tensor) – A tensor of shape [B] where the computed negative log likelihoods will be stored for each sequence in the minibatch.
- flat_labels (torch.Tensor) – A tensor containing the ground truth labels for each sequence in the minibatch, padded to the maximum target length.
- label_lengths (torch.Tensor) – A tensor containing the actual lengths of the labels for each sequence in the minibatch.
- input_lengths (torch.Tensor) – A tensor containing the lengths of the input sequences for each sequence in the minibatch.
- Returns: A status indicator indicating the success or failure of the computation.
- Return type:global_constants.RNNTStatus
############### Examples
>>> log_probs = torch.randn(2, 5, 4, 3) # Example log probabilities
>>> costs = torch.zeros(2) # Initialize costs tensor
>>> flat_labels = torch.tensor([[1, 2], [1, 3]]) # Example labels
>>> label_lengths = torch.tensor([2, 2]) # Lengths of labels
>>> input_lengths = torch.tensor([5, 5]) # Lengths of inputs
>>> rnnt.score_forward(log_probs, costs, flat_labels,
... label_lengths, input_lengths)
>>> print(costs) # Should contain the negative log likelihoods