[Kernel] Replaced blockReduce[...] functions with cub::BlockReduce (#7233)

Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
Luka Govedič
2024-08-21 20:18:00 -04:00
committed by GitHub
parent 9984605412
commit 7937009a7e
8 changed files with 237 additions and 116 deletions

View File

@@ -3,7 +3,14 @@
#include <cmath>
#include "../../dispatch_utils.h"
#include "../../reduction_utils.cuh"
#ifndef USE_ROCM
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#endif
static inline __device__ int8_t float_to_int8_rn(float x) {
#ifdef USE_ROCM
@@ -55,7 +62,10 @@ __global__ void dynamic_scaled_int8_quant_kernel(
absmax_val = val > absmax_val ? val : absmax_val;
}
float const block_absmax_val_maybe = blockReduceMax(absmax_val);
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStorage;
float const block_absmax_val_maybe =
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
__shared__ float block_absmax_val;
if (tid == 0) {
block_absmax_val = block_absmax_val_maybe;