[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin@redhat.com>
This commit is contained in:
@@ -8,7 +8,7 @@
|
||||
#include <cuda_bf16.h>
|
||||
#include <iostream>
|
||||
#include "../gptq_marlin/marlin_dtypes.cuh"
|
||||
using marlin::ScalarType;
|
||||
using marlin::MarlinScalarType2;
|
||||
|
||||
namespace allspark {
|
||||
|
||||
@@ -72,10 +72,10 @@ __global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C,
|
||||
|
||||
int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix;
|
||||
for (int i = 0; i < n_mat; ++i) {
|
||||
sum += ScalarType<FType>::num2float(C_split[idx + i * matrix_size]);
|
||||
sum += MarlinScalarType2<FType>::num2float(C_split[idx + i * matrix_size]);
|
||||
}
|
||||
|
||||
C[idx] = ScalarType<FType>::float2num(sum);
|
||||
C[idx] = MarlinScalarType2<FType>::float2num(sum);
|
||||
}
|
||||
|
||||
template <typename FType>
|
||||
|
||||
Reference in New Issue
Block a user