[Kernel] Replaced blockReduce[...] functions with cub::BlockReduce (#7233)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
@@ -3,7 +3,14 @@
|
||||
#include <cmath>
|
||||
|
||||
#include "../../dispatch_utils.h"
|
||||
#include "../../reduction_utils.cuh"
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/util_type.cuh>
|
||||
#include <cub/cub.cuh>
|
||||
#else
|
||||
#include <hipcub/util_type.hpp>
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#endif
|
||||
|
||||
static inline __device__ int8_t float_to_int8_rn(float x) {
|
||||
#ifdef USE_ROCM
|
||||
@@ -55,7 +62,10 @@ __global__ void dynamic_scaled_int8_quant_kernel(
|
||||
absmax_val = val > absmax_val ? val : absmax_val;
|
||||
}
|
||||
|
||||
float const block_absmax_val_maybe = blockReduceMax(absmax_val);
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStorage;
|
||||
float const block_absmax_val_maybe =
|
||||
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
|
||||
__shared__ float block_absmax_val;
|
||||
if (tid == 0) {
|
||||
block_absmax_val = block_absmax_val_maybe;
|
||||
|
||||
Reference in New Issue
Block a user