[Kernel]Support W4A8 Grouped GEMM on Hopper (#29691)
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
This commit is contained in:
@@ -6,7 +6,11 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
convert_bf16_scales_to_fp8,
|
||||
convert_packed_uint4b8_to_signed_int4_inplace,
|
||||
)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
@@ -48,7 +52,6 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
|
||||
"CUTLASS W4A8, only supported int4",
|
||||
)
|
||||
|
||||
# TODO(czhu): support -1 (column-wise)
|
||||
if c.group_size != 128:
|
||||
return False, "Only group_size 128 is supported"
|
||||
|
||||
@@ -71,9 +74,9 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
# TODO(czhu): optimize speed/mem usage
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
convert_packed_uint4b8_to_signed_int4_inplace(x.data)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x.data = ops.cutlass_encode_and_reorder_int4b(x.data.t().contiguous().t())
|
||||
return x
|
||||
@@ -85,10 +88,18 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
|
||||
x.data = ops.cutlass_pack_scale_fp8(x.data)
|
||||
return x
|
||||
|
||||
w_s = getattr(layer, self.w_s_name)
|
||||
fp8_scales, chan_scales = convert_bf16_scales_to_fp8(self.quant_fp8, w_s.data)
|
||||
w_s.data = fp8_scales
|
||||
|
||||
# register per-channel scales
|
||||
layer.register_parameter(
|
||||
"weight_chan_scale", torch.nn.Parameter(chan_scales, requires_grad=False)
|
||||
)
|
||||
|
||||
# Encode/reorder weights and pack scales
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
self._transform_param(layer, "weight_chan_scale", lambda x: x)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user