[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin@redhat.com>
This commit is contained in:
@@ -470,6 +470,50 @@ __device__ inline void dequant<nv_bfloat162, vllm::kFE2M1f.id(), false>(
|
||||
frag_b[0] = __hmul2(frag_b[0], bias_reg);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<__nv_fp8x4_e4m3, vllm::kFE2M1f.id(), true>(
|
||||
int q, __nv_fp8x4_e4m3* frag_b) {
|
||||
// Constants for FP4 (E2M1) and FP16 formats
|
||||
constexpr int FP4_EXPONENT = 2, FP8_EXPONENT = 4;
|
||||
constexpr int RIGHT_SHIFT = FP8_EXPONENT - FP4_EXPONENT;
|
||||
constexpr int MASK = 0x70707070;
|
||||
|
||||
// Extract and shift FP4 values to FP16 format
|
||||
int Out1 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
q <<= 4;
|
||||
int Out2 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT);
|
||||
|
||||
// Note1: reverse indexing is intentional because weights are permuted
|
||||
// Note2: when dequant to 8bit type, we write to `frag_b[2]` instead of
|
||||
// `frag_b[1]` to fit the layout of tensorcore
|
||||
frag_b[1] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<int32_t, vllm::kU4B8.id(), true>(
|
||||
int q, int32_t* frag_b) {
|
||||
constexpr int repeated_zp = 0x08080808;
|
||||
constexpr int MASK = 0x80808080;
|
||||
|
||||
frag_b[0] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
|
||||
q >>= 4;
|
||||
frag_b[1] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void dequant<__nv_fp8x4_e4m3, vllm::kU4B8.id(), true>(
|
||||
int q, __nv_fp8x4_e4m3* frag_b) {
|
||||
int s = q & 0x08080808;
|
||||
int Out1 = ((q & 0x07070707) | (s << 4)) + (s >> 3);
|
||||
q >>= 4;
|
||||
s = q & 0x08080808;
|
||||
int Out2 = ((q & 0x07070707) | (s << 4)) + (s >> 3);
|
||||
|
||||
frag_b[0] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out1);
|
||||
frag_b[1] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out2);
|
||||
}
|
||||
|
||||
template <typename scalar_t2, vllm::ScalarTypeId s_type_id>
|
||||
__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b);
|
||||
|
||||
@@ -515,6 +559,49 @@ __device__ inline void dequant_fp8_scales<nv_bfloat162, vllm::kFE8M0fnu.id()>(
|
||||
// Note: reverse indexing is intentional because weights are permuted
|
||||
frag_b[1] = *reinterpret_cast<const nv_bfloat162*>(&Out1);
|
||||
frag_b[0] = *reinterpret_cast<const nv_bfloat162*>(&Out2);
|
||||
};
|
||||
|
||||
// subtract zero point in quanted format and then dequant
|
||||
template <typename scalar_t2, vllm::ScalarTypeId w_type_id,
|
||||
bool skip_flop = false>
|
||||
__device__ inline void sub_zp_and_dequant(int q, scalar_t2* frag_b, int zp);
|
||||
|
||||
template <>
|
||||
__device__ inline void sub_zp_and_dequant<int32_t, vllm::kU4.id(), true>(
|
||||
int q, int32_t* frag_b, int zp) {
|
||||
// INT4 with zp -> INT8
|
||||
// see https://github.com/vllm-project/vllm/pull/24722
|
||||
int repeated_zp = 0x01010101 * zp;
|
||||
int MASK = 0x80808080;
|
||||
|
||||
frag_b[0] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
|
||||
q >>= 4;
|
||||
frag_b[1] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline void sub_zp_and_dequant<__nv_fp8x4_e4m3, vllm::kU4.id(),
|
||||
true>(int q, __nv_fp8x4_e4m3* frag_b,
|
||||
int zp) {
|
||||
// INT4 with zp -> FP8
|
||||
// see https://github.com/vllm-project/vllm/pull/24722
|
||||
uint32_t u_q = *reinterpret_cast<uint32_t*>(&q);
|
||||
uint32_t u_zp = *reinterpret_cast<uint32_t*>(&zp);
|
||||
uint32_t u_zp1 = u_zp + 1;
|
||||
uint32_t repeated_zp = 0x01010101 * u_zp;
|
||||
|
||||
uint32_t q0, s;
|
||||
q0 = (u_q & 0x0F0F0F0F) | 0x70707070;
|
||||
s = (q0 + repeated_zp) & 0x80808080;
|
||||
uint32_t Out1 = (q0 + (s >> 7) * u_zp1) & 0x0F0F0F0F | s;
|
||||
|
||||
u_q >>= 4;
|
||||
q0 = (u_q & 0x0F0F0F0F) | 0x70707070;
|
||||
s = (q0 + repeated_zp) & 0x80808080;
|
||||
uint32_t Out2 = (q0 + (s >> 7) * u_zp1) & 0x0F0F0F0F | s;
|
||||
|
||||
frag_b[0] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out1);
|
||||
frag_b[1] = *reinterpret_cast<const __nv_fp8x4_e4m3*>(&Out2);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user