Apply fixes for CUDA 13 (#24599)

Signed-off-by: Aidyn-A <aidyn.b.aitzhan@gmail.com>
This commit is contained in:
Aidyn-A
2025-09-17 17:15:42 +04:00
committed by GitHub
parent 9fccd04e30
commit bfe9380161
8 changed files with 47 additions and 56 deletions

View File

@@ -8,11 +8,7 @@
#include "quantization/utils.cuh"
#include "quant_conversions.cuh"
#ifndef USE_ROCM
#include <cub/cub.cuh>
#else
#include <hipcub/hipcub.hpp>
#endif
#include "../../cub_helpers.h"
namespace vllm {
@@ -36,7 +32,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x);
ss = BlockReduce(reduceStore).Reduce(ss, CubAddOp{}, blockDim.x);
__shared__ float s_rms;
if (threadIdx.x == 0) {
@@ -73,7 +69,7 @@ __device__ void compute_dynamic_per_token_scales(
__shared__ typename BlockReduce::TempStorage reduceStore;
block_absmax_val_maybe =
BlockReduce(reduceStore)
.Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x);
.Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x);
__shared__ float s_token_scale;
if (threadIdx.x == 0) {
@@ -169,7 +165,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x);
ss = BlockReduce(reduceStore).Reduce(ss, CubAddOp{}, blockDim.x);
__shared__ float s_rms;
if (threadIdx.x == 0) {
@@ -240,7 +236,7 @@ __device__ void compute_dynamic_per_token_scales(
__shared__ typename BlockReduce::TempStorage reduceStore;
block_absmax_val_maybe =
BlockReduce(reduceStore)
.Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x);
.Reduce(block_absmax_val_maybe, CubMaxOp{}, blockDim.x);
__shared__ float s_token_scale;
if (threadIdx.x == 0) {