espnet2.asr.pit_espnet_model.PITLossWrapper
espnet2.asr.pit_espnet_model.PITLossWrapper
class espnet2.asr.pit_espnet_model.PITLossWrapper(criterion_fn: Callable, num_ref: int)
Bases: AbsLossWrapper
Wrapper for Permutation Invariant Training (PIT) loss.
This class wraps a given loss function to compute the permutation invariant loss for multi-reference scenarios. It takes multiple inferences and references, calculates all possible permutations, and returns the minimum loss along with the optimal permutation.
criterion_fn
The loss function to be used for computing the loss for each reference-inference pair.
- Type: Callable
num_ref
The number of reference signals.
Type: int
Parameters:
- criterion_fn (Callable) – A callable loss function that takes inference and reference tensors along with their lengths.
- num_ref (int) – The number of reference signals for the PIT loss computation.
Returns: A tuple containing the mean minimum loss across all batches and the optimal permutation indices.
Return type: Tuple[torch.Tensor, torch.Tensor]
Raises:
- AssertionError – If the dimensions of the input tensors do not match
- the expected shapes. –
####### Examples
>>> import torch
>>> criterion = SomeLossFunction() # Replace with an actual loss function
>>> pit_loss_wrapper = PITLossWrapper(criterion_fn=criterion, num_ref=2)
>>> inf = torch.randn(5, 2, 10) # (batch, num_inf, features)
>>> inf_lens = torch.tensor([10] * 5) # (batch,)
>>> ref = torch.randn(5, 2, 10) # (batch, num_ref, features)
>>> ref_lens = torch.tensor([10] * 5) # (batch,)
>>> loss, opt_perm = pit_loss_wrapper(inf, inf_lens, ref, ref_lens)
>>> print(loss, opt_perm)
NOTE
The loss function used should be capable of handling the inputs as defined in the forward method. Ensure that the number of inferences equals the number of references for proper functionality.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(inf: Tensor, inf_lens: Tensor, ref: Tensor, ref_lens: Tensor, others: Dict | None = None)
Computes the Permutation Invariant Training (PIT) loss using a provided criterion function for multiple references.
This method takes in inference and reference tensors, along with their lengths, and computes the optimal permutation of the references to minimize the loss. The PIT loss is particularly useful in scenarios where the order of references may vary, such as in speech separation tasks.
- Parameters:
- inf (torch.Tensor) – Inference tensor of shape (batch, num_inf, …).
- inf_lens (torch.Tensor) – Lengths of the inference tensors, shape (batch, num_inf).
- ref (torch.Tensor) – Reference tensor of shape (batch, num_ref, …).
- ref_lens (torch.Tensor) – Lengths of the reference tensors, shape (batch, num_ref).
- others (Dict , optional) – Additional keyword arguments that may be required by the criterion function.
- Returns: A tuple containing: : - The mean of the minimum losses across the batch.
- The optimal permutation of the references as a tensor of shape <br/> (batch_size, num_ref).
- Return type: Tuple[torch.Tensor, torch.Tensor]
- Raises:
- AssertionError – If the number of references does not match the shapes
- of the input tensors. –
####### Examples
>>> inf = torch.rand(2, 3, 10) # Example inference
>>> inf_lens = torch.tensor([[10, 9, 8], [10, 10, 10]])
>>> ref = torch.rand(2, 2, 10) # Example references
>>> ref_lens = torch.tensor([[10, 9], [10, 10]])
>>> pit_loss_wrapper = PITLossWrapper(criterion_fn=some_loss_function,
... num_ref=2)
>>> loss, optimal_perm = pit_loss_wrapper.forward(inf, inf_lens, ref, ref_lens)
>>> print(loss, optimal_perm)
NOTE
Ensure that the number of references (num_ref) specified during initialization matches the second dimension of the inference and reference tensors.
classmethod permutate(perm, *args)