[Kernel]Support W4A8 Grouped GEMM on Hopper (#29691)

Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
This commit is contained in:
czhu-cohere
2025-12-08 22:29:06 -05:00
committed by GitHub
parent ea657f2078
commit f6227c22ab
22 changed files with 2045 additions and 101 deletions

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This file is used for /tests and /benchmarks"""
from collections.abc import Mapping
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from types import MappingProxyType
from typing import ClassVar, NamedTuple
@@ -691,3 +691,51 @@ def cutlass_fp4_supported() -> bool:
capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()
return cutlass_scaled_mm_supports_fp4(capability)
def convert_bf16_scales_to_fp8(
quant_fp8: Callable, scales: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Convert a BF16 scale tensor into the pair of (fp8_scales, channel_scales)
expected by W4A8 GEMM kernels.
"""
assert scales.is_contiguous(), (
f"scale tensor must be contiguous, got {scales.stride()=}"
)
assert scales.is_cuda, "scales must be on gpu"
orig_shape = scales.shape
k_groups = orig_shape[-1]
flat_scales = scales.view(-1, k_groups)
fp8_scales, chan_scales = quant_fp8(flat_scales)
fp8_scales = (fp8_scales.float() / 8.0).to(torch.float8_e4m3fn)
chan_scales *= 8.0
# restore original shape
fp8_scales = fp8_scales.view(orig_shape)
chan_scales = chan_scales.view(orig_shape[:-1], -1)
return fp8_scales, chan_scales
def convert_packed_uint4b8_to_signed_int4_inplace(t: torch.Tensor) -> torch.Tensor:
"""
Convert int4b8 (packed to int32) to signed int4
"""
assert t.is_cuda, "tensor must be on gpu"
assert t.dtype == torch.int32, f"expected int32 packed weights but got {t.dtype}"
# loop through the 8 4-bit nibbles in each int32 entry
for i in range(8):
shift = 4 * i
# extract the i-th 4-bit nibble
nib = (t >> shift) & 0xF
# clear the original nibble by masking out
t &= ~(0xF << shift)
# convert int4b8 [0..15] to signed int4 [-8..7] by subtracting 8
# and update in-place
t |= ((nib - 8) & 0xF) << shift
return t