[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)

Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Michael Goin <mgoin@redhat.com>
This commit is contained in:
Jinzhen Lin
2025-11-29 23:19:33 +08:00
committed by GitHub
parent fa59fe417f
commit 1656ad3704
46 changed files with 4371 additions and 2240 deletions

View File

@@ -437,10 +437,10 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK {
for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) {
#pragma unroll
for (int k_idx = 0; k_idx < 2; ++k_idx) {
FType low16 =
ScalarType<FType>::float2num(C_frag[m_idx][n_idx][k_idx * 2]);
FType high16 =
ScalarType<FType>::float2num(C_frag[m_idx][n_idx][k_idx * 2 + 1]);
FType low16 = MarlinScalarType2<FType>::float2num(
C_frag[m_idx][n_idx][k_idx * 2]);
FType high16 = MarlinScalarType2<FType>::float2num(
C_frag[m_idx][n_idx][k_idx * 2 + 1]);
uint32_t tmp = (reinterpret_cast<uint32_t&>(low16) & 0xffff) |
(reinterpret_cast<uint32_t&>(high16) << 16);
int sts_offset =

View File

@@ -8,7 +8,7 @@
#include <cuda_bf16.h>
#include <iostream>
#include "../gptq_marlin/marlin_dtypes.cuh"
using marlin::ScalarType;
using marlin::MarlinScalarType2;
namespace allspark {
@@ -72,10 +72,10 @@ __global__ void f16_gemm_splitk_reduce_kernel(const FType* C_split, FType* C,
int n_mat = N_MATRIX > 0 ? N_MATRIX : (int)n_matrix;
for (int i = 0; i < n_mat; ++i) {
sum += ScalarType<FType>::num2float(C_split[idx + i * matrix_size]);
sum += MarlinScalarType2<FType>::num2float(C_split[idx + i * matrix_size]);
}
C[idx] = ScalarType<FType>::float2num(sum);
C[idx] = MarlinScalarType2<FType>::float2num(sum);
}
template <typename FType>