espnet2.enh.loss.criterions.abs_loss.AbsEnhLoss
espnet2.enh.loss.criterions.abs_loss.AbsEnhLoss
class espnet2.enh.loss.criterions.abs_loss.AbsEnhLoss(*args, **kwargs)
Bases: Module
, ABC
Base class for all Enhancement loss modules.
This class serves as an abstract base class for defining different types of enhancement loss functions used in audio processing tasks. It provides a structure for loss modules that can be further extended to implement specific loss functions.
name
The name of the loss module, which will be used as a key in the reporter. Must be implemented in derived classes.
- Type: str
only_for_test
A boolean flag indicating whether the criterion will only be evaluated during the inference stage. Defaults to False.
- Type: bool
forward(ref, inf) → torch.Tensor
Computes the enhancement loss based on reference and inferred signals. Must be implemented in derived classes.
- Raises:NotImplementedError – If the forward method is not implemented in a derived class.
Examples
To create a custom enhancement loss, subclass AbsEnhLoss and implement the forward method:
``
`
python class CustomLoss(AbsEnhLoss):
@property def name(self) -> str:
return “custom_loss”
def forward(self, ref, inf) -> torch.Tensor: : # Custom loss computation logic return torch.mean((ref - inf) ** 2)
``
`
NOTE
This class is intended to be subclassed. It should not be instantiated directly.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
abstract forward(ref, inf) → Tensor
Define the computation performed at every call.
Should be overridden by all subclasses.
NOTE
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
property name : str
property only_for_test : bool