espnet2.asr.pit_espnet_model.PITLossWrapper
Less than 1 minute
espnet2.asr.pit_espnet_model.PITLossWrapper
class espnet2.asr.pit_espnet_model.PITLossWrapper(criterion_fn: Callable, num_ref: int)
Bases: AbsLossWrapper
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(inf: Tensor, inf_lens: Tensor, ref: Tensor, ref_lens: Tensor, others: Dict = None)
PITLoss Wrapper function. Similar to espnet2/enh/loss/wrapper/pit_solver.py
- Parameters:
- inf β Iterable[torch.Tensor], (batch, num_inf, β¦)
- inf_lens β Iterable[torch.Tensor], (batch, num_inf, β¦)
- ref β Iterable[torch.Tensor], (batch, num_ref, β¦)
- ref_lens β Iterable[torch.Tensor], (batch, num_ref, β¦)
- permute_inf β If true, permute the inference and inference_lens according to the optimal permutation.
classmethod permutate(perm, *args)
