[Kernel]Support W4A8 Grouped GEMM on Hopper (#29691)

Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
This commit is contained in:
czhu-cohere
2025-12-08 22:29:06 -05:00
committed by GitHub
parent ea657f2078
commit f6227c22ab
22 changed files with 2045 additions and 101 deletions

View File

@@ -6,7 +6,11 @@ import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
convert_bf16_scales_to_fp8,
convert_packed_uint4b8_to_signed_int4_inplace,
)
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
@@ -48,7 +52,6 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
"CUTLASS W4A8, only supported int4",
)
# TODO(czhu): support -1 (column-wise)
if c.group_size != 128:
return False, "Only group_size 128 is supported"
@@ -71,9 +74,9 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module):
# TODO(czhu): optimize speed/mem usage
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
convert_packed_uint4b8_to_signed_int4_inplace(x.data)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = ops.cutlass_encode_and_reorder_int4b(x.data.t().contiguous().t())
return x
@@ -85,10 +88,18 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
x.data = ops.cutlass_pack_scale_fp8(x.data)
return x
w_s = getattr(layer, self.w_s_name)
fp8_scales, chan_scales = convert_bf16_scales_to_fp8(self.quant_fp8, w_s.data)
w_s.data = fp8_scales
# register per-channel scales
layer.register_parameter(
"weight_chan_scale", torch.nn.Parameter(chan_scales, requires_grad=False)
)
# Encode/reorder weights and pack scales
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
self._transform_param(layer, "weight_chan_scale", lambda x: x)
def apply_weights(
self,