espnet2.enh.loss.wrappers.abs_wrapper.AbsLossWrapper
espnet2.enh.loss.wrappers.abs_wrapper.AbsLossWrapper
class espnet2.enh.loss.wrappers.abs_wrapper.AbsLossWrapper(*args, **kwargs)
Bases: Module
, ABC
Base class for all Enhancement loss wrapper modules.
This abstract class serves as a blueprint for creating various loss wrapper modules used in enhancement tasks. It inherits from torch.nn.Module and requires implementing the forward method.
weight
The weight for the current loss in the multi-task
- Type: float
learning. The overall training target will be combined as
loss = weight_1 * loss_1 + ... + weight_N * loss_N.
forward(ref
List, inf: List, others: Dict) -> Tuple[torch.Tensor,
Dict, Dict]
Abstract method that must be implemented by subclasses to compute the loss.
- Raises:
- NotImplementedError – If the forward method is not implemented
- in a subclass. –
####### Examples
To create a custom loss wrapper, subclass AbsLossWrapper and implement the forward method:
``
`
python class CustomLossWrapper(AbsLossWrapper):
def forward(self, ref: List, inf: List, others: Dict) -> Tuple[torch.Tensor, Dict, Dict]:
Implement loss computation here
pass
``
`
Initialize internal Module state, shared by both nn.Module and ScriptModule.
abstract forward(ref: List, inf: List, others: Dict) → Tuple[Tensor, Dict, Dict]
Compute the forward pass for the loss calculation.
This method is an abstract method that must be implemented by any subclass of AbsLossWrapper. It takes reference and inferred values along with additional parameters to compute the loss.
- Parameters:
- ref (List) – A list of reference values, which serve as the ground truth for comparison.
- inf (List) – A list of inferred values, which are the predictions generated by the model.
- others (Dict) – A dictionary containing additional parameters that may be required for loss computation.
- Returns: A tuple containing: : - A tensor representing the computed loss.
- A dictionary with any additional outputs from the loss calculation.
- A dictionary containing metrics or other information relevant to the loss computation.
- Return type: Tuple[torch.Tensor, Dict, Dict]
- Raises:NotImplementedError – If this method is not overridden in a subclass.
####### Examples
>>> loss_wrapper = MyLossWrapper()
>>> ref = [torch.tensor(1.0), torch.tensor(2.0)]
>>> inf = [torch.tensor(1.5), torch.tensor(1.8)]
>>> others = {'some_param': 0.5}
>>> loss, outputs, metrics = loss_wrapper.forward(ref, inf, others)
weight