espnet2.asr.transducer.error_calculator.ErrorCalculatorTransducer
espnet2.asr.transducer.error_calculator.ErrorCalculatorTransducer
class espnet2.asr.transducer.error_calculator.ErrorCalculatorTransducer(decoder: AbsDecoder, joint_network: Module, token_list: List[int], sym_space: str, sym_blank: str, report_cer: bool = False, report_wer: bool = False)
Bases: object
Error Calculator module for Transducer.
This class calculates Character Error Rate (CER) and Word Error Rate (WER) for transducer models. It utilizes a decoder module and a joint network to perform beam search for predictions, and then compares those predictions against target sequences to compute error rates.
decoder
The decoder module used for generating predictions.
- Type:AbsDecoder
token_list
List of tokens used for converting IDs to characters.
- Type: List[int]
sym_space
The symbol representing space in the token list.
- Type: str
sym_blank
The symbol representing blank in the token list.
- Type: str
report_cer
Flag indicating whether to compute CER.
- Type: bool
report_wer
Flag indicating whether to compute WER.
Type: bool
Parameters:
- decoder – An instance of AbsDecoder, the decoder module.
- joint_network – A torch.nn.Module representing the joint network.
- token_list – A list of integer tokens corresponding to characters.
- sym_space – A string representing the space symbol.
- sym_blank – A string representing the blank symbol.
- report_cer – A boolean indicating whether to report CER (default: False).
- report_wer – A boolean indicating whether to report WER (default: False).
########### Examples
>>> error_calculator = ErrorCalculatorTransducer(
... decoder=my_decoder,
... joint_network=my_joint_network,
... token_list=my_token_list,
... sym_space='<space>',
... sym_blank='<blank>',
... report_cer=True,
... report_wer=True
... )
>>> cer, wer = error_calculator(encoder_output, target_labels)
>>> print(f"CER: {cer}, WER: {wer}")
- Raises:ValueError – If the input shapes of encoder_out and target are not compatible.
######## NOTE The predictions and targets are processed as character sequences, where blank symbols are ignored and space symbols are converted to spaces.
Construct an ErrorCalculatorTransducer.
calculate_cer(char_pred: Tensor, char_target: Tensor) → float
Calculate sentence-level Character Error Rate (CER) score.
This method computes the CER by comparing predicted character sequences against target character sequences. The CER is defined as the edit distance between the predicted and target sequences divided by the length of the target sequence, ignoring spaces.
- Parameters:
- char_pred – Prediction character sequences. Shape: (B, ?), where B is the batch size and ? is the variable length of the predicted sequences.
- char_target – Target character sequences. Shape: (B, ?), where B is the batch size and ? is the variable length of the target sequences.
- Returns: Average sentence-level CER score across the batch.
- Return type: float
- Raises:ZeroDivisionError – If the total length of the target sequences is zero.
########### Examples
>>> char_pred = ["hello", "world"]
>>> char_target = ["hallo", "word"]
>>> cer = calculate_cer(char_pred, char_target)
>>> print(cer)
0.25 # Example output based on edit distance calculation
######## NOTE The method uses the editdistance library to compute the edit distance between character sequences. Make sure to have it installed in your environment.
calculate_wer(char_pred: Tensor, char_target: Tensor) → float
Calculate sentence-level WER score.
This method computes the Word Error Rate (WER) based on the predicted character sequences and the target character sequences. WER is defined as the number of word-level errors divided by the total number of words in the reference (target) text. The errors can be substitutions, deletions, or insertions.
- Parameters:
- char_pred – Prediction character sequences. (B, ?)
- char_target – Target character sequences. (B, ?)
- Returns: Average sentence-level WER score.
- Return type: float
########### Examples
>>> char_pred = ["this is a test", "hello world"]
>>> char_target = ["this is test", "hello there world"]
>>> wer_score = calculate_wer(char_pred, char_target)
>>> print(wer_score)
0.25
######## NOTE The WER calculation uses the edit distance algorithm to evaluate the differences between predicted and target sequences. It is recommended to preprocess the input sequences to ensure consistent formatting (e.g., removing extra spaces).
convert_to_char(pred: Tensor, target: Tensor) → Tuple[List, List]
Convert label ID sequences to character sequences.
This method takes the predicted and target label ID sequences and converts them into their corresponding character representations based on the provided token list. It handles special symbols, replacing space symbols with spaces and removing blank symbols.
- Parameters:
- pred – Prediction label ID sequences. Shape: (B, U), where B is the batch size and U is the maximum length of predicted sequences.
- target – Target label ID sequences. Shape: (B, L), where L is the maximum length of target sequences.
- Returns: char_pred: List of prediction character sequences. Shape: (B, ?). char_target: List of target character sequences. Shape: (B, ?).
- Return type: Tuple[List[str], List[str]]
########### Examples
>>> token_list = ['a', 'b', ' ', '_']
>>> pred = torch.tensor([[0, 1, 2], [1, 0, 3]])
>>> target = torch.tensor([[0, 1], [1, 2]])
>>> ec_transducer = ErrorCalculatorTransducer(decoder, joint_network,
... token_list, ' ', '_')
>>> char_pred, char_target = ec_transducer.convert_to_char(pred, target)
>>> print(char_pred)
['ab ', 'ba']
>>> print(char_target)
['ab', 'a ']
######## NOTE The output character sequences may vary in length based on the predictions and targets. The ‘?’ in the shape notation indicates that the length may differ for each sequence.