Categorize tests/kernels/ based on kernel type (#16799)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
100
tests/kernels/quantization/test_allspark_gemm.py
Normal file
100
tests/kernels/quantization/test_allspark_gemm.py
Normal file
@@ -0,0 +1,100 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
|
||||
ALLSPARK_AMPERE_K_ALIGN, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
ALLSPARK_AMPERE_N_ALIGN)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
quantize_weights)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
|
||||
def is_gptq_allspark_supported(min_capability: int,
|
||||
max_capability: int) -> bool:
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
assert capability is not None
|
||||
|
||||
return capability.to_int() >= min_capability \
|
||||
and capability.to_int() <= max_capability
|
||||
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 4, 8),
|
||||
(13, 17, 67),
|
||||
(26, 37, 13),
|
||||
(48, 16, 24),
|
||||
(67, 13, 88),
|
||||
(257, 13, 11),
|
||||
(658, 13, 11),
|
||||
(1033, 9, 17),
|
||||
]
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
HAS_ZP_OPTS = [False, True]
|
||||
|
||||
|
||||
def compute_max_diff(output, output_ref):
|
||||
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
||||
torch.abs(output_ref))
|
||||
|
||||
|
||||
def rand_data(shape, dtype=torch.float16):
|
||||
return torch.randn(shape, dtype=dtype, device="cuda")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_gptq_allspark_supported(80, 89),
|
||||
reason="AllSpark Ampere kernel is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("group_size", [-1])
|
||||
@pytest.mark.parametrize("has_zp", HAS_ZP_OPTS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
m = m_factor
|
||||
n = n_factor * ALLSPARK_AMPERE_N_ALIGN
|
||||
k = k_factor * ALLSPARK_AMPERE_K_ALIGN
|
||||
|
||||
input = rand_data((m, k), dtype=dtype)
|
||||
weight = rand_data((k, n), dtype=dtype)
|
||||
|
||||
# Quantize (and apply act_order if provided)
|
||||
w_ref, qw, s, zp = quantize_weights(weight, scalar_types.uint8b128,
|
||||
group_size, has_zp)
|
||||
|
||||
qw = qw.to(torch.uint8)
|
||||
if has_zp:
|
||||
zp = zp.to(dtype)
|
||||
properties = torch.cuda.get_device_properties(qw.device.index)
|
||||
sm_count = properties.multi_processor_count
|
||||
sm_version = properties.major * 10 + properties.minor
|
||||
|
||||
n_32align = (n + 32 - 1) // 32 * 32
|
||||
|
||||
qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(
|
||||
qw, s, zp, has_zp)
|
||||
opcheck(torch.ops._C.rearrange_kn_weight_as_n32k16_order,
|
||||
(qw, s, zp, has_zp, qw_reorder, s_reorder, zp_reorder, k, n,
|
||||
n_32align))
|
||||
|
||||
opcheck(torch.ops._C.allspark_w8a16_gemm,
|
||||
(input, qw_reorder, s_reorder, zp_reorder, n, group_size, sm_count,
|
||||
sm_version, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, has_zp, True),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||
output = ops.allspark_w8a16_gemm(input, qw_reorder, s_reorder, zp_reorder,
|
||||
n, group_size, sm_count, sm_version,
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
has_zp, True)
|
||||
|
||||
output_ref = torch.matmul(input, w_ref)
|
||||
torch.cuda.synchronize()
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
39
tests/kernels/quantization/test_aqlm.py
Normal file
39
tests/kernels/quantization/test_aqlm.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops # noqa: F401
|
||||
|
||||
|
||||
def test_aqlm_dequant_opcheck():
|
||||
codes = torch.randint(-32768,
|
||||
32767, (22016, 512, 1),
|
||||
device='cuda',
|
||||
dtype=torch.int16)
|
||||
codebooks = torch.rand((2, 65536, 1, 8),
|
||||
device='cuda',
|
||||
dtype=torch.float16)
|
||||
codebook_partition_sizes = [11008, 11008]
|
||||
|
||||
opcheck(torch.ops._C.aqlm_dequant,
|
||||
(codes, codebooks, codebook_partition_sizes))
|
||||
|
||||
|
||||
def test_aqlm_gemm_opcheck():
|
||||
input = torch.rand((4, 4096), device='cuda', dtype=torch.float16)
|
||||
codes = torch.randint(-32768,
|
||||
32767, (12288, 512, 1),
|
||||
device='cuda',
|
||||
dtype=torch.int16)
|
||||
codebooks = torch.rand((3, 65536, 1, 8),
|
||||
device='cuda',
|
||||
dtype=torch.float16)
|
||||
scales = torch.rand((12288, 1, 1, 1), device='cuda', dtype=torch.float16)
|
||||
codebook_partition_sizes = [4096, 4096, 4096]
|
||||
bias = None
|
||||
|
||||
opcheck(torch.ops._C.aqlm_gemm,
|
||||
(input, codes, codebooks, scales, codebook_partition_sizes, None))
|
||||
opcheck(torch.ops._C.aqlm_gemm,
|
||||
(input, codes, codebooks, scales, codebook_partition_sizes, bias))
|
||||
46
tests/kernels/quantization/test_awq.py
Normal file
46
tests/kernels/quantization/test_awq.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops # noqa: F401
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_dequantize"),
|
||||
reason="AWQ is not supported on this GPU type.")
|
||||
def test_awq_dequantize_opcheck(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_TRITON_AWQ", "0")
|
||||
qweight = torch.randint(-2000000000,
|
||||
2000000000, (8192, 256),
|
||||
device='cuda',
|
||||
dtype=torch.int32)
|
||||
scales = torch.rand((64, 2048), device='cuda', dtype=torch.float16)
|
||||
zeros = torch.empty((64, 256), device='cuda', dtype=torch.int32)
|
||||
split_k_iters = 0
|
||||
thx = 0
|
||||
thy = 0
|
||||
opcheck(torch.ops._C.awq_dequantize,
|
||||
(qweight, scales, zeros, split_k_iters, thx, thy))
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Not working; needs investigation.")
|
||||
@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_gemm"),
|
||||
reason="AWQ is not supported on this GPU type.")
|
||||
def test_awq_gemm_opcheck(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_TRITON_AWQ", "0")
|
||||
input = torch.rand((2, 8192), device='cuda', dtype=torch.float16)
|
||||
qweight = torch.randint(-2000000000,
|
||||
2000000000, (8192, 256),
|
||||
device='cuda',
|
||||
dtype=torch.int32)
|
||||
scales = torch.randint(-2000000000,
|
||||
2000000000, (64, 256),
|
||||
device='cuda',
|
||||
dtype=torch.int32)
|
||||
qzeros = torch.empty((64, 2048), device='cuda', dtype=torch.float16)
|
||||
split_k_iters = 8
|
||||
opcheck(torch.ops._C.awq_gemm,
|
||||
(input, qweight, qzeros, scales, split_k_iters))
|
||||
163
tests/kernels/quantization/test_awq_marlin.py
Normal file
163
tests/kernels/quantization/test_awq_marlin.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Test AWQ with fused MoE Marlin kernels.
|
||||
|
||||
Run `pytest tests/kernels/test_awq_marlin.py`.
|
||||
"""
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe,
|
||||
torch_moe_single)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
awq_marlin_quantize)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
NUM_EXPERTS = [8, 64]
|
||||
TOP_KS = [2, 6]
|
||||
GROUP_SIZES = [-1, 32, 128]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
||||
@pytest.mark.parametrize("n", [128, 2048])
|
||||
@pytest.mark.parametrize("k", [128, 1024])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("group_size", GROUP_SIZES)
|
||||
@pytest.mark.skipif(not (ops.supports_moe_ops
|
||||
and hasattr(torch.ops._moe_C, "marlin_gemm_moe")),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
def test_fused_marlin_moe_awq(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
group_size: int,
|
||||
):
|
||||
torch.manual_seed(7)
|
||||
|
||||
num_bits = 4
|
||||
quant_type = scalar_types.uint4
|
||||
dtype = torch.float16
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
|
||||
w_ref1_l = []
|
||||
qweights1_l = []
|
||||
scales1_l = []
|
||||
zp1_l = []
|
||||
|
||||
for i in range(w1.shape[0]):
|
||||
w_ref1, qweight1, scales1, zp1 = awq_marlin_quantize(
|
||||
w1[i].transpose(1, 0), quant_type, group_size)
|
||||
w_ref1_l.append(w_ref1)
|
||||
qweights1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
zp1_l.append(zp1)
|
||||
|
||||
w_ref1 = stack_and_dev(w_ref1_l)
|
||||
qweight1 = stack_and_dev(qweights1_l).contiguous()
|
||||
scales1 = stack_and_dev(scales1_l)
|
||||
zp1 = stack_and_dev(zp1_l)
|
||||
|
||||
w_ref2_l = []
|
||||
qweights2_l = []
|
||||
scales2_l = []
|
||||
zp2_l = []
|
||||
|
||||
for i in range(w2.shape[0]):
|
||||
w_ref2, qweight2, scales2, zp2 = awq_marlin_quantize(
|
||||
w2[i].transpose(1, 0), quant_type, group_size)
|
||||
w_ref2_l.append(w_ref2)
|
||||
qweights2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
zp2_l.append(zp2)
|
||||
|
||||
w_ref2 = stack_and_dev(w_ref2_l)
|
||||
qweight2 = stack_and_dev(qweights2_l).contiguous()
|
||||
scales2 = stack_and_dev(scales2_l)
|
||||
zp2 = stack_and_dev(zp2_l)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids = fused_topk(a, score, topk, False)
|
||||
marlin_output = torch.ops.vllm.fused_marlin_moe(
|
||||
a,
|
||||
qweight1,
|
||||
qweight2,
|
||||
scales1,
|
||||
scales2,
|
||||
score,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_zeros=zp1,
|
||||
w2_zeros=zp2,
|
||||
num_bits=num_bits,
|
||||
)
|
||||
|
||||
torch_output = torch_moe(a, w_ref1.transpose(1, 2), w_ref2.transpose(1, 2),
|
||||
score, topk, None)
|
||||
|
||||
assert compute_max_diff(marlin_output, torch_output) < 4e-2
|
||||
|
||||
|
||||
@pytest.mark.skip("This test is here for the sake of debugging, "
|
||||
"don't run it in automated tests.")
|
||||
@pytest.mark.parametrize("m", [64, 512, 222, 33, 1])
|
||||
@pytest.mark.parametrize("n", [128, 2048, 256, 1024])
|
||||
@pytest.mark.parametrize("k", [128, 1024, 512])
|
||||
@pytest.mark.parametrize("e", [8, 64])
|
||||
@pytest.mark.parametrize("topk", [2, 6])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||
def test_single_marlin_moe_multiply_awq(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
group_size: int,
|
||||
):
|
||||
torch.manual_seed(7)
|
||||
|
||||
num_bits = 4
|
||||
quant_type = scalar_types.uint4
|
||||
dtype = torch.float16
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
|
||||
|
||||
w_ref_l = []
|
||||
qweights_l = []
|
||||
scales_l = []
|
||||
zp_l = []
|
||||
|
||||
for i in range(w.shape[0]):
|
||||
w_ref, qweight, scales, zp = awq_marlin_quantize(
|
||||
w[i].transpose(1, 0), quant_type, group_size)
|
||||
w_ref_l.append(w_ref)
|
||||
qweights_l.append(qweight)
|
||||
scales_l.append(scales)
|
||||
zp_l.append(zp)
|
||||
|
||||
w_ref = stack_and_dev(w_ref_l)
|
||||
qweight = stack_and_dev(qweights_l).contiguous()
|
||||
scales = stack_and_dev(scales_l).contiguous()
|
||||
zp = stack_and_dev(zp_l).contiguous()
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
marlin_output = torch.ops.vllm.single_marlin_moe(a,
|
||||
qweight,
|
||||
scales,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
w_zeros=zp,
|
||||
num_bits=num_bits)
|
||||
|
||||
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
|
||||
|
||||
assert compute_max_diff(marlin_output, torch_output) < 1e-2
|
||||
171
tests/kernels/quantization/test_awq_triton.py
Normal file
171
tests/kernels/quantization/test_awq_triton.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Tests for the AWQ Triton kernel.
|
||||
|
||||
Run `pytest tests/kernels/test_awq_triton.py`.
|
||||
"""
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.awq_triton import (
|
||||
AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
device = "cuda"
|
||||
|
||||
|
||||
def reverse_awq_order(t: torch.Tensor):
|
||||
bits = 4
|
||||
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
reverse_order_tensor = torch.arange(
|
||||
t.shape[-1],
|
||||
dtype=torch.int32,
|
||||
device=t.device,
|
||||
)
|
||||
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
|
||||
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
|
||||
reverse_order_tensor = reverse_order_tensor.view(-1)
|
||||
|
||||
t = t[:, reverse_order_tensor] & 0xF
|
||||
return t
|
||||
|
||||
|
||||
# qweights - [R , C // 8], int32
|
||||
# scales - [R // G, C ], float16
|
||||
# zeros - [R // G, C // 8], int32
|
||||
def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor,
|
||||
qzeros: torch.Tensor,
|
||||
group_size: int) -> torch.Tensor:
|
||||
|
||||
if group_size == -1:
|
||||
group_size = qweight.shape[0]
|
||||
|
||||
bits = 4
|
||||
shifts = torch.arange(0, 32, bits, device=qzeros.device)
|
||||
|
||||
iweights = torch.bitwise_right_shift(qweight[:, :, None],
|
||||
shifts[None, None, :]).to(torch.int8)
|
||||
|
||||
iweights = iweights.view(iweights.shape[0], -1)
|
||||
|
||||
zeros = torch.bitwise_right_shift(qzeros[:, :, None],
|
||||
shifts[None, None, :]).to(torch.int8)
|
||||
zeros = zeros.view(qzeros.shape[0], -1)
|
||||
zeros = reverse_awq_order(zeros)
|
||||
|
||||
iweights = reverse_awq_order(iweights)
|
||||
|
||||
iweights = torch.bitwise_and(iweights, (2**bits) - 1)
|
||||
zeros = torch.bitwise_and(zeros, (2**bits) - 1)
|
||||
|
||||
scales = scales.repeat_interleave(group_size, dim=0)
|
||||
zeros = zeros.repeat_interleave(group_size, dim=0)
|
||||
return (iweights - zeros) * scales
|
||||
|
||||
|
||||
# qweights - [R , C // 8], int32
|
||||
# scales - [R // G, C ], float16
|
||||
# zeros - [R // G, C // 8], int32
|
||||
@pytest.mark.parametrize("qweight_rows", [3584, 18944, 128, 256, 512, 1024])
|
||||
@pytest.mark.parametrize("qweight_cols", [448, 576, 4736, 16, 32, 64, 128])
|
||||
@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES)
|
||||
def test_dequantize(qweight_rows, qweight_cols, group_size):
|
||||
|
||||
if group_size == -1:
|
||||
group_size = qweight_rows
|
||||
|
||||
qweight_dtype = torch.int32
|
||||
scales_rows = qweight_rows // group_size
|
||||
scales_cols = qweight_cols * 8
|
||||
scales_dtype = torch.float16
|
||||
zeros_rows = scales_rows
|
||||
zeros_cols = qweight_cols
|
||||
zeros_dtype = torch.int32
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
qweight = torch.randint(0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(qweight_rows, qweight_cols),
|
||||
dtype=qweight_dtype,
|
||||
device=device)
|
||||
scales = torch.rand(scales_rows,
|
||||
scales_cols,
|
||||
dtype=scales_dtype,
|
||||
device=device)
|
||||
zeros = torch.randint(0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(zeros_rows, zeros_cols),
|
||||
dtype=zeros_dtype,
|
||||
device=device)
|
||||
|
||||
iweights_triton = awq_dequantize_triton(qweight, scales, zeros)
|
||||
|
||||
assert (not torch.any(torch.isinf(iweights_triton))
|
||||
and not torch.any(torch.isnan(iweights_triton)))
|
||||
|
||||
iweights_torch = awq_dequantize_torch(qweight, scales, zeros, group_size)
|
||||
|
||||
torch.testing.assert_close(iweights_triton, iweights_torch)
|
||||
|
||||
|
||||
# input - [N, K]
|
||||
# qweight - [K, M // 8]
|
||||
# qzeros - [K // G, M // 8]
|
||||
# scales - [K // G, M]
|
||||
@pytest.mark.parametrize("N", [1, 2, 4, 8, 14, 17, 23, 32])
|
||||
@pytest.mark.parametrize("K", [128])
|
||||
@pytest.mark.parametrize("M", [16, 24, 32])
|
||||
@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("splitK", [1, 8])
|
||||
def test_gemm(N, K, M, splitK, group_size):
|
||||
|
||||
if group_size == -1:
|
||||
group_size = K
|
||||
|
||||
split_k_iters = splitK
|
||||
|
||||
input_rows = N
|
||||
input_cols = K
|
||||
input_dtype = torch.float32
|
||||
qweight_rows = input_cols
|
||||
qweight_cols = M // 8
|
||||
scales_rows = qweight_rows // group_size
|
||||
scales_cols = M
|
||||
scales_dtype = torch.float32
|
||||
qzeros_rows = scales_rows
|
||||
qzeros_cols = qweight_cols
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
input = torch.rand((input_rows, input_cols),
|
||||
dtype=input_dtype,
|
||||
device=device)
|
||||
qweight = torch.randint(0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(qweight_rows, qweight_cols),
|
||||
device=device)
|
||||
qzeros = torch.randint(0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(qzeros_rows, qzeros_cols),
|
||||
device=device)
|
||||
scales = torch.rand((scales_rows, scales_cols),
|
||||
dtype=scales_dtype,
|
||||
device=device)
|
||||
|
||||
output_triton = awq_gemm_triton(input, qweight, scales, qzeros,
|
||||
split_k_iters)
|
||||
|
||||
assert (not torch.any(torch.isinf(output_triton))
|
||||
and not torch.any(torch.isnan(output_triton)))
|
||||
|
||||
dequantized_weights = awq_dequantize_triton(qweight, scales, qzeros)
|
||||
|
||||
output_torch = torch.matmul(input, dequantized_weights)
|
||||
|
||||
assert (not torch.any(torch.isinf(output_torch))
|
||||
and not torch.any(torch.isnan(output_torch)))
|
||||
|
||||
torch.testing.assert_close(output_triton.cpu(),
|
||||
output_torch.cpu(),
|
||||
atol=1e-1,
|
||||
rtol=1e-1)
|
||||
449
tests/kernels/quantization/test_block_fp8.py
Normal file
449
tests/kernels/quantization/test_block_fp8.py
Normal file
@@ -0,0 +1,449 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from https://github.com/sgl-project/sglang/pull/2575
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils_block import native_w8a8_block_matmul
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
deep_gemm_moe_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
moe_align_block_size)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
dg_available = False
|
||||
try:
|
||||
import deep_gemm
|
||||
dg_available = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if current_platform.get_device_capability() < (9, 0):
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
||||
allow_module_level=True)
|
||||
|
||||
# Test configurations
|
||||
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
|
||||
NUM_TOKENS = [7, 83, 2048]
|
||||
D = [512, 4096, 5120, 13824]
|
||||
GROUP_SIZE = [64, 128, 256, 512]
|
||||
M = [1, 7, 8, 83, 84, 512, 2048, 4096]
|
||||
N = [128, 512, 1024, 4096, 7168, 7748, 13824]
|
||||
K = [256, 4096, 5120, 3884, 13824, 16384]
|
||||
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
|
||||
# and its hidden size is 7168.
|
||||
M_moe = [1, 2, 7, 83, 128, 512, 2048]
|
||||
M_moe_dg = [128, 192, 512, 1335, 2048]
|
||||
N_moe = [128, 256, 1024, 4608] # [13824]
|
||||
K_moe = [256, 512, 7168] # [13824]
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
E = [2, 8, 16, 24] # [128, 256]
|
||||
TOP_KS = [1, 2, 6]
|
||||
OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
def native_per_token_group_quant_fp8(x,
|
||||
group_size,
|
||||
eps=1e-10,
|
||||
dtype=torch.float8_e4m3fn):
|
||||
"""Function to perform per-token-group quantization on an input tensor
|
||||
`x` using native torch."""
|
||||
assert x.shape[-1] % group_size == 0, ("the last dimension of `x` cannot "
|
||||
"be divisible by `group_size`")
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
x_ = x.reshape(x.numel() // group_size, group_size)
|
||||
amax = x_.abs().max(dim=-1,
|
||||
keepdim=True)[0].clamp(min=eps).to(torch.float32)
|
||||
x_s = amax / fp8_max
|
||||
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
|
||||
x_q = x_q.reshape(x.shape)
|
||||
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
|
||||
"""Fused moe with block-wise quantization using native torch."""
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
|
||||
a_q = a_q.to(torch.float32)
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
inter_out = native_w8a8_block_matmul(a_q[mask],
|
||||
w1[i],
|
||||
a_s[mask],
|
||||
w1_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype)
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
act_out_q, act_out_s = native_per_token_group_quant_fp8(
|
||||
act_out, block_k)
|
||||
act_out = act_out.to(torch.float32)
|
||||
out[mask] = native_w8a8_block_matmul(act_out_q,
|
||||
w2[i],
|
||||
act_out_s,
|
||||
w2_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype)
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
# Skip all tests if CUDA is not available
|
||||
pytest.importorskip("torch.cuda")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_cuda():
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_tokens,d,dtype,group_size,seed",
|
||||
itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS))
|
||||
@torch.inference_mode()
|
||||
def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
|
||||
torch.manual_seed(seed)
|
||||
x = torch.rand(num_tokens, d, dtype=dtype)
|
||||
|
||||
ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size)
|
||||
out, scale = per_token_group_quant_fp8(x, group_size)
|
||||
|
||||
assert torch.allclose(out.to(torch.float32),
|
||||
ref_out.to(torch.float32),
|
||||
rtol=0.15)
|
||||
assert torch.allclose(scale, ref_scale)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M,N,K,block_size,out_dtype,seed",
|
||||
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
factor_for_scale = 1e-2
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
|
||||
As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale
|
||||
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
|
||||
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
|
||||
out_dtype)
|
||||
out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||
assert rel_diff < 0.001
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M,N,K,E,topk,block_size,dtype,seed",
|
||||
itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES,
|
||||
SEEDS))
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
||||
if topk > E:
|
||||
pytest.skip(f"Skipping test; topk={topk} > E={E}")
|
||||
|
||||
torch.manual_seed(seed)
|
||||
factor_for_scale = 1e-2
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
|
||||
w1_bf16 = (torch.rand(
|
||||
(E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max
|
||||
w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
del w1_bf16
|
||||
|
||||
w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max
|
||||
w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
del w2_bf16
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles_w1 = (2 * N + block_n - 1) // block_n
|
||||
n_tiles_w2 = (K + block_n - 1) // block_n
|
||||
k_tiles_w1 = (K + block_k - 1) // block_k
|
||||
k_tiles_w2 = (N + block_k - 1) // block_k
|
||||
|
||||
w1_s = torch.rand(
|
||||
(E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale
|
||||
w2_s = torch.rand(
|
||||
(E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale
|
||||
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
vllm_config = VllmConfig()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
out = fused_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=block_size,
|
||||
)
|
||||
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk,
|
||||
block_size)
|
||||
|
||||
#print(f"{out.sum()=}")
|
||||
#print(f"{ref_out.sum()=}")
|
||||
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||
assert rel_diff < 0.03
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(
|
||||
x: torch.Tensor,
|
||||
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros(
|
||||
(deep_gemm.ceil_div(m, 128) * 128,
|
||||
deep_gemm.ceil_div(n, block_size_n) * block_size_n),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
|
||||
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
||||
return x_scaled_sub, scales
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M,N,K,block_size,out_dtype,seed",
|
||||
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
# only aligned sizes
|
||||
if M % 4 != 0 or K % 128 != 0 or N % 64 != 0:
|
||||
pytest.skip(f"Skipping test; invalid size {M}, {N}, {K}")
|
||||
|
||||
torch.manual_seed(seed)
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max = fp8_info.max
|
||||
|
||||
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
|
||||
|
||||
_, block_k = block_size[0], block_size[1]
|
||||
|
||||
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_k)
|
||||
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32)
|
||||
|
||||
As = As_fp8.to(torch.float32)
|
||||
Bs = Bs_fp8.to(torch.float32)
|
||||
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
|
||||
out_dtype)
|
||||
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
As_fp8 = deep_gemm.get_col_major_tma_aligned_tensor(As_fp8)
|
||||
|
||||
out = torch.zeros((M, N), device='cuda', dtype=out_dtype)
|
||||
|
||||
assert As_fp8.shape == (M, (K + 127) //
|
||||
128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}"
|
||||
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
|
||||
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||
assert rel_diff < 0.001
|
||||
|
||||
|
||||
def fp8_perm(m, idx):
|
||||
if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8:
|
||||
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
|
||||
else:
|
||||
return m[idx, ...]
|
||||
|
||||
|
||||
def _moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
|
||||
M, K = a.shape
|
||||
|
||||
sorted_token_ids, m_indices, num_pad = moe_align_block_size(
|
||||
topk_ids, block_m, num_groups, None, pad_sorted_ids=True)
|
||||
|
||||
num_tokens = topk * M
|
||||
|
||||
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
|
||||
m_indices = torch.repeat_interleave(m_indices, block_m, dim=0)
|
||||
inv_perm = torch.argsort(sorted_token_ids)[:M * topk]
|
||||
|
||||
a = fp8_perm(a, sorted_token_ids // topk)
|
||||
if a_s is not None:
|
||||
a_s = a_s[sorted_token_ids // topk]
|
||||
|
||||
return a, a_s, m_indices, inv_perm
|
||||
|
||||
|
||||
def _moe_unpermute(out, inv_perm, topk, K, topk_weight):
|
||||
M = topk_weight.shape[0]
|
||||
out = out[inv_perm, ...]
|
||||
tmp_out = out.view(-1, topk, K)
|
||||
return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
|
||||
block_shape):
|
||||
"""Fused moe with block-wise quantization using DeepGemm grouped gemm."""
|
||||
num_groups = w1.shape[0]
|
||||
M, K = a.shape
|
||||
N = w2.shape[-1]
|
||||
|
||||
topk_weight, topk_ids = fused_topk(a, score.float(), topk, False)
|
||||
|
||||
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
|
||||
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
|
||||
a_q, a_s = per_token_group_quant_fp8(a, block_m)
|
||||
|
||||
a_q, a_s, m_indices, inv_perm = _moe_permute(a_q, a_s, topk_ids,
|
||||
num_groups, topk, block_m)
|
||||
|
||||
inter_out = torch.zeros((a_q.shape[0], N * 2),
|
||||
dtype=torch.bfloat16,
|
||||
device=a.device)
|
||||
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s),
|
||||
inter_out, m_indices)
|
||||
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k)
|
||||
|
||||
out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device)
|
||||
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(act_out_q, act_out_s), (w2, w2_s), out, m_indices)
|
||||
|
||||
final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight)
|
||||
|
||||
return final_out
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M,N,K,E,topk,seed",
|
||||
itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, SEEDS))
|
||||
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
|
||||
|
||||
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
|
||||
block_size = [block_m, block_m]
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# only aligned sizes
|
||||
if (N % block_m != 0 or K % block_m != 0 or topk > E):
|
||||
pytest.skip(
|
||||
f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}")
|
||||
|
||||
if N <= 512:
|
||||
pytest.skip("Skipping N <= 512 until performance issues solved.")
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
|
||||
torch.manual_seed(seed)
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
|
||||
w1_bf16 = ((torch.rand((E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 *
|
||||
fp8_max).clamp(min=fp8_min, max=fp8_max)
|
||||
|
||||
w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 *
|
||||
fp8_max).clamp(min=fp8_min, max=fp8_max)
|
||||
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles_w1 = ((2 * N) + block_n - 1) // block_n
|
||||
k_tiles_w1 = (K + block_k - 1) // block_k
|
||||
n_tiles_w2 = (K + block_n - 1) // block_n
|
||||
k_tiles_w2 = (N + block_k - 1) // block_k
|
||||
|
||||
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
|
||||
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
|
||||
|
||||
w1_s = torch.empty((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
|
||||
w2_s = torch.empty((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
|
||||
|
||||
w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous()
|
||||
w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous()
|
||||
|
||||
assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128)
|
||||
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
|
||||
|
||||
for i in range(E):
|
||||
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
|
||||
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
if M >= 128:
|
||||
ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s,
|
||||
score, topk, block_size)
|
||||
else:
|
||||
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score,
|
||||
topk, block_size)
|
||||
|
||||
topk_weights, topk_ids = fused_topk(a, score.float(), topk, False)
|
||||
|
||||
out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
|
||||
|
||||
#print(f"{out.sum()=}")
|
||||
#print(f"{ref_out.sum()=}")
|
||||
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||
|
||||
assert rel_diff < 0.03
|
||||
198
tests/kernels/quantization/test_block_int8.py
Normal file
198
tests/kernels/quantization/test_block_int8.py
Normal file
@@ -0,0 +1,198 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from https://github.com/sgl-project/sglang/blob/main/test/srt/test_block_int8.py
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils_block import native_w8a8_block_matmul
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
w8a8_block_int8_matmul)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.get_device_capability() < (7, 0):
|
||||
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
|
||||
allow_module_level=True)
|
||||
|
||||
|
||||
# For test
|
||||
def native_per_token_group_quant_int8(x,
|
||||
group_size,
|
||||
eps=1e-10,
|
||||
dtype=torch.int8):
|
||||
"""Function to perform per-token-group quantization on an input tensor
|
||||
`x` using native torch.
|
||||
|
||||
It converts the tensor values into int8 values and returns the
|
||||
quantized tensor along with the scaling factor used for quantization.
|
||||
"""
|
||||
assert (x.shape[-1] % group_size == 0
|
||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
iinfo = torch.iinfo(dtype)
|
||||
int8_min = iinfo.min
|
||||
int8_max = iinfo.max
|
||||
|
||||
x_ = x.reshape(x.numel() // group_size, group_size)
|
||||
# Use float32 for scale calculation for stability
|
||||
amax = x_.abs().max(dim=-1,
|
||||
keepdim=True)[0].clamp(min=eps).to(torch.float32)
|
||||
x_s = amax / int8_max
|
||||
x_q = (x_.to(torch.float32) / x_s).round().clamp(
|
||||
min=int8_min, max=int8_max).to(dtype) # Round before clamping
|
||||
x_q = x_q.reshape(x.shape)
|
||||
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
# For test
|
||||
def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
|
||||
"""This function performs fused moe with block-wise quantization using
|
||||
native torch."""
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
a_q, a_s = native_per_token_group_quant_int8(a, block_k)
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
inter_out = native_w8a8_block_matmul(a_q[mask],
|
||||
w1[i],
|
||||
a_s[mask],
|
||||
w1_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype)
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
act_out_q, act_out_s = native_per_token_group_quant_int8(
|
||||
act_out, block_k)
|
||||
act_out = act_out.to(torch.float32)
|
||||
out[mask] = native_w8a8_block_matmul(act_out_q,
|
||||
w2[i],
|
||||
act_out_s,
|
||||
w2_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype)
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
M = [1, 33, 64, 222]
|
||||
N = [128, 1024]
|
||||
K = [256, 4096]
|
||||
E = [8, 24]
|
||||
TOP_KS = [2, 6]
|
||||
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def setup_cuda():
|
||||
"""Sets the default CUDA device for all tests in this module."""
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M,N,K,block_size,out_dtype,seed",
|
||||
itertools.product(M, N, K, BLOCK_SIZE, DTYPES, SEEDS))
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
factor_for_scale = 1e-2
|
||||
int8_info = torch.iinfo(torch.int8)
|
||||
int8_max, int8_min = int8_info.max, int8_info.min
|
||||
|
||||
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * int8_max
|
||||
A_fp8 = A_fp32.clamp(min=int8_min, max=int8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * int8_max
|
||||
B_fp8 = B_fp32.clamp(min=int8_min, max=int8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
|
||||
As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale
|
||||
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
|
||||
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
|
||||
out_dtype)
|
||||
out = w8a8_block_int8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||
assert rel_diff < 0.001
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M, N, K, E, topk, block_size, dtype, seed",
|
||||
itertools.product(M, N, K, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS))
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
||||
"""Tests the fused_moe kernel with W8A8 INT8 block quantization against a
|
||||
native torch reference."""
|
||||
torch.manual_seed(seed)
|
||||
# Use a smaller factor for scale initialization to prevent large
|
||||
# values/overflow especially when output dtype might be float16
|
||||
factor_for_scale = 1e-2
|
||||
int8_info = torch.iinfo(torch.int8)
|
||||
int8_max, int8_min = int8_info.max, int8_info.min
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
|
||||
w1_fp32 = (torch.rand(
|
||||
(E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 * int8_max
|
||||
w1 = w1_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 * int8_max
|
||||
w2 = w2_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles_w1 = (2 * N + block_n - 1) // block_n
|
||||
n_tiles_w2 = (K + block_n - 1) // block_n
|
||||
k_tiles_w1 = (K + block_k - 1) // block_k
|
||||
k_tiles_w2 = (N + block_k - 1) // block_k
|
||||
|
||||
w1_s = (torch.rand(
|
||||
(E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale)
|
||||
w2_s = (torch.rand(
|
||||
(E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale)
|
||||
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
vllm_config = VllmConfig()
|
||||
with set_current_vllm_config(vllm_config):
|
||||
out = fused_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
use_int8_w8a8=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=block_size,
|
||||
)
|
||||
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk,
|
||||
block_size)
|
||||
|
||||
# Check results
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||
assert rel_diff < 0.06
|
||||
264
tests/kernels/quantization/test_cutlass_2of4_sparse.py
Normal file
264
tests/kernels/quantization/test_cutlass_2of4_sparse.py
Normal file
@@ -0,0 +1,264 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Tests for sparse cutlass kernels
|
||||
|
||||
Run `pytest tests/kernels/test_semi_structured.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import baseline_scaled_mm, to_fp8, to_int8
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
sparse_cutlass_supported)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
|
||||
def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor.to(dtype=torch.bfloat16)
|
||||
|
||||
|
||||
def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor.to(dtype=torch.float16)
|
||||
|
||||
|
||||
def prune_to_2_4(tensor):
|
||||
# Reshape tensor to [N, 4] where N is number of groups of 4
|
||||
original_shape = tensor.shape
|
||||
reshaped = tensor.reshape(-1, 4)
|
||||
|
||||
# Get indices of top 2 absolute values in each group of 4
|
||||
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
|
||||
|
||||
# Create binary mask
|
||||
mask = torch.zeros_like(reshaped)
|
||||
mask.scatter_(dim=1,
|
||||
index=indices,
|
||||
src=torch.ones_like(indices, dtype=mask.dtype))
|
||||
|
||||
# Apply mask and reshape back
|
||||
pruned = reshaped * mask
|
||||
|
||||
# Turn all -0.0 to 0.0
|
||||
pruned[pruned == -0.0] = 0.0
|
||||
|
||||
return pruned.reshape(original_shape)
|
||||
|
||||
|
||||
# This function checks that applying an identity matrix multiplication
|
||||
# to the compressed weights yields the original uncompressed weights.
|
||||
def check_compress_decompress_invariance(dtype: torch.dtype, b: torch.Tensor,
|
||||
b_compressed: torch.Tensor,
|
||||
b_metadata: torch.Tensor):
|
||||
|
||||
# For float16 and bfloat16, cutlass_scaled_sparse_mm's output must be the
|
||||
# same dtype as its inputs. This line addresses that constraint while
|
||||
# arbitrarily using bfloat16 for the int8/fp8 cases.
|
||||
out_dtype = torch.float16 if dtype is torch.float16 else torch.bfloat16
|
||||
|
||||
eye = torch.eye(b.shape[0], device='cuda', dtype=dtype)
|
||||
eye_scale = torch.ones(1, device='cuda', dtype=torch.float32)
|
||||
b_decomp = ops.cutlass_scaled_sparse_mm(eye,
|
||||
b_compressed,
|
||||
b_metadata,
|
||||
eye_scale,
|
||||
eye_scale,
|
||||
out_dtype=out_dtype)
|
||||
|
||||
torch.testing.assert_close(b.to(dtype=out_dtype), b_decomp)
|
||||
|
||||
|
||||
def make_rand_sparse_tensors(
|
||||
dtype: torch.dtype, m: int, n: int, k: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
a = torch.randn((m, k), device='cuda')
|
||||
b = torch.randn((n, k), device='cuda').t()
|
||||
|
||||
if dtype == torch.int8:
|
||||
# ensure A and B aren't all zeros after rounding
|
||||
a = a * 5.0
|
||||
b = b * 5.0
|
||||
|
||||
b = prune_to_2_4(b.t()).t()
|
||||
|
||||
if dtype == torch.int8:
|
||||
a, b = to_int8(a), to_int8(b)
|
||||
elif dtype == torch.float8_e4m3fn:
|
||||
a, b = to_fp8(a), to_fp8(b)
|
||||
elif dtype == torch.float16:
|
||||
a, b = to_fp16(a), to_fp16(b)
|
||||
elif dtype == torch.bfloat16:
|
||||
a, b = to_bf16(a), to_bf16(b)
|
||||
else:
|
||||
raise ValueError("unsupported dtype")
|
||||
|
||||
b_compressed, e = ops.cutlass_sparse_compress(b.t())
|
||||
check_compress_decompress_invariance(dtype, b, b_compressed, e)
|
||||
|
||||
# Compressed B, Metadata, Original A, B
|
||||
return b_compressed, e, a, b
|
||||
|
||||
|
||||
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.")
|
||||
# Test working with a subset of A and B for sparse matmul
|
||||
def test_cutlass_sparse_subset():
|
||||
|
||||
big_m = 1024
|
||||
m, n, k = 512, 512, 512
|
||||
|
||||
# Create tensors
|
||||
b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn,
|
||||
big_m, n, k)
|
||||
a = whole_a[0:m, 0:k]
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(a,
|
||||
b_comp,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
baseline = baseline_scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 256, 128),
|
||||
(1, 16384, 1024),
|
||||
(1, 24576, 512),
|
||||
(16, 256, 512),
|
||||
(16, 16384, 128),
|
||||
(16, 24576, 4096),
|
||||
(32, 8192, 4096),
|
||||
(32, 16384, 4096),
|
||||
(33, 1024, 1024),
|
||||
(33, 8192, 128),
|
||||
(64, 2048, 512),
|
||||
(64, 16384, 1024),
|
||||
(100, 8192, 512),
|
||||
(128, 32768, 4096),
|
||||
(256, 4096, 4096),
|
||||
(512, 256, 1024),
|
||||
(512, 8192, 4096),
|
||||
(512, 16384, 128),
|
||||
(512, 24576, 128),
|
||||
]
|
||||
|
||||
|
||||
# Test working with a subset of A and B for sparse matmul
|
||||
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("m, n, k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: type[torch.dtype],
|
||||
use_bias: bool):
|
||||
|
||||
# Create tensors
|
||||
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
|
||||
scale_a = torch.ones((1, 1), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.ones((1, 1), device="cuda", dtype=torch.float32)
|
||||
|
||||
bias = torch.rand((n, ), device="cuda", dtype=dtype) if use_bias else None
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(a,
|
||||
b_comp,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=dtype,
|
||||
bias=bias)
|
||||
|
||||
baseline = baseline_scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=dtype,
|
||||
bias=bias)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("m, k, n", MNK_FACTORS)
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_sparse_fp8_gemm(m: int, n: int, k: int, use_bias: bool):
|
||||
|
||||
# Create tensors
|
||||
b_comp, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
|
||||
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
|
||||
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
|
||||
out_dtype = torch.bfloat16
|
||||
|
||||
bias = torch.rand(
|
||||
(n, ), device="cuda", dtype=out_dtype) * 10 if use_bias else None
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(a,
|
||||
b_comp,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=out_dtype,
|
||||
bias=bias)
|
||||
|
||||
baseline = baseline_scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=out_dtype,
|
||||
bias=bias)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("m,k,n", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_sparse_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
|
||||
per_out_ch: bool, use_bias: bool):
|
||||
|
||||
# Create tensors
|
||||
b_comp, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
|
||||
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
|
||||
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
|
||||
out_dtype = torch.bfloat16
|
||||
|
||||
bias = torch.rand(
|
||||
(n, ), device="cuda", dtype=out_dtype) * 10 if use_bias else None
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(a,
|
||||
b_comp,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=out_dtype,
|
||||
bias=bias)
|
||||
|
||||
baseline = baseline_scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=out_dtype,
|
||||
bias=bias)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)
|
||||
641
tests/kernels/quantization/test_cutlass_scaled_mm.py
Normal file
641
tests/kernels/quantization/test_cutlass_scaled_mm.py
Normal file
@@ -0,0 +1,641 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Tests for cutlass kernels
|
||||
|
||||
Run `pytest tests/kernels/test_cutlass.py`.
|
||||
"""
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import baseline_scaled_mm, opcheck, to_fp8, to_int8
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 256, 128),
|
||||
(1, 16384, 1024),
|
||||
(1, 24576, 496),
|
||||
(16, 256, 496),
|
||||
(16, 16384, 128),
|
||||
(16, 24576, 4096),
|
||||
(32, 8192, 4096),
|
||||
(32, 16384, 4096),
|
||||
(33, 1024, 1024),
|
||||
(33, 8192, 128),
|
||||
(64, 2048, 496),
|
||||
(64, 16384, 1024),
|
||||
(100, 8192, 496),
|
||||
(128, 32768, 4096),
|
||||
(256, 4096, 4096),
|
||||
(512, 256, 1024),
|
||||
(512, 8192, 4096),
|
||||
(512, 16384, 128),
|
||||
(512, 24576, 128),
|
||||
]
|
||||
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
# -1 means full extent in that dimension
|
||||
TENSORWISE_GROUP_SHAPE = (-1, -1)
|
||||
PER_TOKEN_GROUP_SHAPE = (1, -1)
|
||||
PER_OUT_CH_GROUP_SHAPE = (-1, 1)
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
|
||||
def rand_int8(shape: tuple, device: str = "cuda"):
|
||||
return to_int8(torch.rand(shape, device=device) * 255 - 128)
|
||||
|
||||
|
||||
def group_scale_helper(shape, group_shape):
|
||||
return [shape[i] if s < 0 else s for i, s in enumerate(group_shape)]
|
||||
|
||||
|
||||
def scale_shape(shape, group_shape):
|
||||
assert len(shape) == len(group_shape)
|
||||
group_shape = group_scale_helper(shape, group_shape)
|
||||
return tuple(
|
||||
cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
|
||||
|
||||
|
||||
def cutlass_fp8_gemm_helper(m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
a_scale_group_shape: tuple,
|
||||
b_scale_group_shape: tuple,
|
||||
use_bias: bool,
|
||||
out_dtype: type[torch.dtype] = torch.bfloat16,
|
||||
device: str = "cuda"):
|
||||
# Test for a cutlass kernel with per-token activation quantization
|
||||
# and per-output channel weight quantization.
|
||||
a = to_fp8(torch.randn((m, k), device=device))
|
||||
b = to_fp8(torch.randn((n, k), device=device).t())
|
||||
|
||||
a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
|
||||
b_scales_shape = scale_shape(b.shape, b_scale_group_shape)
|
||||
|
||||
scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32))
|
||||
scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32))
|
||||
|
||||
# make scales M-major for blockwise quant, doesn't affect 1D scales
|
||||
scale_a = scale_a.t().contiguous().t()
|
||||
# make scales K-major for blockwise quant, doesn't affect 1D scales
|
||||
scale_b = scale_b.t().contiguous().t()
|
||||
|
||||
if use_bias:
|
||||
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
|
||||
else:
|
||||
bias = None
|
||||
|
||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2)
|
||||
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm,
|
||||
(out, a, b, scale_a, scale_b, bias))
|
||||
|
||||
|
||||
def cutlass_int8_gemm_helper(m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
a_scale_group_shape: tuple,
|
||||
b_scale_group_shape: tuple,
|
||||
use_bias: bool,
|
||||
out_dtype: type[torch.dtype] = torch.bfloat16,
|
||||
device: str = "cuda"):
|
||||
# Test for a cutlass kernel with per-token activation quantization
|
||||
# and per-output channel weight quantization.
|
||||
a = to_int8(torch.randn((m, k), device=device) * 5)
|
||||
b = to_int8(torch.randn((n, k), device=device).t() * 5)
|
||||
|
||||
a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
|
||||
b_scales_shape = scale_shape(b.shape, b_scale_group_shape)
|
||||
|
||||
scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32))
|
||||
scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32))
|
||||
|
||||
if use_bias:
|
||||
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
|
||||
else:
|
||||
bias = None
|
||||
|
||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm,
|
||||
(out, a, b, scale_a, scale_b, bias))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm(m: int, n: int, k: int, a_scale_group_shape,
|
||||
b_scale_group_shape, use_bias: bool):
|
||||
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
|
||||
use_bias)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape",
|
||||
[((1, 128), (128, 128))])
|
||||
@pytest.mark.parametrize("use_bias", [False])
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
reason="FP8 blockwise is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape, use_bias: bool):
|
||||
if k % b_scale_group_shape[0] != 0 or n % b_scale_group_shape[1] != 0:
|
||||
return
|
||||
if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0:
|
||||
return
|
||||
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
|
||||
use_bias)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_int8_gemm(m: int, n: int, k: int, a_scale_group_shape,
|
||||
b_scale_group_shape, use_bias: bool):
|
||||
cutlass_int8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
|
||||
use_bias)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
out_dtype: type[torch.dtype],
|
||||
use_bias: bool):
|
||||
cutlass_int8_gemm_helper(512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=out_dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
out_dtype: type[torch.dtype],
|
||||
use_bias: bool):
|
||||
cutlass_fp8_gemm_helper(512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=out_dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape",
|
||||
[((1, 128), (128, 128))])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [False])
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
reason="FP8 blockwise is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
out_dtype: type[torch.dtype],
|
||||
use_bias: bool):
|
||||
cutlass_fp8_gemm_helper(512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=out_dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
|
||||
use_bias: bool, device: str):
|
||||
cutlass_fp8_gemm_helper(512, 512, 512, a_scale_group_shape,
|
||||
b_scale_group_shape, use_bias, torch.bfloat16,
|
||||
device)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
|
||||
use_bias: bool, device: str):
|
||||
cutlass_int8_gemm_helper(512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=torch.bfloat16,
|
||||
device=device)
|
||||
|
||||
|
||||
# For the following two tests:
|
||||
# N and K correspond to the size of the weight matrix and likely to be multiples
|
||||
# of a large power of two. In any case, the kernel will have a naive fallback
|
||||
# when N and K are not divisible by 16. But M is the number of tokens and the
|
||||
# kernel must handle any M thrown at it.
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape,
|
||||
use_bias: bool):
|
||||
for nk in range(32, 128, 32):
|
||||
for m in range(1, 128):
|
||||
cutlass_fp8_gemm_helper(m, nk, nk, a_scale_group_shape,
|
||||
b_scale_group_shape, use_bias)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_int8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape,
|
||||
use_bias: bool):
|
||||
for nk in range(32, 128, 32):
|
||||
for m in range(1, 128):
|
||||
cutlass_int8_gemm_helper(m, nk, nk, a_scale_group_shape,
|
||||
b_scale_group_shape, use_bias)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [32, 64, 128])
|
||||
@pytest.mark.parametrize("n", [16, 32, 64])
|
||||
@pytest.mark.parametrize("k", [64, 128, 256])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.skip
|
||||
def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
|
||||
out_dtype: torch.dtype):
|
||||
# Currently, the test is failing because folding azp into
|
||||
# 16-bit bias loses too much precision
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
aq_i8 = rand_int8((m, k))
|
||||
bq_i8 = rand_int8((n, k)).t()
|
||||
|
||||
aq_i32 = aq_i8.to(dtype=torch.int32)
|
||||
bq_i32 = bq_i8.to(dtype=torch.int32)
|
||||
|
||||
aq_f32 = aq_i8.to(dtype=torch.float32)
|
||||
bq_f32 = bq_i8.to(dtype=torch.float32)
|
||||
|
||||
b_dq = scale_b * bq_f32
|
||||
|
||||
azp_a = torch.rand((1, ), device="cuda", dtype=torch.float32) * 10 + 1.5
|
||||
azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
|
||||
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
|
||||
|
||||
a_dq = scale_a * (aq_i32 + azp_aq_i8).to(dtype=torch.float32)
|
||||
torch.testing.assert_close(a_dq, scale_a * aq_f32 + azp_a)
|
||||
|
||||
baseline_dq = torch.mm(a_dq, b_dq).to(out_dtype)
|
||||
|
||||
J = torch.ones((1, k), device="cuda", dtype=torch.float32)
|
||||
azp_bias = (azp_a * scale_b * (J @ bq_f32)).to(out_dtype)
|
||||
assert azp_bias.shape == (1, n)
|
||||
assert azp_bias[0, :].shape == (n, )
|
||||
|
||||
baseline_q = (scale_a.to(device='cpu') * scale_b.to(device='cpu') * (
|
||||
(aq_i32 + azp_aq_i8).to(device='cpu') @ bq_i32.to(device='cpu'))).to(
|
||||
dtype=out_dtype, device='cuda')
|
||||
|
||||
out = ops.cutlass_scaled_mm(aq_i8,
|
||||
bq_i8,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=out_dtype,
|
||||
bias=azp_bias[0, :])
|
||||
torch.testing.assert_close(out, baseline_dq, rtol=1e-2, atol=1e0)
|
||||
torch.testing.assert_close(out, baseline_q, rtol=1e-2, atol=1e0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [32, 64, 128])
|
||||
@pytest.mark.parametrize("n", [16, 32, 64])
|
||||
@pytest.mark.parametrize("k", [64, 128, 256])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("azp_per_token", [True, False])
|
||||
def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
|
||||
use_bias: bool, azp_per_token: bool):
|
||||
m_azp = m if azp_per_token else 1
|
||||
scale_a = torch.randn((m_azp, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
aq_i8 = rand_int8((m, k))
|
||||
aq_i32 = aq_i8.to(dtype=torch.int32)
|
||||
aq_f32 = aq_i8.to(dtype=torch.float32)
|
||||
|
||||
bq_i8 = rand_int8((n, k)).t()
|
||||
bq_i32 = bq_i8.to(dtype=torch.int32)
|
||||
bq_f32 = bq_i8.to(dtype=torch.float32)
|
||||
b_dq = scale_b * bq_f32
|
||||
|
||||
azp_a = torch.rand(
|
||||
(m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5
|
||||
azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
|
||||
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
|
||||
|
||||
a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
|
||||
torch.testing.assert_close(a_dq,
|
||||
scale_a * aq_f32 - azp_a,
|
||||
rtol=1e-4,
|
||||
atol=1e-3)
|
||||
|
||||
if use_bias:
|
||||
bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
|
||||
else:
|
||||
bias = torch.zeros((1, n), device="cuda", dtype=out_dtype)
|
||||
|
||||
baseline_dq = (torch.mm(a_dq, b_dq) + bias).to(out_dtype)
|
||||
|
||||
# int32 mm not supported on CUDA
|
||||
a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device='cpu')
|
||||
cq = (a_noazp_i32_cpu @ bq_i32.to(device='cpu')).to(device='cuda')
|
||||
baseline_q = (scale_a * scale_b * cq + bias).to(dtype=out_dtype)
|
||||
|
||||
# Hadamard is just the sum of the cols
|
||||
azp_adj_i32 = bq_i32.sum(dim=0, keepdim=True, dtype=torch.int32)
|
||||
azp_i32 = azp_aq_i8.to(dtype=torch.int32)
|
||||
func_bias = bias if use_bias else None
|
||||
|
||||
if azp_per_token:
|
||||
out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
|
||||
out_dtype, azp_adj_i32, azp_i32,
|
||||
func_bias)
|
||||
else:
|
||||
azp_with_adj_i32 = azp_i32 * azp_adj_i32
|
||||
out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
|
||||
out_dtype, azp_with_adj_i32, None,
|
||||
func_bias)
|
||||
|
||||
# bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
|
||||
# float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
|
||||
rtol = 1e-2 if out_dtype == torch.bfloat16 else 1e-3
|
||||
atol = 1e-3
|
||||
torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
|
||||
torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
|
||||
|
||||
if azp_per_token:
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm_azp,
|
||||
(out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32,
|
||||
func_bias))
|
||||
else:
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm_azp,
|
||||
(out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None,
|
||||
func_bias))
|
||||
|
||||
|
||||
# Test working with a subset of A and B
|
||||
def test_cutlass_subset():
|
||||
big_m, big_n, big_k = 1024, 1024, 1024
|
||||
m, n, k = 512, 512, 512
|
||||
|
||||
whole_a = to_int8(torch.randn((big_m, big_k), device="cuda") * 5)
|
||||
whole_b = to_int8(torch.randn((big_n, big_k), device="cuda").t() * 5)
|
||||
a = whole_a[0:m, 0:k]
|
||||
b = whole_b[0:k, 0:n]
|
||||
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
out = ops.cutlass_scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
baseline = baseline_scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
|
||||
# Test to make sure cuda graphs work
|
||||
class CutlassLayer(torch.nn.Module):
|
||||
|
||||
def __init__(self, b, scale_a, scale_b, out_dtype):
|
||||
super().__init__()
|
||||
self.b = b
|
||||
self.scale_a = scale_a
|
||||
self.scale_b = scale_b
|
||||
self.out_dtype = out_dtype
|
||||
|
||||
def forward(self, a):
|
||||
return ops.cutlass_scaled_mm(a, self.b, self.scale_a, self.scale_b,
|
||||
self.out_dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
|
||||
m, n, k = 512, 512, 512
|
||||
|
||||
a = to_int8(torch.randn((m, k), device="cuda"))
|
||||
b = to_int8(torch.randn((n, k), device="cuda").t())
|
||||
|
||||
m_a_scales = m if per_act_token else 1
|
||||
n_b_scales = n if per_out_ch else 1
|
||||
|
||||
scale_a = (torch.randn(
|
||||
(m_a_scales, 1), device="cuda", dtype=torch.float32) / 10)
|
||||
scale_b = (torch.randn(
|
||||
(1, n_b_scales), device="cuda", dtype=torch.float32) / 10)
|
||||
|
||||
# Construct a trivial model with a single layer that calls a CUTLASS kernel
|
||||
model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16)
|
||||
|
||||
# Run the model with a cuda graph
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
out = model(a)
|
||||
out.zero_()
|
||||
g.replay()
|
||||
|
||||
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
|
||||
scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16)
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
|
||||
def test_cutlass_support_opcheck():
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, ))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_experts", [8, 64])
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [False])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
|
||||
per_out_ch: bool, use_bias: bool):
|
||||
|
||||
# Device and dtype setup
|
||||
device = "cuda"
|
||||
out_dtype = torch.half
|
||||
|
||||
# Create separate A, B, C tensors for each group
|
||||
a_tensors = []
|
||||
b_tensors = []
|
||||
a_scales_tensors = []
|
||||
b_scales_tensors = []
|
||||
baseline_tensors = []
|
||||
|
||||
expert_offsets = torch.zeros((num_experts + 1),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
|
||||
problem_sizes = torch.zeros((num_experts, 3),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
|
||||
if not per_act_token:
|
||||
one_scale_a = torch.randn((1, 1), device=device, dtype=torch.float32)
|
||||
|
||||
alignment = 16 # 128 // 8
|
||||
# For variation, each group has dimensions
|
||||
n_g = alignment * random.randint(1, 64)
|
||||
k_g = alignment * random.randint(1, 64)
|
||||
for g in range(num_experts):
|
||||
m_g = alignment * random.randint(1, 64)
|
||||
|
||||
expert_offsets[g + 1] = expert_offsets[g] + m_g
|
||||
problem_sizes[g][0] = m_g
|
||||
problem_sizes[g][1] = n_g
|
||||
problem_sizes[g][2] = k_g
|
||||
|
||||
m_a_scales = m_g if per_act_token else 1
|
||||
n_b_scales = n_g if per_out_ch else 1
|
||||
|
||||
print("shape:", m_g, n_g, k_g)
|
||||
|
||||
# Create group-specific A and B (FP8) and output (FP16/FP32)
|
||||
a_g = to_fp8(torch.randn((m_g, k_g), device=device))
|
||||
b_g = to_fp8(torch.randn((n_g, k_g), device=device).t())
|
||||
a_tensors.append(a_g)
|
||||
b_tensors.append(b_g)
|
||||
|
||||
# Set up A/B scales
|
||||
scale_b = torch.randn((1, n_b_scales),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
b_scales_tensors.append(scale_b)
|
||||
|
||||
if per_act_token:
|
||||
scale_a = torch.randn((m_a_scales, 1),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
a_scales_tensors.append(scale_a)
|
||||
else:
|
||||
scale_a = one_scale_a
|
||||
|
||||
# Compute baseline result for this group
|
||||
baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype,
|
||||
None)
|
||||
baseline_tensors.append(baseline_g)
|
||||
|
||||
a_tensors_stacked = torch.empty((expert_offsets[num_experts], k_g),
|
||||
device=device,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
b_tensors_stacked = torch.empty((num_experts, n_g, k_g),
|
||||
device=device,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
|
||||
for g in range(num_experts):
|
||||
a_tensors_stacked[expert_offsets[g]:expert_offsets[g +
|
||||
1]] = a_tensors[g]
|
||||
b_tensors_stacked[g] = b_tensors[g].t()
|
||||
b_tensors_stacked = b_tensors_stacked.transpose(1, 2)
|
||||
|
||||
if per_act_token:
|
||||
a_scales_tensors_stacked = torch.empty(
|
||||
(expert_offsets[num_experts], 1),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
for g in range(num_experts):
|
||||
a_scales_tensors_stacked[
|
||||
expert_offsets[g]:expert_offsets[g + 1]] = a_scales_tensors[g]
|
||||
else:
|
||||
a_scales_tensors_stacked = one_scale_a
|
||||
|
||||
b_scales_tensors_stacked = torch.empty((num_experts, n_b_scales),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
for g in range(num_experts):
|
||||
b_scales_tensors_stacked[g] = b_scales_tensors[g]
|
||||
|
||||
out_tensors_stacked = torch.zeros((expert_offsets[num_experts], n_g),
|
||||
device=device,
|
||||
dtype=out_dtype)
|
||||
|
||||
ab_strides = torch.full((num_experts, ),
|
||||
a_tensors_stacked.stride(0),
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
c_strides = torch.full((num_experts, ),
|
||||
out_tensors_stacked.stride(0),
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
|
||||
ops.cutlass_moe_mm(out_tensors_stacked, a_tensors_stacked,
|
||||
b_tensors_stacked, a_scales_tensors_stacked,
|
||||
b_scales_tensors_stacked, expert_offsets[:-1],
|
||||
problem_sizes, ab_strides, ab_strides, c_strides)
|
||||
|
||||
# Validate each group's result against the baseline
|
||||
for g in range(num_experts):
|
||||
baseline = baseline_tensors[g]
|
||||
c = out_tensors_stacked[expert_offsets[g]:expert_offsets[g + 1]]
|
||||
print(baseline)
|
||||
print(c)
|
||||
print("*")
|
||||
torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-4)
|
||||
116
tests/kernels/quantization/test_fp8_quant.py
Normal file
116
tests/kernels/quantization/test_fp8_quant.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from tests.kernels.quant_utils import (FP8_DTYPE,
|
||||
ref_dynamic_per_tensor_fp8_quant,
|
||||
ref_dynamic_per_token_quant)
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192,
|
||||
8193] # Arbitrary values for testing
|
||||
HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases
|
||||
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
|
||||
SCALE_UBS = [True, False]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
def opcheck_fp8_quant(output,
|
||||
input,
|
||||
scale=None,
|
||||
scale_ub=None,
|
||||
use_per_token_if_dynamic=False):
|
||||
if scale is not None:
|
||||
opcheck(torch.ops._C.static_scaled_fp8_quant, (output, input, scale))
|
||||
elif use_per_token_if_dynamic:
|
||||
scale = torch.empty((input.shape[0], 1),
|
||||
device=input.device,
|
||||
dtype=torch.float32)
|
||||
opcheck(torch.ops._C.dynamic_per_token_scaled_fp8_quant,
|
||||
(output, input, scale, scale_ub))
|
||||
else:
|
||||
scale = torch.empty((input.numel() // input.shape[-1], 1),
|
||||
device=input.device,
|
||||
dtype=torch.float32)
|
||||
opcheck(torch.ops._C.dynamic_scaled_fp8_quant, (output, input, scale))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("scale_ub", SCALE_UBS)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, scale_ub: bool,
|
||||
seed: int) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
|
||||
device="cuda") + 1e-6 # avoid nans
|
||||
|
||||
scale_ub = torch.mean(x).to(dtype=torch.float32, device='cuda') \
|
||||
if scale_ub else None
|
||||
ref_out, ref_scales = ref_dynamic_per_token_quant(x, FP8_DTYPE, scale_ub)
|
||||
ops_out, ops_scales = ops.scaled_fp8_quant(x,
|
||||
scale_ub=scale_ub,
|
||||
use_per_token_if_dynamic=True)
|
||||
|
||||
torch.testing.assert_close(ref_scales, ops_scales)
|
||||
torch.testing.assert_close(ref_out.to(dtype=torch.float32),
|
||||
ops_out.to(dtype=torch.float32))
|
||||
|
||||
opcheck_fp8_quant(ops_out,
|
||||
x,
|
||||
None,
|
||||
scale_ub,
|
||||
use_per_token_if_dynamic=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||
|
||||
ref_out, ref_scale = ref_dynamic_per_tensor_fp8_quant(x)
|
||||
ops_out, ops_scale = ops.scaled_fp8_quant(x)
|
||||
|
||||
torch.testing.assert_close(ref_scale, ops_scale)
|
||||
torch.testing.assert_close(ref_out.to(dtype=torch.float32),
|
||||
ops_out.to(dtype=torch.float32))
|
||||
|
||||
opcheck_fp8_quant(ops_out, x)
|
||||
|
||||
|
||||
# Regression test for a case with large activations where an int32 index cannot
|
||||
# represent the number of elements.
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
def test_fp8_quant_large(seed: int) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
num_tokens = 1024000 # Mistral-Nemo's max_position_embeddings
|
||||
hidden_size = 1152 # Smallest hidden_size to reproduce the error
|
||||
dtype = torch.bfloat16
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||
ref_out, scale = ref_dynamic_per_tensor_fp8_quant(x)
|
||||
ops_out, _ = ops.scaled_fp8_quant(x, scale)
|
||||
|
||||
# Minimize memory footprint in this test by freeing x and upconverting
|
||||
# the outputs in place. (torch.allclose does not support fp8)
|
||||
del x
|
||||
ref_out = ref_out.to(dtype=dtype)
|
||||
ops_out = ops_out.to(dtype=dtype)
|
||||
|
||||
torch.testing.assert_close(ref_out, ops_out)
|
||||
38
tests/kernels/quantization/test_ggml.py
Normal file
38
tests/kernels/quantization/test_ggml.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import gguf
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops # noqa: F401
|
||||
|
||||
|
||||
@pytest.mark.parametrize("quant_type", [12])
|
||||
def test_ggml_opcheck(quant_type):
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[quant_type]
|
||||
shape = [256, 1152]
|
||||
qweight = torch.randint(0, 100, shape, device='cuda', dtype=torch.uint8)
|
||||
m = qweight.shape[0]
|
||||
n = qweight.shape[1] // type_size * block_size
|
||||
opcheck(torch.ops._C.ggml_dequantize,
|
||||
(qweight, quant_type, m, n, torch.float16))
|
||||
|
||||
x = torch.rand((m, 512), device='cuda', dtype=torch.float16)
|
||||
opcheck(torch.ops._C.ggml_mul_mat_a8,
|
||||
(qweight, x, quant_type, qweight.shape[0]))
|
||||
opcheck(torch.ops._C.ggml_mul_mat_vec_a8,
|
||||
(qweight, x, quant_type, qweight.shape[0]))
|
||||
|
||||
shape = [256, 1024, 336]
|
||||
qweight = torch.randint(0, 100, shape, device='cuda', dtype=torch.uint8)
|
||||
x = torch.rand((1, 1024), device='cuda', dtype=torch.float16)
|
||||
sorted_token_ids = torch.arange(776, device='cuda')
|
||||
expert_ids = torch.randint(0, 256, (194, ), device='cuda')
|
||||
num_tokens_post_padded = torch.tensor([1],
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
|
||||
opcheck(torch.ops._C.ggml_moe_a8,
|
||||
(x, qweight, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
quant_type, qweight.shape[0], 1, x.shape[0]))
|
||||
198
tests/kernels/quantization/test_gguf.py
Normal file
198
tests/kernels/quantization/test_gguf.py
Normal file
@@ -0,0 +1,198 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from gguf import GGMLQuantizationType, GGUFReader, ReaderTensor, dequantize
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.quantization.gguf import _fused_moe_gguf
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample")
|
||||
GGUF_SAMPLE_MOE = snapshot_download("SzymonOzog/test-gguf-moe-sample")
|
||||
|
||||
|
||||
def get_gguf_sample_tensors(
|
||||
hidden_size: int,
|
||||
quant_type: GGMLQuantizationType) -> list[ReaderTensor]:
|
||||
sample_dir = GGUF_SAMPLE
|
||||
filename = f"Quant_{quant_type.name}_{hidden_size}.gguf"
|
||||
sample_file = Path(sample_dir) / filename
|
||||
return GGUFReader(sample_file).tensors
|
||||
|
||||
|
||||
def get_gguf_MoE_tensors(
|
||||
hidden_size: int,
|
||||
quant_type: GGMLQuantizationType) -> list[ReaderTensor]:
|
||||
sample_dir = GGUF_SAMPLE_MOE
|
||||
filename = f"Quant_{quant_type.name}_{hidden_size}.gguf"
|
||||
sample_file = Path(sample_dir) / filename
|
||||
return GGUFReader(sample_file).tensors
|
||||
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float32]
|
||||
# Hidden_size for testing, must match the sample file in HF repo,
|
||||
# we have `hidden_size = 256, 1024` for test in HF repo currently.
|
||||
HIDDEN_SIZES = [256, 1024]
|
||||
NUM_TOKENS = [7, 83, 128, 2048] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
QUANT_TYPES = [
|
||||
# i-matrix
|
||||
GGMLQuantizationType.IQ1_M,
|
||||
GGMLQuantizationType.IQ1_S,
|
||||
GGMLQuantizationType.IQ2_S,
|
||||
GGMLQuantizationType.IQ2_XS,
|
||||
GGMLQuantizationType.IQ3_S,
|
||||
GGMLQuantizationType.IQ3_XXS,
|
||||
GGMLQuantizationType.IQ4_NL,
|
||||
GGMLQuantizationType.IQ4_XS,
|
||||
# k-quants
|
||||
GGMLQuantizationType.Q2_K,
|
||||
GGMLQuantizationType.Q3_K,
|
||||
GGMLQuantizationType.Q4_K,
|
||||
GGMLQuantizationType.Q5_K,
|
||||
GGMLQuantizationType.Q6_K,
|
||||
# standard quantization
|
||||
GGMLQuantizationType.Q4_0,
|
||||
GGMLQuantizationType.Q5_0,
|
||||
GGMLQuantizationType.Q8_0,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
|
||||
@torch.inference_mode()
|
||||
def test_dequantize(hidden_size: int, dtype: torch.dtype,
|
||||
quant_type: GGMLQuantizationType):
|
||||
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
|
||||
for tensor in tensors:
|
||||
shape_str = tensor.name.split("_")[-1]
|
||||
shape = map(int, shape_str.split("x"))
|
||||
|
||||
ref_output = torch.tensor(dequantize(tensor.data, quant_type),
|
||||
device="cuda").to(dtype)
|
||||
output = ops.ggml_dequantize(torch.tensor(tensor.data, device="cuda"),
|
||||
quant_type, *list(shape), dtype)
|
||||
|
||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=4e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
|
||||
@torch.inference_mode()
|
||||
def test_mmvq(hidden_size: int, dtype: torch.dtype,
|
||||
quant_type: GGMLQuantizationType):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
|
||||
x = torch.rand((1, hidden_size), dtype=dtype, device="cuda")
|
||||
for tensor in tensors:
|
||||
weight = torch.tensor(dequantize(tensor.data, quant_type),
|
||||
device="cuda").to(dtype)
|
||||
ref_output = x @ weight.T
|
||||
|
||||
qweight = torch.tensor(tensor.data, device="cuda")
|
||||
output = ops.ggml_mul_mat_vec_a8(qweight, x, quant_type,
|
||||
qweight.shape[0]).to(dtype)
|
||||
|
||||
torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize(
|
||||
"quant_type",
|
||||
[
|
||||
# k-quants
|
||||
GGMLQuantizationType.Q2_K,
|
||||
GGMLQuantizationType.Q3_K,
|
||||
GGMLQuantizationType.Q4_K,
|
||||
GGMLQuantizationType.Q5_K,
|
||||
GGMLQuantizationType.Q6_K,
|
||||
# standard quants
|
||||
GGMLQuantizationType.Q4_0,
|
||||
GGMLQuantizationType.Q5_0,
|
||||
GGMLQuantizationType.Q8_0,
|
||||
])
|
||||
@torch.inference_mode()
|
||||
def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
|
||||
quant_type: GGMLQuantizationType):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
|
||||
x = torch.rand((num_tokens, hidden_size), dtype=dtype, device="cuda")
|
||||
for tensor in tensors:
|
||||
weight = torch.tensor(dequantize(tensor.data, quant_type),
|
||||
device="cuda").to(dtype)
|
||||
ref_output = x @ weight.T
|
||||
|
||||
qweight = torch.tensor(tensor.data, device="cuda")
|
||||
output = ops.ggml_mul_mat_a8(qweight, x, quant_type, qweight.shape[0])
|
||||
atols = {torch.half: 1, torch.bfloat16: 1.5, torch.float: 1.2}
|
||||
# test matrix has inputs centered around 0 and lower precision from
|
||||
# bfloat16 tends to accumulate and can greatly inflate rtol
|
||||
# since outputs are also very close to 0
|
||||
rtols = {torch.half: 1e-1, torch.bfloat16: 1e4, torch.float: 2e1}
|
||||
torch.testing.assert_close(output,
|
||||
ref_output,
|
||||
atol=atols[dtype],
|
||||
rtol=rtols[dtype])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", [512])
|
||||
@pytest.mark.parametrize("top_k", [4, 8])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize(
|
||||
"quant_type",
|
||||
[
|
||||
# k-quants
|
||||
GGMLQuantizationType.Q2_K,
|
||||
GGMLQuantizationType.Q3_K,
|
||||
GGMLQuantizationType.Q4_K,
|
||||
GGMLQuantizationType.Q5_K,
|
||||
GGMLQuantizationType.Q6_K,
|
||||
# standard quants
|
||||
GGMLQuantizationType.Q4_0,
|
||||
GGMLQuantizationType.Q5_0,
|
||||
GGMLQuantizationType.Q8_0,
|
||||
])
|
||||
@torch.inference_mode()
|
||||
def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype,
|
||||
quant_type: GGMLQuantizationType, top_k: int):
|
||||
current_platform.seed_everything(0)
|
||||
H, E = 1024, 256
|
||||
|
||||
x = torch.rand((num_tokens, H), dtype=dtype, device="cuda")
|
||||
|
||||
topk_weights = torch.rand(num_tokens, top_k, device="cuda", dtype=dtype)
|
||||
topk_ids = torch.randint(0, E, (num_tokens, top_k), device="cuda")
|
||||
|
||||
tensors = get_gguf_MoE_tensors(hidden_size, quant_type)
|
||||
|
||||
w13 = tensors[0]
|
||||
w2 = tensors[1]
|
||||
|
||||
w13_dequant = torch.tensor(dequantize(w13.data, quant_type),
|
||||
device="cuda").to(dtype)
|
||||
|
||||
w2_dequant = torch.tensor(dequantize(w2.data, quant_type),
|
||||
device="cuda").to(dtype)
|
||||
act = SiluAndMul()
|
||||
|
||||
output = _fused_moe_gguf(x, torch.tensor(w13.data, device="cuda"),
|
||||
torch.tensor(w2.data,
|
||||
device="cuda"), topk_weights,
|
||||
topk_ids, quant_type, quant_type, act)
|
||||
|
||||
ref_output = fused_experts(x, w13_dequant, w2_dequant, topk_weights,
|
||||
topk_ids).reshape(output.shape)
|
||||
torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)
|
||||
31
tests/kernels/quantization/test_gptq.py
Normal file
31
tests/kernels/quantization/test_gptq.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops # noqa: F401
|
||||
|
||||
|
||||
def test_gptq_shuffle_opcheck():
|
||||
weight = torch.randint(-2000000,
|
||||
2000000, (1792, 4096),
|
||||
device='cuda',
|
||||
dtype=torch.int32)
|
||||
perm = torch.empty((0, ), device='cuda', dtype=torch.int32)
|
||||
bit = 4
|
||||
opcheck(torch.ops._C.gptq_shuffle, (weight, perm, bit))
|
||||
|
||||
|
||||
def test_gptq_gemm_opcheck():
|
||||
a = torch.rand((240, 4096), device='cuda', dtype=torch.float16)
|
||||
weight = torch.randint(-2000000,
|
||||
2000000, (512, 6144),
|
||||
device='cuda',
|
||||
dtype=torch.int32)
|
||||
zeros = torch.zeros((32, 768), device='cuda', dtype=torch.int32)
|
||||
scales = torch.rand((32, 6144), device='cuda', dtype=torch.float16)
|
||||
idx = torch.empty((0, ), device='cuda', dtype=torch.int32)
|
||||
use_exllama = True
|
||||
bit = 4
|
||||
opcheck(torch.ops._C.gptq_gemm,
|
||||
(a, weight, zeros, scales, idx, use_exllama, bit))
|
||||
149
tests/kernels/quantization/test_int8_kernel.py
Normal file
149
tests/kernels/quantization/test_int8_kernel.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from https://github.com/sgl-project/sglang/blob/main/test/srt/test_int8_kernel.py
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
per_token_quant_int8)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.get_device_capability() < (7, 0):
|
||||
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
|
||||
allow_module_level=True)
|
||||
|
||||
|
||||
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
|
||||
"""Matrix multiplication function that supports per-token input
|
||||
quantization and per-column weight quantization"""
|
||||
A = A.to(torch.float32)
|
||||
B = B.to(torch.float32)
|
||||
|
||||
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
|
||||
assert B.ndim == 2 and B.is_contiguous(
|
||||
), "B must be a 2D contiguous tensor"
|
||||
|
||||
# Reshape input
|
||||
M = A.numel() // A.shape[-1]
|
||||
B = B.t() # Transpose weight matrix
|
||||
N, K = B.shape
|
||||
origin_C_shape = A.shape[:-1] + (K, )
|
||||
A = A.reshape(M, N)
|
||||
|
||||
# As is per-token [M, 1], Bs is per-column [1, K]
|
||||
C = torch.matmul(A, B) # [M, K]
|
||||
C = As * C * Bs.view(1, -1) # Broadcast per-column scale
|
||||
|
||||
return C.reshape(origin_C_shape).to(output_dtype)
|
||||
|
||||
|
||||
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
|
||||
"""This function performs fused moe with per-column int8 quantization
|
||||
using native torch."""
|
||||
|
||||
B, D = a.shape
|
||||
# Perform per-token quantization
|
||||
a_q, a_s = per_token_quant_int8(a)
|
||||
# Repeat tokens to match topk
|
||||
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
# Also repeat the scale
|
||||
a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1]
|
||||
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
|
||||
# Calculate routing
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
# Process each expert
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
# First MLP layer: note that a_s is now per-token
|
||||
inter_out = native_w8a8_per_token_matmul(a_q[mask],
|
||||
w1[i],
|
||||
a_s[mask],
|
||||
w1_s[i],
|
||||
output_dtype=a.dtype)
|
||||
# Activation function
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
# Quantize activation output with per-token
|
||||
act_out_q, act_out_s = per_token_quant_int8(act_out)
|
||||
|
||||
# Second MLP layer
|
||||
out[mask] = native_w8a8_per_token_matmul(act_out_q,
|
||||
w2[i],
|
||||
act_out_s,
|
||||
w2_s[i],
|
||||
output_dtype=a.dtype)
|
||||
# Apply routing weights and sum
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def setup_cuda():
|
||||
"""Sets the default CUDA device for all tests in this module."""
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
M = [1, 33]
|
||||
N = [128, 1024]
|
||||
K = [256, 4096]
|
||||
E = [8]
|
||||
TOP_KS = [2, 6]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N, K, E, topk, dtype, seed",
|
||||
itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS))
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
# Initialize int8 quantization parameters
|
||||
factor_for_scale = 1e-2
|
||||
int8_max = 127
|
||||
int8_min = -128
|
||||
|
||||
# Input tensor
|
||||
# M * K
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
|
||||
# Generate int8 weights
|
||||
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2
|
||||
w1 = (w1_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2
|
||||
w2 = (w2_fp32 * int8_max).clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
# Generate scale for each column (per-column quantization)
|
||||
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale
|
||||
w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
|
||||
out = fused_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
use_int8_w8a8=True, # Using int8-w8a8
|
||||
per_channel_quant=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=None, # Not using block quantization
|
||||
)
|
||||
|
||||
# Check results
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||
assert rel_diff < 0.05
|
||||
192
tests/kernels/quantization/test_int8_quant.py
Normal file
192
tests/kernels/quantization/test_int8_quant.py
Normal file
@@ -0,0 +1,192 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.quant_utils import ref_dynamic_per_token_quant
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm._custom_ops import scaled_int8_quant
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
HIDDEN_SIZES = [16, 67, 768, 5137, 8193] # Arbitrary values for testing
|
||||
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
SCALE = [0.1, 2.1]
|
||||
|
||||
|
||||
def opcheck_int8_quant_static(output, input, scale, azp=None):
|
||||
if azp is None:
|
||||
opcheck(torch.ops._C.static_scaled_int8_quant,
|
||||
(output, input, scale, None))
|
||||
else:
|
||||
opcheck(torch.ops._C.static_scaled_int8_quant,
|
||||
(output, input, scale, azp))
|
||||
|
||||
|
||||
def opcheck_int8_quant_dynamic(output, input, symmetric=True):
|
||||
scale = torch.empty((input.numel() // input.shape[-1], 1),
|
||||
device=input.device,
|
||||
dtype=torch.float32)
|
||||
if symmetric:
|
||||
opcheck(torch.ops._C.dynamic_scaled_int8_quant,
|
||||
(output, input, scale, None))
|
||||
else:
|
||||
azp = torch.empty((input.numel() // input.shape[-1], 1),
|
||||
device=input.device,
|
||||
dtype=torch.int32)
|
||||
opcheck(torch.ops._C.dynamic_scaled_int8_quant,
|
||||
(output, input, scale, azp))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
|
||||
|
||||
# reference
|
||||
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.int8)
|
||||
# kernel
|
||||
ops_out, ops_scales, _ = scaled_int8_quant(x)
|
||||
|
||||
torch.testing.assert_close(ops_scales, ref_scales)
|
||||
# big atol to account for rounding errors
|
||||
torch.testing.assert_close(ops_out, ref_out, atol=1, rtol=0.0)
|
||||
|
||||
opcheck_int8_quant_dynamic(ops_out, x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
|
||||
device="cuda") * 1000 - 300
|
||||
|
||||
x_token_max, _ = x.to(dtype=torch.float32).max(dim=1, keepdim=True)
|
||||
x_token_min, _ = x.to(dtype=torch.float32).min(dim=1, keepdim=True)
|
||||
|
||||
# calculate scale and azp, and adjust the range
|
||||
scales = (x_token_max - x_token_min) / torch.tensor(255.0)
|
||||
azps = torch.round(torch.tensor(-128.0) - x_token_min / scales).to(
|
||||
torch.int32)
|
||||
|
||||
torch_out = ((x / scales).round() + azps).clamp(
|
||||
int8_traits.min, int8_traits.max).to(torch.int8)
|
||||
assert torch_out.min() >= int8_traits.min and torch_out.max(
|
||||
) <= int8_traits.max
|
||||
|
||||
ops_out, scales_out, azp_out = scaled_int8_quant(x, symmetric=False)
|
||||
|
||||
if (not torch.allclose(scales_out, scales)):
|
||||
print(torch.argmax(torch.abs(scales_out - scales)))
|
||||
torch.testing.assert_close(scales_out, scales)
|
||||
# big atol to account for rounding errors
|
||||
torch.testing.assert_close(azp_out, azps, atol=1, rtol=0.0)
|
||||
# if AZP is off by 1, after rounding-to-even, the output may be off by 2
|
||||
torch.testing.assert_close(ops_out, torch_out, atol=2, rtol=0.0)
|
||||
|
||||
opcheck_int8_quant_dynamic(ops_out, x, False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("scale", SCALE)
|
||||
@torch.inference_mode()
|
||||
def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int,
|
||||
scale: float) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
|
||||
scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda")
|
||||
|
||||
out1 = (x / scale_arg).round().clamp(int8_traits.min,
|
||||
int8_traits.max).to(torch.int8)
|
||||
out2, scale2, _ = scaled_int8_quant(x, scale_arg)
|
||||
assert scale2 is scale_arg
|
||||
|
||||
# big atol to account for rounding errors
|
||||
torch.testing.assert_close(out1, out2, atol=1, rtol=0.0)
|
||||
|
||||
opcheck_int8_quant_static(out2, x, scale_arg)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("scale", SCALE)
|
||||
@pytest.mark.parametrize("azp", [-255, 54])
|
||||
@torch.inference_mode()
|
||||
def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int,
|
||||
scale: float, azp: int) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
|
||||
device="cuda") * 1000 - 300
|
||||
|
||||
out1 = ((x / scale).round() + azp).clamp(int8_traits.min,
|
||||
int8_traits.max).to(torch.int8)
|
||||
scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda")
|
||||
azp_arg = torch.tensor([azp], dtype=torch.int32, device="cuda")
|
||||
|
||||
out2, scale2, azp2 = scaled_int8_quant(x,
|
||||
scale_arg,
|
||||
azp_arg,
|
||||
symmetric=False)
|
||||
assert scale2 is scale_arg
|
||||
assert azp2 is azp_arg
|
||||
|
||||
# big atol to account for rounding errors
|
||||
torch.testing.assert_close(out1, out2, atol=1, rtol=0.0)
|
||||
|
||||
opcheck_int8_quant_static(out2, x, scale_arg, azp_arg)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_max", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None:
|
||||
# Test that the saturating cast works correctly for values near i32 max/min
|
||||
|
||||
from numpy import inf, nextafter
|
||||
|
||||
int32_traits = torch.iinfo(torch.int32)
|
||||
val = float(int32_traits.max if is_max else int32_traits.min)
|
||||
|
||||
x_vals = [[
|
||||
nextafter(val, inf), val + 1, val, val - 1,
|
||||
nextafter(val, -inf)
|
||||
]]
|
||||
x = torch.tensor(x_vals, dtype=torch.float32, device="cuda")
|
||||
|
||||
# The calculation in the kernel is: cast<int8>(cast<int32>(x / scale) + azp)
|
||||
# where cast<T> is a saturating cast to type T.
|
||||
# Scale is set to 1.0 so that the input values are the ones that are cast.
|
||||
# AZP is set to 0 to make sure the int8 saturating cast is tested as well.
|
||||
scale = torch.scalar_tensor(1.0, dtype=torch.float32, device="cuda")
|
||||
azp = torch.scalar_tensor(0, dtype=torch.int32, device="cuda")
|
||||
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
val_i8 = int8_traits.max if is_max else int8_traits.min
|
||||
expected = torch.full((1, 5), val_i8, dtype=torch.int8, device="cuda")
|
||||
|
||||
out, _, _ = scaled_int8_quant(x, scale, azp, symmetric=False)
|
||||
torch.testing.assert_close(expected, out, atol=0, rtol=0)
|
||||
407
tests/kernels/quantization/test_machete_mm.py
Normal file
407
tests/kernels/quantization/test_machete_mm.py
Normal file
@@ -0,0 +1,407 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Tests for the machete kernel.
|
||||
|
||||
Run `pytest tests/kernels/test_machete_mm.py`.
|
||||
"""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_rows, quantize_weights)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
|
||||
# unit tests to a common utility function. Currently the use of
|
||||
# `is_quant_method_supported` conflates kernels with quantization methods
|
||||
# an assumption which is breaking down as quantizations methods can have
|
||||
# have kernels and some kernels support multiple quantization methods.
|
||||
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
|
||||
|
||||
MNK_SHAPES = [
|
||||
(1, 128, 128),
|
||||
(1, 512, 1024),
|
||||
(1, 4096, 4096),
|
||||
(1, 8192, 28672),
|
||||
(13, 8192, 4096),
|
||||
(26, 4096, 8192),
|
||||
(64, 4096, 4096),
|
||||
(64, 8192, 28672),
|
||||
(257, 128, 4096),
|
||||
(257, 4224, 4160),
|
||||
(257, 4096, 4096),
|
||||
(1024, 4096, 8192),
|
||||
(1024, 8192, 4096),
|
||||
]
|
||||
|
||||
GROUP_SIZES_TO_TEST: list[Optional[int]] = [128, -1]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TypeConfig:
|
||||
act_type: torch.dtype
|
||||
weight_type: ScalarType
|
||||
output_type: Optional[torch.dtype]
|
||||
group_scale_type: Optional[torch.dtype]
|
||||
group_zero_type: Optional[torch.dtype]
|
||||
channel_scale_type: Optional[torch.dtype]
|
||||
token_scale_type: Optional[torch.dtype]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tensors:
|
||||
w_ref: torch.Tensor
|
||||
a_ref: torch.Tensor
|
||||
a: torch.Tensor
|
||||
w_q: torch.Tensor
|
||||
w_g_s: Optional[torch.Tensor]
|
||||
w_g_zp: Optional[torch.Tensor]
|
||||
w_ch_s: Optional[torch.Tensor]
|
||||
w_tok_s: Optional[torch.Tensor]
|
||||
|
||||
|
||||
# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints,
|
||||
# Ch Scales Type, Tok Scales Type)
|
||||
# NOTE: None "Scale Type" means the act type is floating point
|
||||
# None "Output Type" means the output type is the same as the act type
|
||||
TestTypeTuple = tuple[list[torch.dtype], ScalarType, Optional[torch.dtype],
|
||||
Optional[torch.dtype], bool]
|
||||
TEST_TYPES = [
|
||||
# GPTQ style
|
||||
*(TypeConfig(act_type=a_type,
|
||||
weight_type=w_type,
|
||||
output_type=None,
|
||||
group_scale_type=a_type,
|
||||
group_zero_type=None,
|
||||
channel_scale_type=None,
|
||||
token_scale_type=None)
|
||||
for w_type in [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
for a_type in [torch.float16, torch.bfloat16]),
|
||||
# AWQ style
|
||||
*(TypeConfig(act_type=a_type,
|
||||
weight_type=w_type,
|
||||
output_type=None,
|
||||
group_scale_type=a_type,
|
||||
group_zero_type=a_type,
|
||||
channel_scale_type=None,
|
||||
token_scale_type=None)
|
||||
for w_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
for a_type in [torch.float16, torch.bfloat16]),
|
||||
# QQQ style
|
||||
*(TypeConfig(act_type=torch.int8,
|
||||
weight_type=scalar_types.uint4b8,
|
||||
output_type=torch.float16,
|
||||
group_scale_type=group_scale_type,
|
||||
group_zero_type=None,
|
||||
channel_scale_type=torch.float,
|
||||
token_scale_type=torch.float)
|
||||
for group_scale_type in [None, torch.float16]),
|
||||
*(TypeConfig(act_type=torch.float8_e4m3fn,
|
||||
weight_type=scalar_types.uint4b8,
|
||||
output_type=torch.float16,
|
||||
group_scale_type=group_scale_type,
|
||||
group_zero_type=None,
|
||||
channel_scale_type=torch.float,
|
||||
token_scale_type=torch.float)
|
||||
for group_scale_type in [None, torch.float16]),
|
||||
]
|
||||
|
||||
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
|
||||
# unit tests to a common utility function. Currently the use of
|
||||
# `is_quant_method_supported` conflates kernels with quantization methods
|
||||
# an assumption which is breaking down as quantizations methods can have
|
||||
# have kernels and some kernels support multiple quantization methods.
|
||||
IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90)
|
||||
|
||||
|
||||
def rand_data(shape, dtype=torch.float16, scale=1, offset=0):
|
||||
if dtype.is_floating_point:
|
||||
return (scale * torch.rand(shape, device="cuda") - offset).to(dtype)
|
||||
else:
|
||||
return torch.randint(-8, 7, shape, dtype=dtype, device="cuda")
|
||||
|
||||
|
||||
def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
|
||||
return zps if zps is None else -1 * s * (zps.to(s.dtype))
|
||||
|
||||
|
||||
def group_size_valid(shape: tuple[int, int, int],
|
||||
group_size: Optional[int]) -> bool:
|
||||
return group_size is None or group_size == -1 or group_size % shape[2] == 0
|
||||
|
||||
|
||||
def machete_quantize_and_pack(atype: torch.dtype,
|
||||
w: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
stype: Optional[torch.dtype],
|
||||
group_size: Optional[int],
|
||||
zero_points: bool = False):
|
||||
assert wtype.is_integer(), "TODO: support floating point weights"
|
||||
|
||||
w_ref, w_q, w_s, w_zp = quantize_weights(
|
||||
w,
|
||||
wtype,
|
||||
group_size=group_size,
|
||||
zero_points=zero_points,
|
||||
# to match how the kernel applies zps
|
||||
ref_zero_points_after_scales=True)
|
||||
|
||||
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
|
||||
w_q = w_q.t().contiguous().t() # convert to col major
|
||||
|
||||
w_q_machete = ops.machete_prepack_B(w_q, atype, wtype, stype)
|
||||
opcheck(torch.ops._C.machete_prepack_B, (w_q, atype, wtype.id, stype))
|
||||
|
||||
return w_ref, w_q_machete, w_s, w_zp
|
||||
|
||||
|
||||
def create_test_tensors(shape: tuple[int, int, int],
|
||||
types: TypeConfig,
|
||||
group_size: Optional[int],
|
||||
subset_stride_factor: Optional[int] = None) -> Tensors:
|
||||
m, n, k = shape
|
||||
factor = subset_stride_factor or 1
|
||||
|
||||
print("create_test_tensors, shape:", shape, "types:", types, "group_size:",
|
||||
group_size)
|
||||
|
||||
a = rand_data((m * factor, k * factor), types.act_type, scale=3, offset=2)
|
||||
w = rand_data((k * factor, n * factor), types.act_type, scale=3, offset=1)
|
||||
|
||||
if factor > 1:
|
||||
a = a[0:m, 0:k]
|
||||
w = w[0:k, 0:n]
|
||||
|
||||
if types.group_scale_type is not None:
|
||||
w = w.to(types.group_scale_type)
|
||||
if w.dtype.itemsize == 1:
|
||||
w = w.to(torch.float16)
|
||||
|
||||
w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
|
||||
a.dtype, w, types.weight_type, types.group_scale_type, group_size,
|
||||
types.group_zero_type is not None)
|
||||
|
||||
if not a.dtype.is_floating_point:
|
||||
aiinfo = torch.iinfo(a.dtype)
|
||||
w_ref = w_ref.round().clamp(aiinfo.min, aiinfo.max)
|
||||
|
||||
a_ref = a.to(torch.float32)
|
||||
w_ref = w_ref.to(torch.float32)
|
||||
|
||||
w_ch_s = None if types.channel_scale_type is None else\
|
||||
rand_data((n,), types.channel_scale_type)
|
||||
w_tok_s = None if types.token_scale_type is None else\
|
||||
rand_data((m,), types.token_scale_type)
|
||||
|
||||
return Tensors(w_ref=w_ref,
|
||||
a_ref=a_ref,
|
||||
a=a,
|
||||
w_q=w_q_packed,
|
||||
w_g_s=w_s,
|
||||
w_g_zp=maybe_convert_zeropoints(w_zp, w_s),
|
||||
w_ch_s=w_ch_s,
|
||||
w_tok_s=w_tok_s)
|
||||
|
||||
|
||||
# None stype means scales use the same dtype as a
|
||||
def machete_mm_test_helper(types: TypeConfig,
|
||||
tensors: Tensors,
|
||||
group_size: Optional[int] = None,
|
||||
schedule: Optional[str] = None):
|
||||
output_ref = torch.matmul(tensors.a_ref, tensors.w_ref)
|
||||
output_ref_type = output_ref.dtype
|
||||
|
||||
if tensors.w_ch_s is not None:
|
||||
output_ref = (output_ref.to(tensors.w_ch_s.dtype) *
|
||||
tensors.w_ch_s.unsqueeze(0)).to(output_ref_type)
|
||||
if tensors.w_tok_s is not None:
|
||||
output_ref = (output_ref.to(tensors.w_tok_s.dtype) *
|
||||
tensors.w_tok_s.unsqueeze(1)).to(output_ref_type)
|
||||
|
||||
output = ops.machete_mm(
|
||||
a=tensors.a,
|
||||
b_q=tensors.w_q,
|
||||
b_type=types.weight_type,
|
||||
b_group_scales=tensors.w_g_s,
|
||||
b_group_zeros=tensors.w_g_zp,
|
||||
b_group_size=group_size,
|
||||
b_channel_scales=tensors.w_ch_s,
|
||||
a_token_scales=tensors.w_tok_s,
|
||||
out_type=types.output_type,
|
||||
schedule=schedule,
|
||||
)
|
||||
|
||||
print(output)
|
||||
print(output_ref)
|
||||
|
||||
# Relax atol as our reduction dim becomes larger (more rounding error)
|
||||
# Relax atol when we have zeropoints since the way machete applies
|
||||
# zeropoints (after scales) causes noise around 0
|
||||
atol = 1 if tensors.w_g_zp is not None\
|
||||
else min(5e-2 * math.sqrt(tensors.a.shape[1]), 1)
|
||||
rtol = 1e-1 if tensors.a.element_size() >= 2 else 2e-1
|
||||
torch.testing.assert_close(output,
|
||||
output_ref.to(output.dtype),
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||
reason="Machete is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("shape",
|
||||
MNK_SHAPES,
|
||||
ids=lambda x: "x".join(str(v) for v in x))
|
||||
@pytest.mark.parametrize("types", TEST_TYPES)
|
||||
def test_machete_all_schedules(shape, types: TypeConfig):
|
||||
|
||||
group_sizes: list[Optional[int]] = []
|
||||
if types.group_scale_type is None:
|
||||
group_sizes = [None]
|
||||
else:
|
||||
group_sizes = GROUP_SIZES_TO_TEST
|
||||
|
||||
for group_size in group_sizes:
|
||||
if not group_size_valid(shape, group_size):
|
||||
continue
|
||||
|
||||
tensors = create_test_tensors(shape, types, group_size)
|
||||
print(f"MNK = {shape}")
|
||||
for schedule in ops.machete_supported_schedules(
|
||||
types.act_type,
|
||||
types.weight_type,
|
||||
group_scales_type=types.group_scale_type,
|
||||
group_zeros_type=types.group_scale_type,
|
||||
out_type=types.output_type):
|
||||
print(f"Testing schedule {schedule}")
|
||||
machete_mm_test_helper(types, tensors, group_size, schedule)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||
reason="Machete is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("shape",
|
||||
MNK_SHAPES,
|
||||
ids=lambda x: "x".join(str(v) for v in x))
|
||||
@pytest.mark.parametrize("types", TEST_TYPES)
|
||||
def test_machete_heuristic(shape, types: TypeConfig):
|
||||
group_sizes: list[Optional[int]] = []
|
||||
if types.group_scale_type is None:
|
||||
group_sizes = [None]
|
||||
else:
|
||||
group_sizes = GROUP_SIZES_TO_TEST
|
||||
|
||||
for group_size in group_sizes:
|
||||
if not group_size_valid(shape, group_size):
|
||||
continue
|
||||
|
||||
tensors = create_test_tensors(shape, types, group_size)
|
||||
machete_mm_test_helper(types, tensors, group_size)
|
||||
|
||||
|
||||
# Test working on other devices
|
||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||
reason="Machete is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_machete_devices(device: str):
|
||||
group_size = 128
|
||||
|
||||
type_config = TypeConfig(act_type=torch.float16,
|
||||
weight_type=scalar_types.uint4b8,
|
||||
output_type=None,
|
||||
group_scale_type=torch.float16,
|
||||
group_zero_type=None,
|
||||
channel_scale_type=None,
|
||||
token_scale_type=None)
|
||||
|
||||
tensors = create_test_tensors((512, 4096, 4096), type_config, group_size)
|
||||
|
||||
for field in fields(Tensors):
|
||||
tensor = getattr(tensors, field.name)
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
setattr(tensors, field.name, tensor.to(device))
|
||||
|
||||
machete_mm_test_helper(type_config, tensors, group_size)
|
||||
|
||||
|
||||
# Test working with a subset of A and B
|
||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||
reason="Machete is not supported on this GPU type.")
|
||||
def test_machete_subset():
|
||||
group_size = 128
|
||||
|
||||
type_config = TypeConfig(act_type=torch.float16,
|
||||
weight_type=scalar_types.uint4b8,
|
||||
output_type=None,
|
||||
group_scale_type=torch.float16,
|
||||
group_zero_type=None,
|
||||
channel_scale_type=None,
|
||||
token_scale_type=None)
|
||||
|
||||
tensors = create_test_tensors((512, 4096, 4096),
|
||||
type_config,
|
||||
group_size,
|
||||
subset_stride_factor=2)
|
||||
machete_mm_test_helper(type_config, tensors, group_size)
|
||||
|
||||
|
||||
# Test to make sure cuda graphs work
|
||||
class MacheteLayer(torch.nn.Module):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.kwargs = kwargs
|
||||
|
||||
def forward(self, a):
|
||||
return ops.machete_mm(a=a, **self.kwargs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||
reason="Machete is not supported on this GPU type.")
|
||||
def test_machete_cuda_graph():
|
||||
m, n, k = 512, 4096, 4096
|
||||
|
||||
a = rand_data((m, k), torch.float16)
|
||||
b = rand_data((k, n), torch.float16)
|
||||
wtype = scalar_types.uint4b8
|
||||
stype = torch.float16
|
||||
group_size = 128
|
||||
zero_points = False
|
||||
|
||||
w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
|
||||
a.dtype, b, wtype, stype, group_size, zero_points)
|
||||
|
||||
# Construct a trivial model with a single layer that calls a machete kernel
|
||||
model = MacheteLayer(
|
||||
b_q=w_q_packed,
|
||||
b_type=wtype,
|
||||
b_group_scales=w_s,
|
||||
b_group_zeros=maybe_convert_zeropoints(w_zp, w_s),
|
||||
b_group_size=group_size,
|
||||
)
|
||||
|
||||
output_ref = torch.matmul(a, w_ref)
|
||||
|
||||
# Run the model with a cuda graph
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
output = model(a)
|
||||
output.zero_()
|
||||
g.replay()
|
||||
|
||||
# Relax atol as our reduction dim becomes larger (more rounding error)
|
||||
# Relax atol when we have zeropoints since the way machete applies
|
||||
# zeropoints (after scales) causes noise around 0
|
||||
atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1)
|
||||
torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol)
|
||||
666
tests/kernels/quantization/test_marlin_gemm.py
Normal file
666
tests/kernels/quantization/test_marlin_gemm.py
Normal file
@@ -0,0 +1,666 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Tests for the marlin kernel.
|
||||
|
||||
Run `pytest tests/kernels/marlin/test_marlin_gemm.py`.
|
||||
"""
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
|
||||
from vllm.model_executor.layers.quantization.qqq import (
|
||||
MARLIN_QQQ_MAX_PARALLEL, MARLIN_QQQ_MIN_THREAD_N,
|
||||
MARLIN_QQQ_SUPPORTED_GROUP_SIZES, MARLIN_QQQ_SUPPORTED_NUM_BITS)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
|
||||
marlin_permute_scales, query_marlin_supported_quant_types)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
pack_fp8_to_int32)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize,
|
||||
marlin_weights)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||
marlin_24_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_qqq import ( # noqa: E501
|
||||
marlin_qqq_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
ACT_ORDER_OPTS = [False, True]
|
||||
K_FULL_OPTS = [False, True]
|
||||
USE_ATOMIC_ADD_OPTS = [False, True]
|
||||
USE_FP32_REDUCE_OPTS = [False, True]
|
||||
|
||||
MARLIN_K_CHUNKS = [128]
|
||||
MARLIN_N_CHUNKS = [64, 256]
|
||||
|
||||
MARLIN_24_K_CHUNKS = [128]
|
||||
MARLIN_24_N_CHUNKS = [512]
|
||||
|
||||
HQQ_SUPPORTED_GROUP_SIZES = [64]
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 1, 1),
|
||||
(1, 4, 8),
|
||||
(1, 7, 5),
|
||||
(13, 17, 67),
|
||||
(26, 37, 13),
|
||||
(67, 13, 11),
|
||||
(257, 13, 11),
|
||||
(658, 13, 11),
|
||||
]
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
|
||||
|
||||
def compute_max_diff(output, output_ref):
|
||||
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
||||
torch.abs(output_ref))
|
||||
|
||||
|
||||
def rand_data(shape, dtype=torch.float16):
|
||||
return torch.randn(shape, dtype=dtype, device="cuda")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type",
|
||||
query_marlin_supported_quant_types(False))
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
act_order, mnk_factors):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
if group_size == -1:
|
||||
return
|
||||
if group_size == size_k:
|
||||
return
|
||||
|
||||
# Normalize group_size
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
|
||||
# Create input
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
# Quantize (and apply act_order if provided)
|
||||
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
|
||||
b_weight, quant_type, group_size, act_order)
|
||||
|
||||
# Pack to GPTQ format
|
||||
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
||||
|
||||
# For act_order, sort the "weights" and "g_idx" so that group ids are
|
||||
# increasing
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device)
|
||||
if act_order:
|
||||
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
||||
|
||||
# Pack to Marlin format
|
||||
weight_perm = get_weight_perm(quant_type.size_bits)
|
||||
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
|
||||
weight_perm)
|
||||
|
||||
opcheck(torch.ops._C.gptq_marlin_repack,
|
||||
(q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits))
|
||||
|
||||
# Run Marlin repack GPU kernel
|
||||
marlin_q_w_2 = ops.gptq_marlin_repack(
|
||||
q_w_gptq,
|
||||
sort_indices,
|
||||
size_k,
|
||||
size_n,
|
||||
quant_type.size_bits,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type",
|
||||
query_marlin_supported_quant_types(False))
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
mnk_factors):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
# Normalize group_size
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
|
||||
# Create input
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
# Quantize
|
||||
w_ref, q_w, s, zp = quantize_weights(b_weight,
|
||||
quant_type,
|
||||
group_size,
|
||||
zero_points=True)
|
||||
|
||||
# Pack to AWQ format
|
||||
q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
||||
|
||||
# Pack to Marlin format
|
||||
weight_perm = get_weight_perm(quant_type.size_bits)
|
||||
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
|
||||
weight_perm)
|
||||
|
||||
opcheck(torch.ops._C.awq_marlin_repack,
|
||||
(q_w_awq, size_k, size_n, quant_type.size_bits))
|
||||
|
||||
# Run Marlin repack GPU kernel
|
||||
marlin_q_w_2 = ops.awq_marlin_repack(
|
||||
q_w_awq,
|
||||
size_k,
|
||||
size_n,
|
||||
quant_type.size_bits,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type",
|
||||
query_marlin_supported_quant_types(False))
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
|
||||
@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS)
|
||||
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
|
||||
def test_gptq_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
quant_type,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
act_order,
|
||||
is_k_full,
|
||||
use_atomic_add,
|
||||
use_fp32_reduce,
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
if act_order:
|
||||
if group_size == -1:
|
||||
return
|
||||
if group_size == size_k:
|
||||
return
|
||||
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
b_weight, quant_type, group_size, act_order)
|
||||
|
||||
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
||||
|
||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
|
||||
opcheck(torch.ops._C.gptq_marlin_gemm,
|
||||
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
|
||||
workspace.scratch, quant_type.id, a_input.shape[0],
|
||||
b_weight.shape[1], a_input.shape[1], is_k_full, False,
|
||||
use_atomic_add, use_fp32_reduce, False),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace.scratch,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full=is_k_full,
|
||||
has_zp=False,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False,
|
||||
)
|
||||
output_ref = torch.matmul(a_input, w_ref)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
# TODO: find better way to test this?
|
||||
@torch.compile(fullgraph=True)
|
||||
def marlin_24_gemm_tester(a_input, marlin_24_q_w_comp, marlin_24_meta,
|
||||
marlin_24_s, scratch, quant_type, size_m, size_n,
|
||||
size_k):
|
||||
return ops.gptq_marlin_24_gemm(a_input, marlin_24_q_w_comp, marlin_24_meta,
|
||||
marlin_24_s, scratch, quant_type, size_m,
|
||||
size_n, size_k)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
|
||||
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
mnk_factors):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
(w_24_ref, marlin_24_q_w_comp, marlin_24_meta,
|
||||
marlin_24_s) = marlin_24_quantize(b_weight, quant_type, group_size)
|
||||
|
||||
workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL)
|
||||
|
||||
output_ref = torch.matmul(a_input, w_24_ref)
|
||||
|
||||
opcheck(torch.ops._C.gptq_marlin_24_gemm,
|
||||
(a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s,
|
||||
workspace_24.scratch, quant_type.id, a_input.shape[0],
|
||||
b_weight.shape[1], a_input.shape[1]),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||
|
||||
output = marlin_24_gemm_tester(
|
||||
a_input,
|
||||
marlin_24_q_w_comp,
|
||||
marlin_24_meta,
|
||||
marlin_24_s,
|
||||
workspace_24.scratch,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("num_bits", [8])
|
||||
@pytest.mark.parametrize("group_size", [-1])
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
def test_fp8_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
num_bits,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
dtype,
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
a_input = rand_data((size_m, size_k), dtype=dtype)
|
||||
b_weight = rand_data((size_k, size_n), dtype=dtype)
|
||||
|
||||
# WEIGHTS
|
||||
fp8_weight, weight_scale = ops.scaled_fp8_quant(b_weight, scale=None)
|
||||
# Repack weights to gptq format (packed int32 elements)
|
||||
packed_gptq_qweight = pack_fp8_to_int32(fp8_weight)
|
||||
# Repack weights to marlin format
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
b_q_weight=packed_gptq_qweight,
|
||||
perm=torch.empty(0, dtype=torch.int, device="cuda"),
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
num_bits=8,
|
||||
)
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Currently Marlin doesn't support per-tensor scales, so we
|
||||
# expand it to channelwise
|
||||
scales = weight_scale.repeat(1, size_n).to(a_input.dtype).to("cuda")
|
||||
# Permute scales
|
||||
marlin_scales = marlin_permute_scales(s=scales,
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=-1)
|
||||
|
||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
|
||||
opcheck(torch.ops._C.fp8_marlin_gemm,
|
||||
(a_input, marlin_qweight, marlin_scales, workspace.scratch,
|
||||
num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1]))
|
||||
|
||||
output = ops.fp8_marlin_gemm(
|
||||
a=a_input,
|
||||
b_q_weight=marlin_qweight,
|
||||
b_scales=marlin_scales,
|
||||
workspace=workspace.scratch,
|
||||
num_bits=num_bits,
|
||||
size_m=a_input.shape[0],
|
||||
size_n=b_weight.shape[1],
|
||||
size_k=a_input.shape[1],
|
||||
)
|
||||
output_ref = torch.matmul(a_input, b_weight)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type",
|
||||
query_marlin_supported_quant_types(True))
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
|
||||
def test_awq_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
quant_type,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
use_fp32_reduce,
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
|
||||
b_weight, quant_type, group_size)
|
||||
|
||||
g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device)
|
||||
is_k_full = True
|
||||
has_zp = True
|
||||
|
||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace.scratch,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full=is_k_full,
|
||||
has_zp=has_zp,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False,
|
||||
)
|
||||
output_ref = torch.matmul(a_input, w_ref)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("group_size", HQQ_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
|
||||
def test_hqq_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
use_fp32_reduce,
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
quant_type = scalar_types.uint4
|
||||
|
||||
a_input = rand_data((size_m, size_k))
|
||||
dev = a_input.device
|
||||
|
||||
b_weight = torch.randint(0,
|
||||
10, (size_n, size_k),
|
||||
dtype=torch.uint8,
|
||||
device=dev)
|
||||
scale = rand_data((size_n, size_k // group_size))
|
||||
zero = rand_data((size_n, size_k // group_size))
|
||||
|
||||
gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n)
|
||||
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=dev)
|
||||
marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n,
|
||||
4).to(dev)
|
||||
marlin_s = marlin_permute_scales(scale.transpose(1, 0), size_k, size_n,
|
||||
group_size).to(dev)
|
||||
marlin_zp = marlin_permute_scales(zero.transpose(1, 0), size_k, size_n,
|
||||
group_size).to(dev)
|
||||
|
||||
g_idx = marlin_make_empty_g_idx(dev)
|
||||
g_idx_sort_indices = marlin_make_empty_g_idx(dev)
|
||||
|
||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
marlin_w_q,
|
||||
marlin_s,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
g_idx_sort_indices,
|
||||
workspace.scratch,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[0],
|
||||
a_input.shape[1],
|
||||
is_k_full=True,
|
||||
has_zp=True,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=True,
|
||||
)
|
||||
|
||||
b_flat = b_weight.reshape(-1, group_size)
|
||||
zp_flat = zero.reshape(-1, 1)
|
||||
s_flat = scale.reshape(-1, 1)
|
||||
dequant = (b_flat - zp_flat) * s_flat
|
||||
|
||||
output_ref = torch.matmul(a_input,
|
||||
dequant.reshape(b_weight.shape).transpose(1, 0))
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("qqq"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("num_bits", MARLIN_QQQ_SUPPORTED_NUM_BITS)
|
||||
@pytest.mark.parametrize("group_size", MARLIN_QQQ_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_marlin_qqq_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
num_bits,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
):
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
# Quantize activations
|
||||
s_a = a_input.abs().max(dim=-1, keepdim=True)[0].div(int8_traits.max).to(
|
||||
torch.float)
|
||||
q_a = (a_input / s_a).round().clamp(int8_traits.min,
|
||||
int8_traits.max).to(torch.int8)
|
||||
|
||||
# Quantize weights
|
||||
w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = \
|
||||
marlin_qqq_quantize(b_weight, num_bits, group_size)
|
||||
|
||||
workspace = MarlinWorkspace(size_n, MARLIN_QQQ_MIN_THREAD_N,
|
||||
MARLIN_QQQ_MAX_PARALLEL)
|
||||
|
||||
opcheck(torch.ops._C.marlin_qqq_gemm,
|
||||
(q_a, marlin_qqq_q_w, s_a, marlin_qqq_s_channel,
|
||||
marlin_qqq_s_group, workspace.scratch, a_input.shape[0],
|
||||
b_weight.shape[1], a_input.shape[1]))
|
||||
|
||||
output = ops.marlin_qqq_gemm(
|
||||
q_a,
|
||||
marlin_qqq_q_w,
|
||||
s_a,
|
||||
marlin_qqq_s_channel,
|
||||
marlin_qqq_s_group,
|
||||
workspace.scratch,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
)
|
||||
output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
def test_marlin_gemm_subset_input():
|
||||
quant_type = scalar_types.uint4b8
|
||||
group_size = 128
|
||||
|
||||
size_m, size_k, size_n = 32, 1024, 2048
|
||||
big_m = size_m * 2
|
||||
big_k = size_k * 2
|
||||
|
||||
a_input = rand_data((big_m, big_k))[8:size_m + 8, 8:size_k + 8]
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
b_weight, quant_type, group_size, False)
|
||||
|
||||
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace.scratch,
|
||||
quant_type,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full=True,
|
||||
has_zp=False,
|
||||
use_atomic_add=False,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
)
|
||||
output_ref = torch.matmul(a_input, w_ref)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
def test_marlin_gemm_opcheck():
|
||||
size_m = 2048
|
||||
size_n = 4096
|
||||
size_k = 4096
|
||||
a = torch.rand((size_m, size_n), device='cuda', dtype=torch.float16)
|
||||
w = torch.randint(-5, 5, (256, 8192), device='cuda', dtype=torch.int32)
|
||||
s = torch.full((32, size_k), 0.125, device='cuda', dtype=torch.float16)
|
||||
wk = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL).scratch
|
||||
x = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k)
|
||||
y = torch.ops._C.marlin_gemm(a, w, s, wk, size_m, size_n, size_k)
|
||||
torch.testing.assert_close(x, y)
|
||||
opcheck(torch.ops._C.marlin_gemm, (a, w, s, wk, size_m, size_n, size_k))
|
||||
149
tests/kernels/quantization/test_nvfp4_quant.py
Normal file
149
tests/kernels/quantization/test_nvfp4_quant.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True)
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)]
|
||||
PAD_SHAPES = [(90, 64), (150, 64), (128, 48), (128, 80), (150, 80), (90, 48),
|
||||
(90, 128), (150, 128), (150, 48), (90, 80)]
|
||||
SEEDS = [42]
|
||||
CUDA_DEVICES = ['cuda:0']
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1fn.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
# E2M1 to float
|
||||
# 0111 -> 6
|
||||
# 0110 -> 4
|
||||
# 0101 -> 3
|
||||
# 0100 -> 2
|
||||
# 0011 -> 1.5
|
||||
# 0010 -> 1
|
||||
# 0001 -> 0.5
|
||||
# 0000 -> 0
|
||||
E2M1_TO_FLOAT32 = [
|
||||
0., 0.5, 1., 1.5, 2., 3., 4., 6., 0., -0.5, -1., -1.5, -2., -3., -4., -6.
|
||||
]
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
|
||||
def cast_from_fp4(x, m, n):
|
||||
# The fp4 values are packed in uint8 as [v_1st | v_2nd]
|
||||
v_2nd = x & 0xF
|
||||
v_1st = (x >> 4) & 0xF
|
||||
c = torch.stack((v_2nd, v_1st), dim=-1)
|
||||
out = torch.tensor([E2M1_TO_FLOAT32[x] for x in c.flatten()])
|
||||
out = out.reshape(m, n).to(torch.float32)
|
||||
return out
|
||||
|
||||
|
||||
def cast_to_fp4(x):
|
||||
sign = torch.sign(x)
|
||||
x = torch.abs(x)
|
||||
x[(x >= 0.0) & (x <= 0.25)] = 0.0
|
||||
x[(x > 0.25) & (x < 0.75)] = 0.5
|
||||
x[(x >= 0.75) & (x <= 1.25)] = 1.0
|
||||
x[(x > 1.25) & (x < 1.75)] = 1.5
|
||||
x[(x >= 1.75) & (x <= 2.5)] = 2.0
|
||||
x[(x > 2.5) & (x < 3.5)] = 3.0
|
||||
x[(x >= 3.5) & (x <= 5.0)] = 4.0
|
||||
x[x > 5.0] = 6.0
|
||||
return x * sign
|
||||
|
||||
|
||||
def get_reciprocal(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x)
|
||||
elif isinstance(x, (float, int)):
|
||||
return 0.0 if x == 0 else 1.0 / x
|
||||
else:
|
||||
raise TypeError("Input must be a float, int, or a torch.Tensor.")
|
||||
|
||||
|
||||
def ref_nvfp4_quant(x, global_scale):
|
||||
assert global_scale.dtype == torch.float32
|
||||
assert x.ndim == 2
|
||||
m, n = x.shape
|
||||
x = torch.reshape(x, (m, n // BLOCK_SIZE, BLOCK_SIZE))
|
||||
vec_max = torch.max(torch.abs(x), dim=-1,
|
||||
keepdim=True)[0].to(torch.float32)
|
||||
scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX))
|
||||
scale = scale.to(torch.float8_e4m3fn).to(torch.float32)
|
||||
output_scale = get_reciprocal(scale * get_reciprocal(global_scale))
|
||||
|
||||
scaled_x = x.to(torch.float32) * output_scale
|
||||
clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n)
|
||||
return cast_to_fp4(clipped_x), scale.squeeze(-1)
|
||||
|
||||
|
||||
def recover_swizzled_scales(scale, m, n):
|
||||
round_up = lambda x, y: (x + y - 1) // y * y
|
||||
rounded_m = round_up(m, 128)
|
||||
scale_n = n // BLOCK_SIZE
|
||||
rounded_n = round_up(scale_n, 4)
|
||||
# Recover the swizzled scaling factor to linear layout
|
||||
tmp = torch.reshape(scale, (1, rounded_m // 128, rounded_n // 4, 32, 4, 4))
|
||||
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
|
||||
result = torch.reshape(tmp, (rounded_m, rounded_n)).to(torch.float32)
|
||||
return result[:m, :scale_n]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("shape", SHAPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_quantize_to_fp4(
|
||||
dtype: torch.dtype,
|
||||
shape: tuple[int, int],
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
m, n = shape
|
||||
|
||||
x = torch.randn((m, n), dtype=dtype)
|
||||
tensor_amax = torch.abs(x).max().to(torch.float32)
|
||||
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
|
||||
out_ref, scale_ref = ref_nvfp4_quant(x, global_scale)
|
||||
|
||||
out, out_scale = ops.scaled_fp4_quant(x, global_scale)
|
||||
scale_ans = recover_swizzled_scales(out_scale, m, n)
|
||||
out_ans = cast_from_fp4(out, m, n)
|
||||
|
||||
torch.testing.assert_close(out_ans, out_ref)
|
||||
torch.testing.assert_close(scale_ans, scale_ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pad_shape", PAD_SHAPES)
|
||||
@torch.inference_mode()
|
||||
def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
|
||||
dtype = torch.float16
|
||||
current_platform.seed_everything(42)
|
||||
torch.set_default_device('cuda:0')
|
||||
|
||||
m, n = pad_shape
|
||||
|
||||
x = torch.randn((m, n), dtype=dtype)
|
||||
|
||||
tensor_amax = torch.abs(x).max().to(torch.float32)
|
||||
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
|
||||
out_ref, scale_ref = ref_nvfp4_quant(x, global_scale)
|
||||
|
||||
out, out_scale = ops.scaled_fp4_quant(x, global_scale)
|
||||
|
||||
scale_ans = recover_swizzled_scales(out_scale, m, n)
|
||||
out_ans = cast_from_fp4(out, m, n)
|
||||
|
||||
torch.testing.assert_close(out_ans, out_ref)
|
||||
torch.testing.assert_close(scale_ans, scale_ref)
|
||||
150
tests/kernels/quantization/test_nvfp4_scaled_mm.py
Normal file
150
tests/kernels/quantization/test_nvfp4_scaled_mm.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True)
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
# m, n, k
|
||||
SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)]
|
||||
PAD_SHAPES = [(150, 128, 64), (128, 128, 96)]
|
||||
SHAPES.extend(PAD_SHAPES)
|
||||
|
||||
SEEDS = [42]
|
||||
CUDA_DEVICES = ['cuda:0']
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1fn.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
kE2M1ToFloatArray = [
|
||||
0.,
|
||||
0.5,
|
||||
1.,
|
||||
1.5,
|
||||
2.,
|
||||
3.,
|
||||
4.,
|
||||
6.,
|
||||
]
|
||||
|
||||
|
||||
def e2m1_to_fp32(int4_value):
|
||||
signBit = (int4_value & 0x8)
|
||||
int4_absValue = int4_value & 0x7
|
||||
float_result = kE2M1ToFloatArray[int4_absValue]
|
||||
if (signBit):
|
||||
float_result = -float_result
|
||||
return float_result
|
||||
|
||||
|
||||
def break_fp4_bytes(a, dtype):
|
||||
assert (a.dtype == torch.uint8)
|
||||
m, n = a.shape
|
||||
a = a.flatten()
|
||||
# Get upper 4 bits
|
||||
highHalfByte = (a & 0xF0) >> 4
|
||||
# Get lower 4 bits
|
||||
lowHalfByte = a & 0x0F
|
||||
fH = torch.tensor([e2m1_to_fp32(x) for x in highHalfByte]).to(a.device)
|
||||
fL = torch.tensor([e2m1_to_fp32(x) for x in lowHalfByte]).to(a.device)
|
||||
# [0xAB, 0xCD] -> [0xB, 0xA, 0xD, 0xC]
|
||||
out = torch.stack((fL, fH), dim=-1).reshape(m, n * 2)
|
||||
return out
|
||||
|
||||
|
||||
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
|
||||
sf_m, sf_k = a_sf_swizzled.shape
|
||||
m_tiles = (m + 128 - 1) // 128
|
||||
f = block_size * 4
|
||||
k_tiles = (k + f - 1) // f
|
||||
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
|
||||
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
|
||||
out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
|
||||
return out[0:m, 0:k]
|
||||
|
||||
|
||||
def dequantize_to_dtype(tensor_fp4,
|
||||
tensor_sf,
|
||||
global_scale,
|
||||
dtype,
|
||||
device,
|
||||
block_size=16):
|
||||
"""Dequantize the fp4 tensor back to high precision."""
|
||||
# Two fp4 values are packed into one uint8.
|
||||
assert tensor_fp4.dtype == torch.uint8
|
||||
m, packed_k = tensor_fp4.shape
|
||||
k = packed_k * 2
|
||||
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
|
||||
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
|
||||
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
|
||||
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
|
||||
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
|
||||
|
||||
# scale the tensor
|
||||
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
|
||||
return out
|
||||
|
||||
|
||||
def get_ref_results(a_fp4, b_fp4, a_sf, b_sf, a_global_scale, b_global_scale,
|
||||
m, n, dtype, block_size, device):
|
||||
_, m_k = a_fp4.shape
|
||||
_, n_k = b_fp4.shape
|
||||
assert (m_k == n_k)
|
||||
a_in_dtype = dequantize_to_dtype(a_fp4,
|
||||
a_sf,
|
||||
a_global_scale,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
block_size=block_size)
|
||||
b_in_dtype = dequantize_to_dtype(b_fp4,
|
||||
b_sf,
|
||||
b_global_scale,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
block_size=block_size)
|
||||
return torch.matmul(a_in_dtype, b_in_dtype.t())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("shape", SHAPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_nvfp4_gemm(
|
||||
dtype: torch.dtype,
|
||||
shape: tuple[int, int, int],
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
m, n, packed_k = shape
|
||||
k = packed_k * 2
|
||||
block_size = 16
|
||||
a_dtype = torch.randn((m, k), dtype=dtype, device=device)
|
||||
b_dtype = torch.randn((n, k), dtype=dtype, device=device)
|
||||
|
||||
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||
torch.amax(a_dtype.flatten(), dim=-1)).to(torch.float32)
|
||||
b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||
torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32)
|
||||
alpha = 1. / (a_global_scale * b_global_scale)
|
||||
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale)
|
||||
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale)
|
||||
|
||||
expected_out = get_ref_results(a_fp4, b_fp4, a_scale_interleaved,
|
||||
b_scale_interleaved, a_global_scale,
|
||||
b_global_scale, m, n, dtype, block_size,
|
||||
device)
|
||||
out = ops.cutlass_scaled_fp4_mm(a_fp4, b_fp4, a_scale_interleaved,
|
||||
b_scale_interleaved, alpha, dtype)
|
||||
|
||||
torch.testing.assert_close(out,
|
||||
expected_out.to(dtype=dtype),
|
||||
atol=1e-1,
|
||||
rtol=1e-1)
|
||||
121
tests/kernels/quantization/test_triton_scaled_mm.py
Normal file
121
tests/kernels/quantization/test_triton_scaled_mm.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Tests for the triton_scaled_mm kernel
|
||||
|
||||
Run `pytest tests/kernels/test_triton_scaled_mm.py`.
|
||||
"""
|
||||
import importlib
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
device = "cuda"
|
||||
|
||||
|
||||
def scaled_mm_torch(a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: type[torch.dtype],
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
out = torch.mm(a.to(torch.float32), b.to(torch.float32))
|
||||
out = scale_a * out
|
||||
out = scale_b.T * out
|
||||
out = out.to(out_dtype)
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def get_8bit_types():
|
||||
types = [torch.int8]
|
||||
if current_platform.supports_fp8():
|
||||
types.append(current_platform.fp8_dtype())
|
||||
return types
|
||||
|
||||
|
||||
# This test is to check regressions for int8 support on ROCm.
|
||||
@pytest.mark.parametrize("model_path", [
|
||||
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
|
||||
])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||
reason="Should only run on ROCm")
|
||||
def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path,
|
||||
max_tokens, num_logprobs):
|
||||
dtype = "bfloat16"
|
||||
|
||||
with vllm_runner(model_path, dtype=dtype) as vllm_model:
|
||||
vllm_model.generate_greedy_logprobs(example_prompts, max_tokens,
|
||||
num_logprobs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M", [1, 33, 64, 512])
|
||||
@pytest.mark.parametrize("N", [256, 971, 20486])
|
||||
@pytest.mark.parametrize("K", [128, 496, 1024])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("in_dtype", get_8bit_types())
|
||||
@pytest.mark.parametrize("use_scalar_scale_a", [True, False])
|
||||
@pytest.mark.parametrize("use_scalar_scale_b", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a,
|
||||
use_scalar_scale_b, use_bias):
|
||||
is_floating_point_type = lambda t: torch.tensor([1, 1], dtype=t
|
||||
).is_floating_point()
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
# NOTE: There are cases, where if the matrix is large enough, an output
|
||||
# like 65504.4 can be produced, and can easily turn into inf when
|
||||
# multiplied when using float16/bfloat16. This means one function, e.g.,
|
||||
# testing function, and another function, e.g. golden function, can
|
||||
# produce a non-inf value while the other produces an inf value, and
|
||||
# will cause assert_close/allclose to fail, even though if overflow
|
||||
# wouldn't have occurred, the values would have been "close."
|
||||
#
|
||||
# So, the values here are kept small enough to avoid this situation.
|
||||
if is_floating_point_type(in_dtype):
|
||||
a = (0.25 * torch.rand(
|
||||
(M, K), dtype=torch.float32, device=device)).to(in_dtype)
|
||||
b = (0.25 * torch.rand(
|
||||
(K, N), dtype=torch.float32, device=device)).to(in_dtype)
|
||||
else:
|
||||
a = torch.randint(-32, 32, (M, K), dtype=in_dtype, device=device)
|
||||
b = torch.randint(-32, 32, (K, N), dtype=in_dtype, device=device)
|
||||
|
||||
if use_scalar_scale_a:
|
||||
scale_a = torch.rand((1, 1), device=device)
|
||||
else:
|
||||
scale_a = 0.25 * torch.rand((M, 1), device=device)
|
||||
|
||||
if use_scalar_scale_b:
|
||||
scale_b = torch.rand((1, 1), device=device)
|
||||
else:
|
||||
scale_b = 0.25 * torch.rand((N, 1), device=device)
|
||||
|
||||
bias = None
|
||||
if use_bias:
|
||||
bias = torch.rand((N, ), device=device, dtype=out_dtype)
|
||||
|
||||
triton_scaled_mm_module = importlib.import_module(
|
||||
"vllm.model_executor.layers.quantization.compressed_tensors."
|
||||
"triton_scaled_mm")
|
||||
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
|
||||
|
||||
c_check = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
|
||||
a_cpu = a.cpu()
|
||||
b_cpu = b.cpu()
|
||||
scale_a_cpu = scale_a.cpu()
|
||||
scale_b_cpu = scale_b.cpu()
|
||||
bias_cpu = None if bias is None else bias.cpu()
|
||||
|
||||
c_actual = scaled_mm_torch(a_cpu, b_cpu, scale_a_cpu, scale_b_cpu,
|
||||
out_dtype, bias_cpu)
|
||||
|
||||
c_check_cpu = c_check.cpu()
|
||||
torch.testing.assert_close(c_check_cpu, c_actual, rtol=1e-1, atol=1e-1)
|
||||
Reference in New Issue
Block a user