espnet2.asr_transducer.error_calculator.ErrorCalculator
espnet2.asr_transducer.error_calculator.ErrorCalculator
class espnet2.asr_transducer.error_calculator.ErrorCalculator(decoder: AbsDecoder, joint_network: JointNetwork, token_list: List[int], sym_space: str, sym_blank: str, nstep: int = 2, report_cer: bool = False, report_wer: bool = False)
Bases: object
Error Calculator module for Transducer.
This module provides the ErrorCalculator class which is responsible for calculating the Character Error Rate (CER) and Word Error Rate (WER) for transducer models in Automatic Speech Recognition (ASR).
decoder
The decoder module used for generating predictions.
- Type:AbsDecoder
joint_network
The joint network module used in the model.
- Type:JointNetwork
token_list
List of token units for mapping predictions to characters.
- Type: List[int]
space
Symbol representing space in the token list.
- Type: str
blank
Symbol representing the blank in the token list.
- Type: str
report_cer
Flag indicating whether to compute CER.
- Type: bool
report_wer
Flag indicating whether to compute WER.
Type: bool
Parameters:
- decoder (AbsDecoder) – Decoder module.
- joint_network (JointNetwork) – Joint Network module.
- token_list (List *[*int ]) – List of token units.
- sym_space (str) – Space symbol.
- sym_blank (str) – Blank symbol.
- nstep (int , optional) – Maximum number of symbol expansions at each time step with mAES. Defaults to 2.
- report_cer (bool , optional) – Whether to compute CER. Defaults to False.
- report_wer (bool , optional) – Whether to compute WER. Defaults to False.
########### Examples
Initialize the ErrorCalculator
error_calculator = ErrorCalculator(
decoder=decoder_instance, joint_network=joint_network_instance, token_list=[0, 1, 2, 3], # Example token list sym_space=’ ‘, sym_blank=’|’, nstep=2, report_cer=True, report_wer=True
)
Calculate CER and WER
cer, wer = error_calculator(encoder_out, target, encoder_out_lens)
######## NOTE The ErrorCalculator uses the mAES algorithm for validation to ensure better performance and control over the number of emitted symbols during validation.
- Raises:ValueError – If the lengths of the predictions and targets do not match.
Construct an ErrorCalculatorTransducer object.
calculate_cer(char_pred: Tensor, char_target: Tensor) → float
Error Calculator module for Transducer.
This module provides functionality to calculate Character Error Rate (CER) and Word Error Rate (WER) for transducer models. It includes methods to process output from the model and compute the error rates based on predicted and target sequences.
decoder
An instance of the decoder module used in the transducer model.
joint_network
An instance of the joint network module.
token_list
A list of token units for conversion between IDs and characters.
space
A symbol representing space in the token list.
blank
A symbol representing blank in the token list.
report_cer
A flag indicating whether to compute CER.
report_wer
A flag indicating whether to compute WER.
- Parameters:
- decoder – An instance of AbsDecoder, the decoder module.
- joint_network – An instance of JointNetwork, the joint network module.
- token_list – A list of integer token units.
- sym_space – A string representing the space symbol.
- sym_blank – A string representing the blank symbol.
- nstep – An integer specifying the maximum number of symbol expansions at each time step with mAES (default is 2).
- report_cer – A boolean indicating whether to compute CER (default is False).
- report_wer – A boolean indicating whether to compute WER (default is False).
########### Examples
error_calculator = ErrorCalculator(decoder, joint_network, token_list, : sym_space, sym_blank, report_cer=True, report_wer=True)
cer_score, wer_score = error_calculator(encoder_out, target, encoder_out_lens)
- Raises:ValueError – If the input tensors do not have the expected dimensions.
######## NOTE The CER and WER calculations are based on the edit distance between predicted and target sequences, which is computed using the editdistance library.
calculate_wer(char_pred: Tensor, char_target: Tensor) → float
Error Calculator module for Transducer.
This module provides functionality to calculate Character Error Rate (CER) and Word Error Rate (WER) for transducer models. It utilizes a beam search decoder to generate predictions based on the encoder outputs and compares these predictions with target sequences to compute the respective error rates.
decoder
An instance of AbsDecoder used for decoding.
joint_network
An instance of JointNetwork used in the transducer model.
token_list
A list of token units representing characters or words.
sym_space
The symbol representing space in the token list.
sym_blank
The symbol representing blank in the token list.
nstep
The maximum number of symbol expansions at each time step for mAES.
report_cer
A boolean indicating whether to compute CER.
report_wer
A boolean indicating whether to compute WER.
- Parameters:
- decoder – Decoder module.
- joint_network – Joint Network module.
- token_list – List of token units.
- sym_space – Space symbol.
- sym_blank – Blank symbol.
- nstep – Maximum number of symbol expansions at each time step w/ mAES.
- report_cer – Whether to compute CER.
- report_wer – Whether to compute WER.
########### Examples
>>> error_calculator = ErrorCalculator(decoder, joint_network, token_list,
... ' ', '<blank>', nstep=2,
... report_cer=True, report_wer=True)
>>> cer, wer = error_calculator(encoder_out, target, encoder_out_lens)
>>> print(f"CER: {cer}, WER: {wer}")
######## NOTE The calculation relies on the mAES algorithm for validation instead of the default algorithm to avoid performance degradation during training.
- Raises:ValueError – If the dimensions of encoder_out or target are inconsistent.
convert_to_char(pred: Tensor, target: Tensor) → Tuple[List, List]
Convert label ID sequences to character sequences.
This method takes in prediction and target label ID sequences and converts them into their corresponding character sequences. It replaces space and blank symbols as specified by the user.
Parameters:
- pred – Prediction label ID sequences. Shape: (B, U), where B is the batch size and U is the number of predicted symbols.
- target – Target label ID sequences. Shape: (B, L), where L is the number of target symbols.
Returns:
- char_pred: Prediction character sequences. Shape: (B, ?), where ?
is the variable length of character sequences.
- char_target: Target character sequences. Shape: (B, ?).
Return type: Tuple[List, List]
########### Examples
>>> pred = torch.tensor([[1, 2, 3], [4, 5, 0]])
>>> target = torch.tensor([[1, 3], [4, 6]])
>>> token_list = ['<blank>', 'a', 'b', 'c', 'd', 'e', 'f', ' ']
>>> calculator = ErrorCalculator(decoder, joint_network, token_list,
... ' ', '<blank>')
>>> char_pred, char_target = calculator.convert_to_char(pred, target)
>>> print(char_pred) # Output: ['abc', 'de']
>>> print(char_target) # Output: ['ac', 'd']
######## NOTE
- The space symbol is replaced with a regular space character (’ ‘).
- The blank symbol is removed from the output character sequences.