espnet2.enh.loss.wrappers.pit_solver.PITSolver
espnet2.enh.loss.wrappers.pit_solver.PITSolver
class espnet2.enh.loss.wrappers.pit_solver.PITSolver(criterion: AbsEnhLoss, weight=1.0, independent_perm=True, flexible_numspk=False)
Bases: AbsLossWrapper
Permutation Invariant Training Solver.
This class implements a solver for permutation invariant training (PIT), which is used to handle the permutation of speakers in the input. It calculates the loss for all permutations and selects the one with the minimum loss, making it suitable for multi-speaker scenarios.
criterion
An instance of AbsEnhLoss that defines the loss computation.
- Type:AbsEnhLoss
weight
Weight (between 0 and 1) of the current loss for multi-task learning.
- Type: float
independent_perm
If True, PIT will be performed in forward to find the best permutation; if False, the permutation from the last LossWrapper output will be inherited.
- Type: bool
flexible_numspk
If True, num_spk will be taken from inf to handle flexible numbers of speakers.
Type: bool
Parameters:
- criterion (AbsEnhLoss) – An instance of AbsEnhLoss.
- weight (float) – Weight (between 0 and 1) of the current loss for multi-task learning.
- independent_perm (bool) – If True, PIT will be performed in forward to find the best permutation; if False, the permutation from the last LossWrapper output will be inherited. NOTE (wangyou): Be cautious about the ordering of loss wrappers defined in the yaml config if this argument is False.
- flexible_numspk (bool) – If True, num_spk will be taken from inf to handle flexible numbers of speakers, as ref may include dummy data in this case.
forward(ref, inf, others={})
Computes the forward pass for the PIT solver, returning the minimum loss and the corresponding permutation.
####### Examples
Example usage:
pit_solver = PITSolver(criterion=my_criterion, weight=0.5) loss, stats, perm = pit_solver.forward(reference, inference)
- Raises:AssertionError – If flexible_numspk is False and the number of reference tensors does not match the number of inference tensors.
Permutation Invariant Training Solver.
- Parameters:
criterion (AbsEnhLoss) – an instance of AbsEnhLoss
weight (float) – weight (between 0 and 1) of current loss for multi-task learning.
independent_perm (bool) –
If True, PIT will be performed in forward to find the best permutation; If False, the permutation from the last LossWrapper output will be inherited. NOTE (wangyou): You should be careful about the ordering of loss
wrappers defined in the yaml config, if this argument is False.
flexible_numspk (bool) – If True, num_spk will be taken from inf to handle flexible numbers of speakers. This is because ref may include dummy data in this case.
forward(ref, inf, others={})
PITSolver forward method.
This method computes the loss for the Permutation Invariant Training (PIT) using the provided reference and inferred tensors. It evaluates different permutations to find the one that minimizes the loss.
- Parameters:
- ref (List *[*torch.Tensor ]) – A list of tensors representing the reference signals, with each tensor shaped as (batch, …). The length of the list should correspond to the number of speakers (n_spk).
- inf (List *[*torch.Tensor ]) – A list of tensors representing the inferred signals, shaped as (batch, …). The length of this list may vary based on the flexible_numspk attribute.
- others (dict , optional) – Additional parameters, which may include:
- “perm”: A predefined permutation order to be used if available.
- Returns: A tuple containing: : - loss (torch.Tensor): The minimum loss computed with the best : permutation.
- stats (dict): A dictionary containing the collected training : statistics.
- others (dict): A dictionary containing the permutation order used : in this PIT solver.
- Return type: tuple
- Raises:AssertionError – If flexible_numspk is False and the lengths of ref and inf do not match.
####### Examples
>>> ref = [torch.randn(2, 5) for _ in range(3)] # 3 speakers
>>> inf = [torch.randn(2, 5) for _ in range(3)] # 3 speakers
>>> solver = PITSolver(criterion=my_criterion)
>>> loss, stats, perm = solver.forward(ref, inf)
NOTE
The independent_perm argument controls whether to compute permutations independently or to use the last permutation from the LossWrapper. If set to False, be cautious about the ordering of loss wrappers defined in the YAML configuration.