espnet2.asr.transducer.rnnt_multi_blank.utils.rnnt_helper.log_sum_exp
Less than 1 minute
espnet2.asr.transducer.rnnt_multi_blank.utils.rnnt_helper.log_sum_exp
espnet2.asr.transducer.rnnt_multi_blank.utils.rnnt_helper.log_sum_exp(a: float, b: float)
Calculate the log of the sum of exponentials of two input values.
This function efficiently computes the logarithm of the sum of the exponentials of two floating-point numbers, a and b, while handling cases of negative infinity as defined in the global_constants module.
This implementation is designed for use in CUDA kernels, thus it uses the @cuda.jit decorator for Just-In-Time compilation. The function is also inlined for performance optimization.
- Parameters:
- a (float) – The first input value.
- b (float) – The second input value.
- Returns: The logarithm of the sum of exponentials of a and b.
- Return type: float
Examples
>>> result = log_sum_exp(1.0, 2.0)
>>> print(result) # Output will be approximately 2.3133
>>> result = log_sum_exp(global_constants.FP32_NEG_INF, 3.0)
>>> print(result) # Output will be 3.0, as -inf is ignored.
>>> result = log_sum_exp(4.0, global_constants.FP32_NEG_INF)
>>> print(result) # Output will be 4.0, as -inf is ignored.
NOTE
The function assumes that inputs are valid floating-point numbers and uses constants defined in global_constants to handle edge cases effectively.