espnet2.enh.loss.wrappers.dpcl_solver.DPCLSolver
espnet2.enh.loss.wrappers.dpcl_solver.DPCLSolver
class espnet2.enh.loss.wrappers.dpcl_solver.DPCLSolver(criterion: AbsEnhLoss, weight=1.0)
Bases: AbsLossWrapper
DPCLSolver is a wrapper for a Deep Permutation Invariant Contrastive Learning
(DPCL) loss function. This class inherits from the AbsLossWrapper and is designed to compute the minimum loss with the best permutation of the input data. It utilizes a given enhancement loss criterion for calculating the loss value.
criterion
The loss criterion used for computing the loss.
- Type:AbsEnhLoss
weight
A scaling factor for the loss.
Type: float
Parameters:
- criterion (AbsEnhLoss) – The enhancement loss criterion to be used.
- weight (float , optional) – The weight for the loss. Defaults to 1.0.
forward(ref, inf, others={})
Computes the loss based on the reference and input tensors.
- Returns:
- loss (torch.Tensor): The minimum loss with the best permutation.
- stats (dict): A dictionary containing training status statistics.
- others (dict): Reserved for future use.
- Return type: Tuple[torch.Tensor, dict, dict]
- Raises:AssertionError – If “tf_embedding” is not included in the others argument.
####### Examples
>>> criterion = SomeEnhLoss()
>>> dpcl_solver = DPCLSolver(criterion)
>>> ref = [torch.randn(2, 5), torch.randn(2, 5)] # Two speakers
>>> inf = [torch.randn(2, 5)] # Single input
>>> others = {"tf_embedding": torch.randn(2, 10, 5)} # Example embedding
>>> loss, stats, _ = dpcl_solver.forward(ref, inf, others)
NOTE
The “tf_embedding” should be included in the others argument for proper functioning of the forward method.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(ref, inf, others={})
A naive DPCL solver for calculating the minimum loss with the best permutation.
This method computes the loss between the reference and inferred signals, utilizing a learned embedding of time-frequency (T-F) bins. The DPCL solver aims to optimize the loss based on permutations of the input signals.
criterion
The criterion used to calculate the loss.
- Type:AbsEnhLoss
weight
Weighting factor for the loss, default is 1.0.
Type: float
Parameters:
- ref (List *[*torch.Tensor ]) – A list of tensors representing the reference signals, structured as [(batch, …), …] for n_spk speakers.
- inf (List *[*torch.Tensor ]) – A list of tensors representing the inferred signals, structured as [(batch, …), …].
- others (dict) – Additional data required by this solver, must include:
- “tf_embedding”: A learned embedding of all T-F bins with shape (B, T * F, D).
Returns: A tuple containing: : - loss (torch.Tensor): The minimum loss calculated with the best permutation.
- stats (dict): A dictionary for collecting training status, includes: : - criterion name and corresponding loss value.
- others (dict): Reserved for future use.
Return type: Tuple[torch.Tensor, dict, dict]
Raises:AssertionError – If “tf_embedding” is not present in the others argument.
####### Examples
>>> solver = DPCLSolver(criterion=my_criterion)
>>> loss, stats, _ = solver.forward(reference_signals, inferred_signals,
... others={"tf_embedding": tf_embeddings})
NOTE
Ensure that the “tf_embedding” is correctly provided in the others dictionary for the forward method to function properly.