espnet2.asr.transducer.rnnt_multi_blank.utils.cuda_utils.reduce.ReduceHelper
About 1 min
espnet2.asr.transducer.rnnt_multi_blank.utils.cuda_utils.reduce.ReduceHelper
espnet2.asr.transducer.rnnt_multi_blank.utils.cuda_utils.reduce.ReduceHelper(I_opid: int, R_opid: int, acts: Tensor, output: Tensor, num_rows: int, num_cols: int, minus: bool, stream)
ReduceHelper is a CUDA Warp reduction kernel helper that performs reductions on input activation matrices according to specified input and reduction operator IDs. The results are written to the output tensor based on the selected operations.
This function can execute either a maximum or an additive reduction based on the specified parameters, while efficiently handling the input shapes that are powers of two.
NOTE
Efficient warp occurs at input shapes of 2 ^ K.
References
- Warp Primitives
[https://developer.nvidia.com/blog/using-cuda-warp-level-primitives/]
- Parameters:
- I_opid (int) – Operator ID for input, defined in I_Op enumeration.
- R_opid (int) – Operator ID for reduction, defined in R_Op enumeration.
- acts (torch.Tensor) – Flattened activation matrix of shape [B * T * U * (V+1)].
- output (torch.Tensor) – Flattened output matrix of shape [B * T * U * (V+1)]. Data will be overwritten.
- num_rows (int) – Vocabulary size (including blank token) - V+1. Represents the number of threads per block.
- num_cols (int) – Flattened shape of activation matrix, without vocabulary dimension (B * T * U). Represents number of blocks per grid.
- minus (bool) – Flag indicating whether to perform subtraction. If set to True, calls the _reduce_minus kernel; otherwise, calls the _reduce_rows kernel.
- stream – CUDA Stream to manage asynchronous execution.
- Returns: Returns True upon successful execution of the reduction operation.
- Return type: bool
Examples
>>> acts = torch.randn((2, 3, 4, 5)) # Example activation tensor
>>> output = torch.zeros((2, 3, 4, 5)) # Output tensor
>>> num_rows = 6 # Example number of rows
>>> num_cols = 4 # Example number of columns
>>> minus = False # Example flag
>>> stream = None # Assuming synchronous execution
>>> ReduceHelper(I_opid=0, R_opid=0, acts=acts, output=output,
... num_rows=num_rows, num_cols=num_cols,
... minus=minus, stream=stream)