[Bugfix][Kernel]: Fix AllSpark kernel compilation errors and enable for CUDA < 12.0 (#14430)

Signed-off-by: wyj371990 <wyj371990@alibaba-inc.com>
This commit is contained in:
Yajie Wang
2025-03-15 00:55:14 +08:00
committed by GitHub
parent 73deea2fdb
commit 977a16772c
3 changed files with 15 additions and 10 deletions

View File

@@ -7,6 +7,8 @@
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <iostream>
#include "../gptq_marlin/marlin_dtypes.cuh"
using marlin::ScalarType;
namespace allspark {
@@ -66,14 +68,14 @@ __global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C,
return;
}
FType sum(0);
float sum = 0.f;
int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix;
for (int i = 0; i < n_mat; ++i) {
sum += C_split[idx + i * matrix_size];
sum += ScalarType<FType>::num2float(C_split[idx + i * matrix_size]);
}
C[idx] = sum;
C[idx] = ScalarType<FType>::float2num(sum);
}
template <typename FType>