espnet2.enh.loss.wrappers.mixit_solver.MixITSolver
espnet2.enh.loss.wrappers.mixit_solver.MixITSolver
class espnet2.enh.loss.wrappers.mixit_solver.MixITSolver(criterion: AbsEnhLoss, weight: float = 1.0)
Bases: AbsLossWrapper
MixITSolver is a Mixture Invariant Training Solver that extends the
AbsLossWrapper class for multi-task learning in speech enhancement.
This solver computes the loss by exploring all possible permutations of the estimated sources against the reference sources to find the minimum loss. It is designed to work with both real and complex tensors.
criterion
An instance of AbsEnhLoss used to compute the loss.
- Type:AbsEnhLoss
weight
Weight (between 0 and 1) of the current loss for multi-task learning.
Type: float
Parameters:
- criterion (AbsEnhLoss) – An instance of AbsEnhLoss.
- weight (float) – Weight (between 0 and 1) of the current loss for multi-task learning.
Returns:
- loss (torch.Tensor): Minimum loss with the best permutation.
- stats (Dict[str, torch.Tensor]): A dictionary for collecting training status.
- others (Dict[str, torch.Tensor]): In this PIT solver, the permutation order will be returned.
Return type: Tuple[torch.Tensor, Dict[str, torch.Tensor], Dict[str, torch.Tensor]]
####### Examples
>>> criterion = SomeEnhLoss() # Replace with an actual criterion
>>> mixit_solver = MixITSolver(criterion, weight=0.5)
>>> ref = [torch.randn(2, 10), torch.randn(2, 10)]
>>> inf = [torch.randn(2, 10), torch.randn(2, 10),
torch.randn(2, 10), torch.randn(2, 10)]
>>> loss, stats, others = mixit_solver(ref, inf)
NOTE
This class requires that the input tensors are either all real or all complex. The permutation of the estimated sources is determined to minimize the loss against the reference sources.
Mixture Invariant Training Solver.
- Parameters:
- criterion (AbsEnhLoss) – an instance of AbsEnhLoss
- weight (float) – weight (between 0 and 1) of current loss for multi-task learning.
forward(ref: List[Tensor] | List[ComplexTensor], inf: List[Tensor] | List[ComplexTensor], others: Dict = {})
MixIT solver for calculating the minimum loss with the best permutation.
This method computes the loss for a mixture of inputs using the MixIT approach. It evaluates all possible permutations of the reference and estimated tensors to find the one that yields the minimum loss. The function also returns statistics for training status and the permutation order used in the loss calculation.
- Parameters:
- ref (Union *[*List *[*torch.Tensor ] , List *[*ComplexTensor ] ]) – A list of reference tensors, where each tensor has shape [(batch, …), …] for n_spk speakers.
- inf (Union *[*List *[*torch.Tensor ] , List *[*ComplexTensor ] ]) – A list of estimated tensors, where each tensor has shape [(batch, …), …] for n_est estimates.
- others (Dict , optional) – A dictionary for any additional parameters. Defaults to an empty dictionary.
- Returns:
- loss (torch.Tensor): The minimum loss calculated with the best permutation.
- stats (Dict): A dictionary containing training status metrics.
- others (Dict): A dictionary containing the permutation order used in this PIT solver.
- Return type: Tuple[torch.Tensor, Dict, Dict]
####### Examples
>>> ref = [torch.rand(2, 3), torch.rand(2, 3)] # Two reference tensors
>>> inf = [torch.rand(2, 3), torch.rand(2, 3),
... torch.rand(2, 3), torch.rand(2, 3)] # Four estimates
>>> loss, stats, others = mixit_solver.forward(ref, inf)
>>> print(loss) # Outputs the minimum loss
>>> print(others['perm']) # Outputs the permutation order
NOTE
Ensure that the input tensors are either all complex or all real tensors for proper computation. The function asserts that the types of ref and inf match.
- Raises:
- AssertionError – If the number of estimated tensors is not double
- the number of reference tensors or if the input types do not match. –
property name