espnet2.enh.loss.wrappers.multilayer_pit_solver.MultiLayerPITSolver
espnet2.enh.loss.wrappers.multilayer_pit_solver.MultiLayerPITSolver
class espnet2.enh.loss.wrappers.multilayer_pit_solver.MultiLayerPITSolver(criterion: AbsEnhLoss, weight=1.0, independent_perm=True, layer_weights=None)
Bases: AbsLossWrapper
Multi-Layer Permutation Invariant Training Solver.
This class computes the Permutation Invariant Training (PIT) loss using inferences from multiple layers and a single reference. It also supports single inference and single reference during the evaluation stage.
criterion
An instance of AbsEnhLoss used for computing loss.
- Type:AbsEnhLoss
weight
Weight (between 0 and 1) of the current loss for multi-task learning.
- Type: float
independent_perm
If True, performs PIT in forward to find the best permutation; if False, inherits the permutation from the last LossWrapper output.
- Type: bool
layer_weights
Weights for each layer; if not None, the loss of each layer will be weighted-summed using the specified weights.
- Type: Optional[List[float]]
solver
Instance of PITSolver to handle the PIT logic.
Type:PITSolver
Parameters:
- criterion (AbsEnhLoss) – An instance of AbsEnhLoss.
- weight (float) – Weight (between 0 and 1) of the current loss for multi-task learning. Defaults to 1.0.
- independent_perm (bool) – If True, PIT will be performed in forward to find the best permutation; if False, inherits permutation from the last LossWrapper output. Defaults to True.
- layer_weights (Optional *[*List *[*float ] ]) – Weights for each layer. If not None, the loss of each layer will be weighted-summed using the specified weights. Defaults to None.
forward(ref, infs, others={})
Computes the minimum PIT loss with the best permutation and returns the loss, statistics, and permutation order.
####### Examples
>>> criterion = SomeCriterion()
>>> solver = MultiLayerPITSolver(criterion, weight=0.5)
>>> ref = [torch.randn(10, 2), torch.randn(10, 2)]
>>> infs = [[torch.randn(10, 2)], [torch.randn(10, 2)]]
>>> loss, stats, others = solver.forward(ref, infs)
NOTE
Be cautious about the ordering of loss wrappers defined in the YAML config when setting independent_perm to False.
Multi-Layer Permutation Invariant Training Solver.
Compute the PIT loss given inferences of multiple layers and a single reference. It also support single inference and single reference in evaluation stage.
- Parameters:
- criterion (AbsEnhLoss) – an instance of AbsEnhLoss
- weight (float) – weight (between 0 and 1) of current loss for multi-task learning.
- independent_perm (bool) – If True, PIT will be performed in forward to find the best permutation; If False, the permutation from the last LossWrapper output will be inherited. Note: You should be careful about the ordering of loss wrappers defined in the yaml config, if this argument is False.
- layer_weights (Optional *[*List *[*float ] ]) – weights for each layer If not None, the loss of each layer will be weighted-summed using the specified weights.
forward(ref, infs, others={})
Permutation Invariant Training Solver.
This method computes the minimum loss with the best permutation based on the provided references and inferences. It supports both single-layer and multi-layer cases, allowing for flexible handling of audio source separation tasks.
- Parameters:
- ref (List *[*torch.Tensor ]) – A list of tensors representing the reference signals, structured as [(batch, …), …] for n_spk.
- infs (Union *[*List *[*torch.Tensor ] , List *[*List *[*torch.Tensor ] ] ]) – A list of tensors representing the inferences. In the single-layer case, it should be structured as [(batch, …), …]. In the multi-layer case, it should be structured as List[List[torch.Tensor]].
- others (dict , optional) – Additional arguments for training status collection or any other necessary data. Defaults to an empty dictionary.
- Returns: A tuple containing: : - loss (torch.Tensor): The computed minimum loss with the best permutation.
- stats (dict): A dictionary for collecting training status.
- others (dict): A dictionary that returns the permutation order used in this PIT solver.
- Return type: Tuple[torch.Tensor, dict, dict]
####### Examples
>>> ref = [torch.randn(10, 2), torch.randn(10, 2)]
>>> infs = [torch.randn(10, 2), torch.randn(10, 2)]
>>> solver = MultiLayerPITSolver(criterion)
>>> loss, stats, others = solver.forward(ref, infs)
NOTE
This function is designed to work with the MultiLayerPITSolver class, which handles the complexity of permutation invariant training.
- Raises:
- ValueError – If the shapes of ref and infs do not match in the
- single-layer case. –