espnet2.legacy.nets.e2e_mt_common.ErrorCalculator
Less than 1 minute
espnet2.legacy.nets.e2e_mt_common.ErrorCalculator
class espnet2.legacy.nets.e2e_mt_common.ErrorCalculator(char_list, sym_space, sym_pad, report_bleu=False)
Bases: object
Calculate BLEU for ST and MT models during training.
- Parameters:
- y_hats β numpy array with predicted text
- y_pads β numpy array with true (target) text
- char_list β vocabulary list
- sym_space β space symbol
- sym_pad β pad symbol
- report_bleu β report BLUE score if True
Construct an ErrorCalculator object.
calculate_bleu_ctc(ys_hat, ys_pad)
Calculate sentence-level BLEU score for CTC.
- Parameters:
- ys_hat (torch.Tensor) β prediction (batch, seqlen)
- ys_pad (torch.Tensor) β reference (batch, seqlen)
- Returns: corpus-level BLEU score
:rtype float
calculate_corpus_bleu(ys_hat, ys_pad)
Calculate corpus-level BLEU score in a mini-batch.
- Parameters:
- seqs_hat (torch.Tensor) β prediction (batch, seqlen)
- seqs_true (torch.Tensor) β reference (batch, seqlen)
- Returns: corpus-level BLEU score
:rtype float
