[Kernel]Support W4A8 Grouped GEMM on Hopper (#29691)
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
This commit is contained in:
@@ -12,8 +12,11 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
convert_packed_uint4b8_to_signed_int4_inplace,
|
||||
pack_cols,
|
||||
pack_rows,
|
||||
quantize_weights,
|
||||
unpack_quantized_values_into_int32,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
@@ -167,8 +170,7 @@ def create_test_tensors(
|
||||
|
||||
# for the practical use case we need per-tok scales for fp8 activations
|
||||
w_tok_s = torch.randn((m,), device="cuda", dtype=types.token_scale_type)
|
||||
# weights are already per-group quantized, use placeholder here
|
||||
w_ch_s = torch.ones((n,), device="cuda", dtype=types.channel_scale_type)
|
||||
w_ch_s = torch.randn((n,), device="cuda", dtype=types.channel_scale_type)
|
||||
|
||||
return Tensors(
|
||||
w_ref=w_ref,
|
||||
@@ -211,7 +213,7 @@ def mm_test_helper(
|
||||
print(output_ref)
|
||||
|
||||
torch.testing.assert_close(
|
||||
output, output_ref.to(output.dtype), rtol=1e-3, atol=1e-3
|
||||
output, output_ref.to(output.dtype), rtol=1e-2, atol=1e-2
|
||||
)
|
||||
|
||||
|
||||
@@ -257,7 +259,7 @@ def test_w4a8_cuda_graph():
|
||||
)
|
||||
|
||||
w_tok_s = torch.randn((m,), device="cuda", dtype=torch.float32)
|
||||
w_ch_s = torch.ones((n,), device="cuda", dtype=torch.float32)
|
||||
w_ch_s = torch.randn((n,), device="cuda", dtype=torch.float32)
|
||||
|
||||
# Construct a trivial model with a single layer that calls the kernel
|
||||
model = W4A8Layer(
|
||||
@@ -287,4 +289,38 @@ def test_w4a8_cuda_graph():
|
||||
output.zero_()
|
||||
g.replay()
|
||||
|
||||
torch.testing.assert_close(output, output_ref, rtol=1e-3, atol=1e-3)
|
||||
torch.testing.assert_close(output, output_ref, rtol=1e-2, atol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not IS_SUPPORTED_BY_GPU, reason="CUTLASS W4A8 is not supported on this GPU type."
|
||||
)
|
||||
@pytest.mark.parametrize("shape", MNK_SHAPES)
|
||||
def test_convert_packed_uint4b8_to_signed_int4_inplace(shape):
|
||||
"""
|
||||
The W4A16 checkpoints encode the weights as int4b8 packed to int32.
|
||||
The CUTLASS kernels expect signed int4 packed to int32.
|
||||
This tests checks that the runtime int4b8 -> signed int4 conversion
|
||||
matches the offline conversion step exactly.
|
||||
"""
|
||||
_, N, K = shape
|
||||
# random weights packed to int32
|
||||
t = torch.randint(
|
||||
low=torch.iinfo(torch.int32).min,
|
||||
high=torch.iinfo(torch.int32).max + 1,
|
||||
size=(N, K // 8),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
# compute reference
|
||||
unpacked = unpack_quantized_values_into_int32(
|
||||
t.clone(), scalar_types.uint4b8, packed_dim=1
|
||||
)
|
||||
unpacked = unpacked - 8 # int4b8 -> signed int4
|
||||
ref = pack_cols(unpacked & 0x0F, 4, *unpacked.shape)
|
||||
|
||||
out = convert_packed_uint4b8_to_signed_int4_inplace(t.clone())
|
||||
|
||||
assert torch.equal(ref, out)
|
||||
assert not torch.equal(ref, t)
|
||||
|
||||
Reference in New Issue
Block a user