Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -8,8 +8,9 @@ from vllm.scalar_type import scalar_types
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.],
dtype=torch.float32)
kE2M1ToFloat = torch.tensor(
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
)
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
@@ -22,12 +23,9 @@ def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
return out[0:m, 0:k]
def dequantize_nvfp4_to_dtype(tensor_fp4,
tensor_sf,
global_scale,
dtype,
device,
block_size=16):
def dequantize_nvfp4_to_dtype(
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert tensor_fp4.dtype == torch.uint8
@@ -69,7 +67,8 @@ def break_fp4_bytes(a, dtype):
def quant_nvfp4_tensor(a: torch.Tensor):
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.abs(a).max().to(torch.float32))
a_global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to(
torch.float32
)
a_quant, a_block_scale = scaled_fp4_quant(a, a_global_scale)
return a_quant, a_block_scale, a_global_scale

View File

@@ -6,24 +6,25 @@ import torch
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
ALLSPARK_AMPERE_K_ALIGN, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
ALLSPARK_AMPERE_N_ALIGN)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
quantize_weights)
ALLSPARK_AMPERE_K_ALIGN,
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
ALLSPARK_AMPERE_N_ALIGN,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_weights
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
def is_gptq_allspark_supported(min_capability: int,
max_capability: int) -> bool:
def is_gptq_allspark_supported(min_capability: int, max_capability: int) -> bool:
if not current_platform.is_cuda():
return False
capability = current_platform.get_device_capability()
assert capability is not None
return capability.to_int() >= min_capability \
and capability.to_int() <= max_capability
return (
capability.to_int() >= min_capability and capability.to_int() <= max_capability
)
MNK_FACTORS = [
@@ -43,7 +44,8 @@ HAS_ZP_OPTS = [False, True]
def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref))
torch.abs(output_ref)
)
def rand_data(shape, dtype=torch.float16):
@@ -52,7 +54,8 @@ def rand_data(shape, dtype=torch.float16):
@pytest.mark.skipif(
not is_gptq_allspark_supported(80, 89),
reason="AllSpark Ampere kernel is not supported on this GPU type.")
reason="AllSpark Ampere kernel is not supported on this GPU type.",
)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("group_size", [-1])
@pytest.mark.parametrize("has_zp", HAS_ZP_OPTS)
@@ -67,8 +70,9 @@ def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype):
weight = rand_data((k, n), dtype=dtype)
# Quantize (and apply act_order if provided)
w_ref, qw, s, zp = quantize_weights(weight, scalar_types.uint8b128,
group_size, has_zp)
w_ref, qw, s, zp = quantize_weights(
weight, scalar_types.uint8b128, group_size, has_zp
)
qw = qw.to(torch.uint8)
if has_zp:
@@ -79,20 +83,42 @@ def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype):
n_32align = (n + 32 - 1) // 32 * 32
qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(
qw, s, zp, has_zp)
opcheck(torch.ops._C.rearrange_kn_weight_as_n32k16_order,
(qw, s, zp, has_zp, qw_reorder, s_reorder, zp_reorder, k, n,
n_32align))
qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(qw, s, zp, has_zp)
opcheck(
torch.ops._C.rearrange_kn_weight_as_n32k16_order,
(qw, s, zp, has_zp, qw_reorder, s_reorder, zp_reorder, k, n, n_32align),
)
opcheck(torch.ops._C.allspark_w8a16_gemm,
(input, qw_reorder, s_reorder, zp_reorder, n, group_size, sm_count,
sm_version, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, has_zp, True),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
output = ops.allspark_w8a16_gemm(input, qw_reorder, s_reorder, zp_reorder,
n, group_size, sm_count, sm_version,
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
has_zp, True)
opcheck(
torch.ops._C.allspark_w8a16_gemm,
(
input,
qw_reorder,
s_reorder,
zp_reorder,
n,
group_size,
sm_count,
sm_version,
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
has_zp,
True,
),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
output = ops.allspark_w8a16_gemm(
input,
qw_reorder,
s_reorder,
zp_reorder,
n,
group_size,
sm_count,
sm_version,
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
has_zp,
True,
)
output_ref = torch.matmul(input, w_ref)
torch.cuda.synchronize()

View File

@@ -8,40 +8,42 @@ from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_dequantize"),
reason="AWQ is not supported on this GPU type.")
@pytest.mark.skipif(
not hasattr(torch.ops._C, "awq_dequantize"),
reason="AWQ is not supported on this GPU type.",
)
def test_awq_dequantize_opcheck(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_TRITON_AWQ", "0")
qweight = torch.randint(-2000000000,
2000000000, (8192, 256),
device='cuda',
dtype=torch.int32)
scales = torch.rand((64, 2048), device='cuda', dtype=torch.float16)
zeros = torch.empty((64, 256), device='cuda', dtype=torch.int32)
qweight = torch.randint(
-2000000000, 2000000000, (8192, 256), device="cuda", dtype=torch.int32
)
scales = torch.rand((64, 2048), device="cuda", dtype=torch.float16)
zeros = torch.empty((64, 256), device="cuda", dtype=torch.int32)
split_k_iters = 0
thx = 0
thy = 0
opcheck(torch.ops._C.awq_dequantize,
(qweight, scales, zeros, split_k_iters, thx, thy))
opcheck(
torch.ops._C.awq_dequantize,
(qweight, scales, zeros, split_k_iters, thx, thy),
)
@pytest.mark.skip(reason="Not working; needs investigation.")
@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_gemm"),
reason="AWQ is not supported on this GPU type.")
@pytest.mark.skipif(
not hasattr(torch.ops._C, "awq_gemm"),
reason="AWQ is not supported on this GPU type.",
)
def test_awq_gemm_opcheck(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_TRITON_AWQ", "0")
input = torch.rand((2, 8192), device='cuda', dtype=torch.float16)
qweight = torch.randint(-2000000000,
2000000000, (8192, 256),
device='cuda',
dtype=torch.int32)
scales = torch.randint(-2000000000,
2000000000, (64, 256),
device='cuda',
dtype=torch.int32)
qzeros = torch.empty((64, 2048), device='cuda', dtype=torch.float16)
input = torch.rand((2, 8192), device="cuda", dtype=torch.float16)
qweight = torch.randint(
-2000000000, 2000000000, (8192, 256), device="cuda", dtype=torch.int32
)
scales = torch.randint(
-2000000000, 2000000000, (64, 256), device="cuda", dtype=torch.int32
)
qzeros = torch.empty((64, 2048), device="cuda", dtype=torch.float16)
split_k_iters = 8
opcheck(torch.ops._C.awq_gemm,
(input, qweight, qzeros, scales, split_k_iters))
opcheck(torch.ops._C.awq_gemm, (input, qweight, qzeros, scales, split_k_iters))

View File

@@ -4,11 +4,15 @@
Run `pytest tests/kernels/quantization/test_awq_triton.py`.
"""
import pytest
import torch
from vllm.model_executor.layers.quantization.awq_triton import (
AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton)
AWQ_TRITON_SUPPORTED_GROUP_SIZES,
awq_dequantize_triton,
awq_gemm_triton,
)
from vllm.platforms import current_platform
device = "cuda"
@@ -33,23 +37,24 @@ def reverse_awq_order(t: torch.Tensor):
# qweights - [R , C // 8], int32
# scales - [R // G, C ], float16
# zeros - [R // G, C // 8], int32
def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor,
qzeros: torch.Tensor,
group_size: int) -> torch.Tensor:
def awq_dequantize_torch(
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, group_size: int
) -> torch.Tensor:
if group_size == -1:
group_size = qweight.shape[0]
bits = 4
shifts = torch.arange(0, 32, bits, device=qzeros.device)
iweights = torch.bitwise_right_shift(qweight[:, :, None],
shifts[None, None, :]).to(torch.int8)
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
torch.int8
)
iweights = iweights.view(iweights.shape[0], -1)
zeros = torch.bitwise_right_shift(qzeros[:, :, None],
shifts[None, None, :]).to(torch.int8)
zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(
torch.int8
)
zeros = zeros.view(qzeros.shape[0], -1)
zeros = reverse_awq_order(zeros)
@@ -70,7 +75,6 @@ def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor,
@pytest.mark.parametrize("qweight_cols", [448, 576, 4736, 16, 32, 64, 128])
@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES)
def test_dequantize(qweight_rows, qweight_cols, group_size):
if group_size == -1:
group_size = qweight_rows
@@ -84,25 +88,27 @@ def test_dequantize(qweight_rows, qweight_cols, group_size):
current_platform.seed_everything(0)
qweight = torch.randint(0,
torch.iinfo(torch.int32).max,
(qweight_rows, qweight_cols),
dtype=qweight_dtype,
device=device)
scales = torch.rand(scales_rows,
scales_cols,
dtype=scales_dtype,
device=device)
zeros = torch.randint(0,
torch.iinfo(torch.int32).max,
(zeros_rows, zeros_cols),
dtype=zeros_dtype,
device=device)
qweight = torch.randint(
0,
torch.iinfo(torch.int32).max,
(qweight_rows, qweight_cols),
dtype=qweight_dtype,
device=device,
)
scales = torch.rand(scales_rows, scales_cols, dtype=scales_dtype, device=device)
zeros = torch.randint(
0,
torch.iinfo(torch.int32).max,
(zeros_rows, zeros_cols),
dtype=zeros_dtype,
device=device,
)
iweights_triton = awq_dequantize_triton(qweight, scales, zeros)
assert (not torch.any(torch.isinf(iweights_triton))
and not torch.any(torch.isnan(iweights_triton)))
assert not torch.any(torch.isinf(iweights_triton)) and not torch.any(
torch.isnan(iweights_triton)
)
iweights_torch = awq_dequantize_torch(qweight, scales, zeros, group_size)
@@ -119,7 +125,6 @@ def test_dequantize(qweight_rows, qweight_cols, group_size):
@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("splitK", [1, 8])
def test_gemm(N, K, M, splitK, group_size):
if group_size == -1:
group_size = K
@@ -138,35 +143,29 @@ def test_gemm(N, K, M, splitK, group_size):
current_platform.seed_everything(0)
input = torch.rand((input_rows, input_cols),
dtype=input_dtype,
device=device)
qweight = torch.randint(0,
torch.iinfo(torch.int32).max,
(qweight_rows, qweight_cols),
device=device)
qzeros = torch.randint(0,
torch.iinfo(torch.int32).max,
(qzeros_rows, qzeros_cols),
device=device)
scales = torch.rand((scales_rows, scales_cols),
dtype=scales_dtype,
device=device)
input = torch.rand((input_rows, input_cols), dtype=input_dtype, device=device)
qweight = torch.randint(
0, torch.iinfo(torch.int32).max, (qweight_rows, qweight_cols), device=device
)
qzeros = torch.randint(
0, torch.iinfo(torch.int32).max, (qzeros_rows, qzeros_cols), device=device
)
scales = torch.rand((scales_rows, scales_cols), dtype=scales_dtype, device=device)
output_triton = awq_gemm_triton(input, qweight, scales, qzeros,
split_k_iters)
output_triton = awq_gemm_triton(input, qweight, scales, qzeros, split_k_iters)
assert (not torch.any(torch.isinf(output_triton))
and not torch.any(torch.isnan(output_triton)))
assert not torch.any(torch.isinf(output_triton)) and not torch.any(
torch.isnan(output_triton)
)
dequantized_weights = awq_dequantize_triton(qweight, scales, qzeros)
output_torch = torch.matmul(input, dequantized_weights)
assert (not torch.any(torch.isinf(output_torch))
and not torch.any(torch.isnan(output_torch)))
assert not torch.any(torch.isinf(output_torch)) and not torch.any(
torch.isnan(output_torch)
)
torch.testing.assert_close(output_triton.cpu(),
output_torch.cpu(),
atol=1e-1,
rtol=1e-1)
torch.testing.assert_close(
output_triton.cpu(), output_torch.cpu(), atol=1e-1, rtol=1e-1
)

View File

@@ -7,20 +7,26 @@ import itertools
import pytest
import torch
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
native_w8a8_block_matmul)
from tests.kernels.quant_utils import (
native_per_token_group_quant_fp8,
native_w8a8_block_matmul,
)
from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
cutlass_scaled_mm, per_token_group_quant_fp8, w8a8_triton_block_scaled_mm)
cutlass_scaled_mm,
per_token_group_quant_fp8,
w8a8_triton_block_scaled_mm,
)
from vllm.platforms import current_platform
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import (fp8_gemm_nt,
get_col_major_tma_aligned_tensor,
per_block_cast_to_fp8)
from vllm.utils.deep_gemm import (
fp8_gemm_nt,
get_col_major_tma_aligned_tensor,
per_block_cast_to_fp8,
)
if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
allow_module_level=True)
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
@@ -51,7 +57,8 @@ def setup_cuda():
@pytest.mark.parametrize(
"num_tokens,d,dtype,group_size,seed",
itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS))
itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS),
)
@torch.inference_mode()
def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
torch.manual_seed(seed)
@@ -60,15 +67,14 @@ def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size)
out, scale = per_token_group_quant_fp8(x, group_size)
assert torch.allclose(out.to(torch.float32),
ref_out.to(torch.float32),
rtol=0.15)
assert torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15)
assert torch.allclose(scale, ref_scale)
@pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),
)
@torch.inference_mode()
def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
torch.manual_seed(seed)
@@ -89,14 +95,12 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
out_dtype)
out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size,
out_dtype)
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
rel_diff = torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
assert rel_diff < 0.001
@@ -127,32 +131,32 @@ def test_w8a8_block_fp8_cutlass_matmul():
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
# Hopper requires row-major format for scales
Bs_cutlass = Bs.T.contiguous() if current_platform.is_device_capability(
90) else Bs
Bs_cutlass = Bs.T.contiguous() if current_platform.is_device_capability(90) else Bs
A_fp8, As = per_token_group_quant_fp8(A_fp32,
block_size[1],
column_major_scales=False)
A_fp8, As = per_token_group_quant_fp8(
A_fp32, block_size[1], column_major_scales=False
)
# CUTLASS uses column-major format for scales
A_fp8_cutlass, As_cutlass = per_token_group_quant_fp8(
A_fp32, block_size[1], column_major_scales=True)
A_fp32, block_size[1], column_major_scales=True
)
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
out_dtype)
out = cutlass_scaled_mm(A_fp8_cutlass, B_fp8, As_cutlass, Bs_cutlass,
block_size, out_dtype)
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
out = cutlass_scaled_mm(
A_fp8_cutlass, B_fp8, As_cutlass, Bs_cutlass, block_size, out_dtype
)
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
rel_diff = torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
assert rel_diff < 0.001
@pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
@pytest.mark.skipif(not has_deep_gemm(),
reason="DeepGemm kernels not available.")
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),
)
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGemm kernels not available.")
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
# only aligned sizes
@@ -172,20 +176,20 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
As = As_fp8.to(torch.float32)
Bs = Bs_fp8.to(torch.float32)
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
out_dtype)
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
# Transpose earlier so that the testing will not trigger transposing kernels
As_fp8 = get_col_major_tma_aligned_tensor(As_fp8)
out = torch.zeros((M, N), device='cuda', dtype=out_dtype)
out = torch.zeros((M, N), device="cuda", dtype=out_dtype)
assert As_fp8.shape == (M, (K + 127) //
128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}"
assert As_fp8.shape == (M, (K + 127) // 128), (
f"{As_fp8.shape} != {(M, (K + 127) // 128)}"
)
fp8_gemm_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
rel_diff = torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
assert rel_diff < 0.001

View File

@@ -10,12 +10,12 @@ import torch
from tests.kernels.quant_utils import native_w8a8_block_matmul
from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization.utils.int8_utils import (
w8a8_block_int8_matmul)
w8a8_block_int8_matmul,
)
from vllm.platforms import current_platform
if current_platform.get_device_capability() < (7, 0):
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
allow_module_level=True)
pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
@@ -36,8 +36,10 @@ def setup_cuda():
torch.set_default_device("cuda")
@pytest.mark.parametrize("M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, DTYPES, SEEDS))
@pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, DTYPES, SEEDS),
)
@torch.inference_mode()
def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed):
torch.manual_seed(seed)
@@ -58,11 +60,10 @@ def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed):
As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
out_dtype)
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
out = w8a8_block_int8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
rel_diff = torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
assert rel_diff < 0.001

View File

@@ -11,12 +11,11 @@ import torch
from tests.kernels.utils import baseline_scaled_mm, to_fp8, to_int8
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
sparse_cutlass_supported)
sparse_cutlass_supported,
)
from vllm.platforms import current_platform
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
@@ -40,9 +39,7 @@ def prune_to_2_4(tensor):
# Create binary mask
mask = torch.zeros_like(reshaped)
mask.scatter_(dim=1,
index=indices,
src=torch.ones_like(indices, dtype=mask.dtype))
mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype))
# Apply mask and reshape back
pruned = reshaped * mask
@@ -55,32 +52,31 @@ def prune_to_2_4(tensor):
# This function checks that applying an identity matrix multiplication
# to the compressed weights yields the original uncompressed weights.
def check_compress_decompress_invariance(dtype: torch.dtype, b: torch.Tensor,
b_compressed: torch.Tensor,
b_metadata: torch.Tensor):
def check_compress_decompress_invariance(
dtype: torch.dtype,
b: torch.Tensor,
b_compressed: torch.Tensor,
b_metadata: torch.Tensor,
):
# For float16 and bfloat16, cutlass_scaled_sparse_mm's output must be the
# same dtype as its inputs. This line addresses that constraint while
# arbitrarily using bfloat16 for the int8/fp8 cases.
out_dtype = torch.float16 if dtype is torch.float16 else torch.bfloat16
eye = torch.eye(b.shape[0], device='cuda', dtype=dtype)
eye_scale = torch.ones(1, device='cuda', dtype=torch.float32)
b_decomp = ops.cutlass_scaled_sparse_mm(eye,
b_compressed,
b_metadata,
eye_scale,
eye_scale,
out_dtype=out_dtype)
eye = torch.eye(b.shape[0], device="cuda", dtype=dtype)
eye_scale = torch.ones(1, device="cuda", dtype=torch.float32)
b_decomp = ops.cutlass_scaled_sparse_mm(
eye, b_compressed, b_metadata, eye_scale, eye_scale, out_dtype=out_dtype
)
torch.testing.assert_close(b.to(dtype=out_dtype), b_decomp)
def make_rand_sparse_tensors(
dtype: torch.dtype, m: int, n: int, k: int
dtype: torch.dtype, m: int, n: int, k: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda')
b = torch.randn((n, k), device='cuda').t()
a = torch.randn((m, k), device="cuda")
b = torch.randn((n, k), device="cuda").t()
if dtype == torch.int8:
# ensure A and B aren't all zeros after rounding
@@ -107,32 +103,25 @@ def make_rand_sparse_tensors(
return b_compressed, e, a, b
@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.")
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.",
)
# Test working with a subset of A and B for sparse matmul
def test_cutlass_sparse_subset():
big_m = 1024
m, n, k = 512, 512, 512
# Create tensors
b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn,
big_m, n, k)
b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, big_m, n, k)
a = whole_a[0:m, 0:k]
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
out = ops.cutlass_scaled_sparse_mm(a,
b_comp,
e,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
out = ops.cutlass_scaled_sparse_mm(
a, b_comp, e, scale_a, scale_b, out_dtype=torch.bfloat16
)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16)
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
@@ -161,105 +150,87 @@ MNK_FACTORS = [
# Test working with a subset of A and B for sparse matmul
@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.")
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.",
)
@pytest.mark.parametrize("m, n, k", MNK_FACTORS)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: type[torch.dtype],
use_bias: bool):
def test_cutlass_sparse_gemm(
m: int, k: int, n: int, dtype: type[torch.dtype], use_bias: bool
):
# Create tensors
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
scale_a = torch.ones((1, 1), device="cuda", dtype=torch.float32)
scale_b = torch.ones((1, 1), device="cuda", dtype=torch.float32)
bias = torch.rand((n, ), device="cuda", dtype=dtype) if use_bias else None
bias = torch.rand((n,), device="cuda", dtype=dtype) if use_bias else None
out = ops.cutlass_scaled_sparse_mm(a,
b_comp,
e,
scale_a,
scale_b,
out_dtype=dtype,
bias=bias)
out = ops.cutlass_scaled_sparse_mm(
a, b_comp, e, scale_a, scale_b, out_dtype=dtype, bias=bias
)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=dtype,
bias=bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=dtype, bias=bias)
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.")
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.",
)
@pytest.mark.parametrize("m, k, n", MNK_FACTORS)
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.")
@pytest.mark.skipif(
not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.",
)
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_sparse_fp8_gemm(m: int, n: int, k: int, use_bias: bool):
# Create tensors
b_comp, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32)
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32)
out_dtype = torch.bfloat16
bias = torch.rand(
(n, ), device="cuda", dtype=out_dtype) * 10 if use_bias else None
bias = torch.rand((n,), device="cuda", dtype=out_dtype) * 10 if use_bias else None
out = ops.cutlass_scaled_sparse_mm(a,
b_comp,
e,
scale_a,
scale_b,
out_dtype=out_dtype,
bias=bias)
out = ops.cutlass_scaled_sparse_mm(
a, b_comp, e, scale_a, scale_b, out_dtype=out_dtype, bias=bias
)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=out_dtype,
bias=bias)
baseline = baseline_scaled_mm(
a, b, scale_a, scale_b, out_dtype=out_dtype, bias=bias
)
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
@pytest.mark.skipif(not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.")
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.",
)
@pytest.mark.parametrize("m,k,n", MNK_FACTORS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_sparse_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
per_out_ch: bool, use_bias: bool):
def test_cutlass_sparse_int8_gemm(
m: int, n: int, k: int, per_act_token: bool, per_out_ch: bool, use_bias: bool
):
# Create tensors
b_comp, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32)
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32)
out_dtype = torch.bfloat16
bias = torch.rand(
(n, ), device="cuda", dtype=out_dtype) * 10 if use_bias else None
bias = torch.rand((n,), device="cuda", dtype=out_dtype) * 10 if use_bias else None
out = ops.cutlass_scaled_sparse_mm(a,
b_comp,
e,
scale_a,
scale_b,
out_dtype=out_dtype,
bias=bias)
out = ops.cutlass_scaled_sparse_mm(
a, b_comp, e, scale_a, scale_b, out_dtype=out_dtype, bias=bias
)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=out_dtype,
bias=bias)
baseline = baseline_scaled_mm(
a, b, scale_a, scale_b, out_dtype=out_dtype, bias=bias
)
torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)

View File

@@ -4,6 +4,7 @@
Run `pytest tests/kernels/quantization/test_cutlass_scaled_mm.py`.
"""
import random
import pytest
@@ -36,9 +37,7 @@ MNK_FACTORS = [
(512, 24576, 128),
]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
# -1 means full extent in that dimension
TENSORWISE_GROUP_SHAPE = (-1, -1)
@@ -60,18 +59,19 @@ def group_scale_helper(shape, group_shape):
def scale_shape(shape, group_shape):
assert len(shape) == len(group_shape)
group_shape = group_scale_helper(shape, group_shape)
return tuple(
cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
def cutlass_fp8_gemm_helper(m: int,
n: int,
k: int,
a_scale_group_shape: tuple,
b_scale_group_shape: tuple,
use_bias: bool,
out_dtype: type[torch.dtype] = torch.bfloat16,
device: str = "cuda"):
def cutlass_fp8_gemm_helper(
m: int,
n: int,
k: int,
a_scale_group_shape: tuple,
b_scale_group_shape: tuple,
use_bias: bool,
out_dtype: type[torch.dtype] = torch.bfloat16,
device: str = "cuda",
):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
a = to_fp8(torch.randn((m, k), device=device))
@@ -80,8 +80,8 @@ def cutlass_fp8_gemm_helper(m: int,
a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
b_scales_shape = scale_shape(b.shape, b_scale_group_shape)
scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32))
scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32))
scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32)
scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32)
# make scales M-major for blockwise quant, doesn't affect 1D scales
scale_a = scale_a.t().contiguous().t()
@@ -89,7 +89,7 @@ def cutlass_fp8_gemm_helper(m: int,
scale_b = scale_b.t().contiguous().t()
if use_bias:
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10
else:
bias = None
@@ -98,18 +98,19 @@ def cutlass_fp8_gemm_helper(m: int,
torch.testing.assert_close(out, baseline, rtol=5e-1, atol=1.5e-1)
opcheck(torch.ops._C.cutlass_scaled_mm,
(out, a, b, scale_a, scale_b, bias))
opcheck(torch.ops._C.cutlass_scaled_mm, (out, a, b, scale_a, scale_b, bias))
def cutlass_int8_gemm_helper(m: int,
n: int,
k: int,
a_scale_group_shape: tuple,
b_scale_group_shape: tuple,
use_bias: bool,
out_dtype: type[torch.dtype] = torch.bfloat16,
device: str = "cuda"):
def cutlass_int8_gemm_helper(
m: int,
n: int,
k: int,
a_scale_group_shape: tuple,
b_scale_group_shape: tuple,
use_bias: bool,
out_dtype: type[torch.dtype] = torch.bfloat16,
device: str = "cuda",
):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
a = to_int8(torch.randn((m, k), device=device) * 5)
@@ -118,11 +119,11 @@ def cutlass_int8_gemm_helper(m: int,
a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
b_scales_shape = scale_shape(b.shape, b_scale_group_shape)
scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32))
scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32))
scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32)
scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32)
if use_bias:
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10
else:
bias = None
@@ -131,145 +132,192 @@ def cutlass_int8_gemm_helper(m: int,
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
opcheck(torch.ops._C.cutlass_scaled_mm,
(out, a, b, scale_a, scale_b, bias))
opcheck(torch.ops._C.cutlass_scaled_mm, (out, a, b, scale_a, scale_b, bias))
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("a_scale_group_shape",
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize(
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize(
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm(m: int, n: int, k: int, a_scale_group_shape,
b_scale_group_shape, use_bias: bool):
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
use_bias)
@pytest.mark.skipif(
not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.",
)
def test_cutlass_fp8_gemm(
m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool
):
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape",
[((1, 128), (128, 128))])
@pytest.mark.parametrize(
"a_scale_group_shape,b_scale_group_shape", [((1, 128), (128, 128))]
)
@pytest.mark.parametrize("use_bias", [False])
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="FP8 blockwise is not supported on this GPU type.")
def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int,
a_scale_group_shape,
b_scale_group_shape, use_bias: bool):
@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="FP8 blockwise is not supported on this GPU type.",
)
def test_cutlass_fp8_blockwise_scale_gemm(
m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool
):
if k % b_scale_group_shape[0] != 0 or n % b_scale_group_shape[1] != 0:
return
if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0:
return
if m % 4 != 0 and current_platform.has_device_capability(100):
return
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
use_bias)
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias)
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("a_scale_group_shape",
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize(
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize(
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_int8_gemm(m: int, n: int, k: int, a_scale_group_shape,
b_scale_group_shape, use_bias: bool):
cutlass_int8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
use_bias)
def test_cutlass_int8_gemm(
m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool
):
cutlass_int8_gemm_helper(
m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias
)
@pytest.mark.parametrize("a_scale_group_shape",
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize(
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize(
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
b_scale_group_shape,
out_dtype: type[torch.dtype],
use_bias: bool):
cutlass_int8_gemm_helper(512,
512,
512,
a_scale_group_shape,
b_scale_group_shape,
use_bias,
out_dtype=out_dtype)
def test_cutlass_int8_gemm_output_dtype(
a_scale_group_shape,
b_scale_group_shape,
out_dtype: type[torch.dtype],
use_bias: bool,
):
cutlass_int8_gemm_helper(
512,
512,
512,
a_scale_group_shape,
b_scale_group_shape,
use_bias,
out_dtype=out_dtype,
)
@pytest.mark.parametrize("a_scale_group_shape",
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize(
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize(
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape,
b_scale_group_shape,
out_dtype: type[torch.dtype],
use_bias: bool):
cutlass_fp8_gemm_helper(512,
512,
512,
a_scale_group_shape,
b_scale_group_shape,
use_bias,
out_dtype=out_dtype)
@pytest.mark.skipif(
not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.",
)
def test_cutlass_fp8_gemm_output_dtype(
a_scale_group_shape,
b_scale_group_shape,
out_dtype: type[torch.dtype],
use_bias: bool,
):
cutlass_fp8_gemm_helper(
512,
512,
512,
a_scale_group_shape,
b_scale_group_shape,
use_bias,
out_dtype=out_dtype,
)
@pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape",
[((1, 128), (128, 128))])
@pytest.mark.parametrize(
"a_scale_group_shape,b_scale_group_shape", [((1, 128), (128, 128))]
)
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [False])
@pytest.mark.skipif(not current_platform.has_device_capability(90),
reason="FP8 blockwise is not supported on this GPU type.")
def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape,
b_scale_group_shape,
out_dtype: type[torch.dtype],
use_bias: bool):
cutlass_fp8_gemm_helper(512,
512,
512,
a_scale_group_shape,
b_scale_group_shape,
use_bias,
out_dtype=out_dtype)
@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="FP8 blockwise is not supported on this GPU type.",
)
def test_cutlass_fp8_blockwise_scale_gemm_dtype(
a_scale_group_shape,
b_scale_group_shape,
out_dtype: type[torch.dtype],
use_bias: bool,
):
cutlass_fp8_gemm_helper(
512,
512,
512,
a_scale_group_shape,
b_scale_group_shape,
use_bias,
out_dtype=out_dtype,
)
@pytest.mark.parametrize("a_scale_group_shape",
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize(
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize(
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
use_bias: bool, device: str):
cutlass_fp8_gemm_helper(512, 512, 512, a_scale_group_shape,
b_scale_group_shape, use_bias, torch.bfloat16,
device)
@pytest.mark.skipif(
not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.",
)
def test_cutlass_fp8_gemm_devices(
a_scale_group_shape, b_scale_group_shape, use_bias: bool, device: str
):
cutlass_fp8_gemm_helper(
512,
512,
512,
a_scale_group_shape,
b_scale_group_shape,
use_bias,
torch.bfloat16,
device,
)
@pytest.mark.parametrize("a_scale_group_shape",
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize(
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize(
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
use_bias: bool, device: str):
cutlass_int8_gemm_helper(512,
512,
512,
a_scale_group_shape,
b_scale_group_shape,
use_bias,
out_dtype=torch.bfloat16,
device=device)
def test_cutlass_int8_gemm_devices(
a_scale_group_shape, b_scale_group_shape, use_bias: bool, device: str
):
cutlass_int8_gemm_helper(
512,
512,
512,
a_scale_group_shape,
b_scale_group_shape,
use_bias,
out_dtype=torch.bfloat16,
device=device,
)
# For the following two tests:
@@ -277,32 +325,42 @@ def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
# of a large power of two. In any case, the kernel will have a naive fallback
# when N and K are not divisible by 16. But M is the number of tokens and the
# kernel must handle any M thrown at it.
@pytest.mark.parametrize("a_scale_group_shape",
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize(
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize(
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.skipif(not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape,
use_bias: bool):
@pytest.mark.skipif(
not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.",
)
def test_cutlass_fp8_gemm_m_sweep(
a_scale_group_shape, b_scale_group_shape, use_bias: bool
):
for nk in range(32, 128, 32):
for m in range(1, 128):
cutlass_fp8_gemm_helper(m, nk, nk, a_scale_group_shape,
b_scale_group_shape, use_bias)
cutlass_fp8_gemm_helper(
m, nk, nk, a_scale_group_shape, b_scale_group_shape, use_bias
)
@pytest.mark.parametrize("a_scale_group_shape",
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize("b_scale_group_shape",
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
@pytest.mark.parametrize(
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize(
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
)
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_int8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape,
use_bias: bool):
def test_cutlass_int8_gemm_m_sweep(
a_scale_group_shape, b_scale_group_shape, use_bias: bool
):
for nk in range(32, 128, 32):
for m in range(1, 128):
cutlass_int8_gemm_helper(m, nk, nk, a_scale_group_shape,
b_scale_group_shape, use_bias)
cutlass_int8_gemm_helper(
m, nk, nk, a_scale_group_shape, b_scale_group_shape, use_bias
)
@pytest.mark.parametrize("m", [32, 64, 128])
@@ -310,8 +368,7 @@ def test_cutlass_int8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape,
@pytest.mark.parametrize("k", [64, 128, 256])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.skip
def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
out_dtype: torch.dtype):
def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int, out_dtype: torch.dtype):
# Currently, the test is failing because folding azp into
# 16-bit bias loses too much precision
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
@@ -328,7 +385,7 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
b_dq = scale_b * bq_f32
azp_a = torch.rand((1, ), device="cuda", dtype=torch.float32) * 10 + 1.5
azp_a = torch.rand((1,), device="cuda", dtype=torch.float32) * 10 + 1.5
azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
@@ -340,18 +397,17 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
J = torch.ones((1, k), device="cuda", dtype=torch.float32)
azp_bias = (azp_a * scale_b * (J @ bq_f32)).to(out_dtype)
assert azp_bias.shape == (1, n)
assert azp_bias[0, :].shape == (n, )
assert azp_bias[0, :].shape == (n,)
baseline_q = (scale_a.to(device='cpu') * scale_b.to(device='cpu') * (
(aq_i32 + azp_aq_i8).to(device='cpu') @ bq_i32.to(device='cpu'))).to(
dtype=out_dtype, device='cuda')
baseline_q = (
scale_a.to(device="cpu")
* scale_b.to(device="cpu")
* ((aq_i32 + azp_aq_i8).to(device="cpu") @ bq_i32.to(device="cpu"))
).to(dtype=out_dtype, device="cuda")
out = ops.cutlass_scaled_mm(aq_i8,
bq_i8,
scale_a,
scale_b,
out_dtype=out_dtype,
bias=azp_bias[0, :])
out = ops.cutlass_scaled_mm(
aq_i8, bq_i8, scale_a, scale_b, out_dtype=out_dtype, bias=azp_bias[0, :]
)
torch.testing.assert_close(out, baseline_dq, rtol=1e-2, atol=1e0)
torch.testing.assert_close(out, baseline_q, rtol=1e-2, atol=1e0)
@@ -362,8 +418,9 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [True, False])
@pytest.mark.parametrize("azp_per_token", [True, False])
def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
use_bias: bool, azp_per_token: bool):
def test_cutlass_int8_azp(
m: int, n: int, k: int, out_dtype: torch.dtype, use_bias: bool, azp_per_token: bool
):
m_azp = m if azp_per_token else 1
scale_a = torch.randn((m_azp, 1), device="cuda", dtype=torch.float32) / 10
scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
@@ -377,16 +434,12 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
bq_f32 = bq_i8.to(dtype=torch.float32)
b_dq = scale_b * bq_f32
azp_a = torch.rand(
(m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5
azp_a = torch.rand((m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5
azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
torch.testing.assert_close(a_dq,
scale_a * aq_f32 - azp_a,
rtol=1e-4,
atol=1e-3)
torch.testing.assert_close(a_dq, scale_a * aq_f32 - azp_a, rtol=1e-4, atol=1e-3)
if use_bias:
bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
@@ -396,8 +449,8 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
baseline_dq = (torch.mm(a_dq, b_dq) + bias).to(out_dtype)
# int32 mm not supported on CUDA
a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device='cpu')
cq = (a_noazp_i32_cpu @ bq_i32.to(device='cpu')).to(device='cuda')
a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device="cpu")
cq = (a_noazp_i32_cpu @ bq_i32.to(device="cpu")).to(device="cuda")
baseline_q = (scale_a * scale_b * cq + bias).to(dtype=out_dtype)
# Hadamard is just the sum of the cols
@@ -406,14 +459,14 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
func_bias = bias if use_bias else None
if azp_per_token:
out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
out_dtype, azp_adj_i32, azp_i32,
func_bias)
out = ops.cutlass_scaled_mm_azp(
aq_i8, bq_i8, scale_a, scale_b, out_dtype, azp_adj_i32, azp_i32, func_bias
)
else:
azp_with_adj_i32 = azp_i32 * azp_adj_i32
out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
out_dtype, azp_with_adj_i32, None,
func_bias)
out = ops.cutlass_scaled_mm_azp(
aq_i8, bq_i8, scale_a, scale_b, out_dtype, azp_with_adj_i32, None, func_bias
)
# bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
# float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
@@ -423,13 +476,15 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
if azp_per_token:
opcheck(torch.ops._C.cutlass_scaled_mm_azp,
(out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32,
func_bias))
opcheck(
torch.ops._C.cutlass_scaled_mm_azp,
(out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32, func_bias),
)
else:
opcheck(torch.ops._C.cutlass_scaled_mm_azp,
(out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None,
func_bias))
opcheck(
torch.ops._C.cutlass_scaled_mm_azp,
(out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None, func_bias),
)
# Test working with a subset of A and B
@@ -445,23 +500,14 @@ def test_cutlass_subset():
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
out = ops.cutlass_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
baseline = baseline_scaled_mm(a,
b,
scale_a,
scale_b,
out_dtype=torch.bfloat16)
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16)
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
# Test to make sure cuda graphs work
class CutlassLayer(torch.nn.Module):
def __init__(self, b, scale_a, scale_b, out_dtype):
super().__init__()
self.b = b
@@ -470,8 +516,9 @@ class CutlassLayer(torch.nn.Module):
self.out_dtype = out_dtype
def forward(self, a):
return ops.cutlass_scaled_mm(a, self.b, self.scale_a, self.scale_b,
self.out_dtype)
return ops.cutlass_scaled_mm(
a, self.b, self.scale_a, self.scale_b, self.out_dtype
)
@pytest.mark.parametrize("per_act_token", [True, False])
@@ -485,10 +532,8 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
m_a_scales = m if per_act_token else 1
n_b_scales = n if per_out_ch else 1
scale_a = (torch.randn(
(m_a_scales, 1), device="cuda", dtype=torch.float32) / 10)
scale_b = (torch.randn(
(1, n_b_scales), device="cuda", dtype=torch.float32) / 10)
scale_a = torch.randn((m_a_scales, 1), device="cuda", dtype=torch.float32) / 10
scale_b = torch.randn((1, n_b_scales), device="cuda", dtype=torch.float32) / 10
# Construct a trivial model with a single layer that calls a CUTLASS kernel
model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16)
@@ -502,13 +547,14 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
out.zero_()
g.replay()
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16)
baseline = torch.mm(
scale_a * a.to(dtype=torch.float32), scale_b * b.to(dtype=torch.float32)
).to(torch.bfloat16)
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
def test_cutlass_support_opcheck():
opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, ))
opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability,))
@pytest.mark.parametrize("num_experts", [8, 64])
@@ -517,11 +563,13 @@ def test_cutlass_support_opcheck():
@pytest.mark.parametrize("use_bias", [False])
@pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
reason="Grouped gemm is not supported on this GPU type.")
def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
per_out_ch: bool, use_bias: bool):
current_platform.get_device_capability()
),
reason="Grouped gemm is not supported on this GPU type.",
)
def test_cutlass_fp8_group_gemm(
num_experts: int, per_act_token: bool, per_out_ch: bool, use_bias: bool
):
# Device and dtype setup
device = "cuda"
out_dtype = torch.half
@@ -533,13 +581,9 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
b_scales_tensors = []
baseline_tensors = []
expert_offsets = torch.zeros((num_experts + 1),
device=device,
dtype=torch.int64)
expert_offsets = torch.zeros((num_experts + 1), device=device, dtype=torch.int64)
problem_sizes = torch.zeros((num_experts, 3),
device=device,
dtype=torch.int32)
problem_sizes = torch.zeros((num_experts, 3), device=device, dtype=torch.int32)
if not per_act_token:
one_scale_a = torch.randn((1, 1), device=device, dtype=torch.float32)
@@ -566,75 +610,76 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
b_tensors.append(b_g)
# Set up A/B scales
scale_b = torch.randn((1, n_b_scales),
device=device,
dtype=torch.float32)
scale_b = torch.randn((1, n_b_scales), device=device, dtype=torch.float32)
b_scales_tensors.append(scale_b)
if per_act_token:
scale_a = torch.randn((m_a_scales, 1),
device=device,
dtype=torch.float32)
scale_a = torch.randn((m_a_scales, 1), device=device, dtype=torch.float32)
a_scales_tensors.append(scale_a)
else:
scale_a = one_scale_a
# Compute baseline result for this group
baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype,
None)
baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype, None)
baseline_tensors.append(baseline_g)
a_tensors_stacked = torch.empty((expert_offsets[num_experts], k_g),
device=device,
dtype=torch.float8_e4m3fn)
b_tensors_stacked = torch.empty((num_experts, n_g, k_g),
device=device,
dtype=torch.float8_e4m3fn)
a_tensors_stacked = torch.empty(
(expert_offsets[num_experts], k_g), device=device, dtype=torch.float8_e4m3fn
)
b_tensors_stacked = torch.empty(
(num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn
)
for g in range(num_experts):
a_tensors_stacked[expert_offsets[g]:expert_offsets[g +
1]] = a_tensors[g]
a_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[g]
b_tensors_stacked[g] = b_tensors[g].t()
b_tensors_stacked = b_tensors_stacked.transpose(1, 2)
if per_act_token:
a_scales_tensors_stacked = torch.empty(
(expert_offsets[num_experts], 1),
device=device,
dtype=torch.float32)
(expert_offsets[num_experts], 1), device=device, dtype=torch.float32
)
for g in range(num_experts):
a_scales_tensors_stacked[
expert_offsets[g]:expert_offsets[g + 1]] = a_scales_tensors[g]
a_scales_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]] = (
a_scales_tensors[g]
)
else:
a_scales_tensors_stacked = one_scale_a
b_scales_tensors_stacked = torch.empty((num_experts, n_b_scales),
device=device,
dtype=torch.float32)
b_scales_tensors_stacked = torch.empty(
(num_experts, n_b_scales), device=device, dtype=torch.float32
)
for g in range(num_experts):
b_scales_tensors_stacked[g] = b_scales_tensors[g]
out_tensors_stacked = torch.zeros((expert_offsets[num_experts], n_g),
device=device,
dtype=out_dtype)
out_tensors_stacked = torch.zeros(
(expert_offsets[num_experts], n_g), device=device, dtype=out_dtype
)
ab_strides = torch.full((num_experts, ),
a_tensors_stacked.stride(0),
device="cuda",
dtype=torch.int64)
c_strides = torch.full((num_experts, ),
out_tensors_stacked.stride(0),
device="cuda",
dtype=torch.int64)
ab_strides = torch.full(
(num_experts,), a_tensors_stacked.stride(0), device="cuda", dtype=torch.int64
)
c_strides = torch.full(
(num_experts,), out_tensors_stacked.stride(0), device="cuda", dtype=torch.int64
)
ops.cutlass_moe_mm(out_tensors_stacked, a_tensors_stacked,
b_tensors_stacked, a_scales_tensors_stacked,
b_scales_tensors_stacked, expert_offsets[:-1],
problem_sizes, ab_strides, ab_strides, c_strides,
per_act_token, per_out_ch)
ops.cutlass_moe_mm(
out_tensors_stacked,
a_tensors_stacked,
b_tensors_stacked,
a_scales_tensors_stacked,
b_scales_tensors_stacked,
expert_offsets[:-1],
problem_sizes,
ab_strides,
ab_strides,
c_strides,
per_act_token,
per_out_ch,
)
# Validate each group's result against the baseline
for g in range(num_experts):
baseline = baseline_tensors[g]
c = out_tensors_stacked[expert_offsets[g]:expert_offsets[g + 1]]
c = out_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]]
torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-4)

View File

@@ -13,7 +13,9 @@ import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_rows, quantize_weights)
pack_rows,
quantize_weights,
)
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
@@ -24,16 +26,33 @@ from vllm.scalar_type import ScalarType, scalar_types
# have kernels and some kernels support multiple quantization methods.
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
MNK_SHAPES = [(1, 128, 128), (1, 512, 1024), (1, 4096, 4096), (1, 8192, 28672),
(13, 8192, 4096), (26, 4096, 8192), (64, 4096, 4096),
(64, 8192, 28672), (257, 128, 4096), (257, 4096, 4096),
(1024, 4096, 8192), (1024, 8192, 4096)]
MNK_SHAPES = [
(1, 128, 128),
(1, 512, 1024),
(1, 4096, 4096),
(1, 8192, 28672),
(13, 8192, 4096),
(26, 4096, 8192),
(64, 4096, 4096),
(64, 8192, 28672),
(257, 128, 4096),
(257, 4096, 4096),
(1024, 4096, 8192),
(1024, 8192, 4096),
]
# TODO(czhu): get supported schedules from fn
SCHEDULES = [
'128x16_1x1x1', '256x16_1x1x1', '128x32_1x1x1', '256x32_1x1x1',
'128x64_1x1x1', '256x64_1x1x1', '128x128_1x1x1', '256x128_1x1x1',
'128x256_1x1x1', '128x256_2x1x1'
"128x16_1x1x1",
"256x16_1x1x1",
"128x32_1x1x1",
"256x32_1x1x1",
"128x64_1x1x1",
"256x64_1x1x1",
"128x128_1x1x1",
"256x128_1x1x1",
"128x256_1x1x1",
"128x256_2x1x1",
]
@@ -60,19 +79,23 @@ class Tensors:
# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints,
# Ch Scales Type, Tok Scales Type)
TestTypeTuple = tuple[list[torch.dtype], ScalarType, Optional[torch.dtype],
Optional[torch.dtype], bool]
TestTypeTuple = tuple[
list[torch.dtype], ScalarType, Optional[torch.dtype], Optional[torch.dtype], bool
]
TEST_TYPES = [
*(
TypeConfig(act_type=torch.float8_e4m3fn,
weight_type=w_type,
output_type=o_type,
group_scale_type=torch.float8_e4m3fn,
channel_scale_type=torch.float32,
token_scale_type=torch.float32)
TypeConfig(
act_type=torch.float8_e4m3fn,
weight_type=w_type,
output_type=o_type,
group_scale_type=torch.float8_e4m3fn,
channel_scale_type=torch.float32,
token_scale_type=torch.float32,
)
for w_type in [scalar_types.int4]
# TODO(czhu): fp16 out type
for o_type in [torch.bfloat16]),
for o_type in [torch.bfloat16]
),
]
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
@@ -86,26 +109,28 @@ IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90)
# For testing quantized linear kernels
def to_fp8(tensor: torch.Tensor):
finfo = torch.finfo(torch.float8_e4m3fn)
return tensor.clamp(min=finfo.min,
max=finfo.max).to(dtype=torch.float8_e4m3fn)
return tensor.clamp(min=finfo.min, max=finfo.max).to(dtype=torch.float8_e4m3fn)
def cutlass_quantize_and_pack(atype: torch.dtype,
w: torch.Tensor,
wtype: ScalarType,
stype: Optional[torch.dtype],
group_size: Optional[int],
zero_points: bool = False):
def cutlass_quantize_and_pack(
atype: torch.dtype,
w: torch.Tensor,
wtype: ScalarType,
stype: Optional[torch.dtype],
group_size: Optional[int],
zero_points: bool = False,
):
assert wtype.is_integer(), "TODO: support floating point weights"
w_ref, w_q, w_s, w_zp = quantize_weights(w,
wtype,
group_size=group_size,
zero_points=zero_points)
w_ref, w_q, w_s, w_zp = quantize_weights(
w, wtype, group_size=group_size, zero_points=zero_points
)
# since scales are cast to fp8, we need to compute w_ref this way
w_ref = ((w_q).to(torch.float32) * w_s.to(atype).to(
torch.float32).repeat_interleave(group_size, dim=0)).to(atype)
w_ref = (
(w_q).to(torch.float32)
* w_s.to(atype).to(torch.float32).repeat_interleave(group_size, dim=0)
).to(atype)
# bit mask prevents sign extending int4 when packing
w_q = pack_rows(w_q & 0x0F, wtype.size_bits, *w_q.shape)
@@ -117,12 +142,14 @@ def cutlass_quantize_and_pack(atype: torch.dtype,
return w_ref, w_q_packed, w_s_packed, w_zp
def create_test_tensors(shape: tuple[int, int, int], types: TypeConfig,
group_size: Optional[int]) -> Tensors:
def create_test_tensors(
shape: tuple[int, int, int], types: TypeConfig, group_size: Optional[int]
) -> Tensors:
m, n, k = shape
print("create_test_tensors, shape:", shape, "types:", types, "group_size:",
group_size)
print(
"create_test_tensors, shape:", shape, "types:", types, "group_size:", group_size
)
a = to_fp8(torch.randn((m, k), device="cuda"))
w = to_fp8(torch.randn((k, n), device="cuda"))
@@ -133,30 +160,34 @@ def create_test_tensors(shape: tuple[int, int, int], types: TypeConfig,
w = w.to(torch.float16)
w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack(
a.dtype, w, types.weight_type, types.group_scale_type, group_size,
False)
a.dtype, w, types.weight_type, types.group_scale_type, group_size, False
)
a_ref = a.to(torch.float32)
w_ref = w_ref.to(torch.float32)
# for the practical use case we need per-tok scales for fp8 activations
w_tok_s = torch.randn((m, ), device='cuda', dtype=types.token_scale_type)
w_tok_s = torch.randn((m,), device="cuda", dtype=types.token_scale_type)
# weights are already per-group quantized, use placeholder here
w_ch_s = torch.ones((n, ), device='cuda', dtype=types.channel_scale_type)
w_ch_s = torch.ones((n,), device="cuda", dtype=types.channel_scale_type)
return Tensors(w_ref=w_ref,
a_ref=a_ref,
a=a,
w_q=w_q_packed,
w_g_s=w_s,
w_ch_s=w_ch_s,
w_tok_s=w_tok_s)
return Tensors(
w_ref=w_ref,
a_ref=a_ref,
a=a,
w_q=w_q_packed,
w_g_s=w_s,
w_ch_s=w_ch_s,
w_tok_s=w_tok_s,
)
def mm_test_helper(types: TypeConfig,
tensors: Tensors,
group_size: Optional[int] = None,
schedule: Optional[str] = None):
def mm_test_helper(
types: TypeConfig,
tensors: Tensors,
group_size: Optional[int] = None,
schedule: Optional[str] = None,
):
# CUTLASS upstream uses fp8 with fastaccum as reference
# https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L406
output_ref = torch._scaled_mm(
@@ -165,7 +196,8 @@ def mm_test_helper(types: TypeConfig,
tensors.w_tok_s.unsqueeze(1),
tensors.w_ch_s.unsqueeze(0),
out_dtype=types.output_type,
use_fast_accum=True)
use_fast_accum=True,
)
output = ops.cutlass_w4a8_mm(
a=tensors.a,
@@ -179,17 +211,15 @@ def mm_test_helper(types: TypeConfig,
print(output)
print(output_ref)
torch.testing.assert_close(output,
output_ref.to(output.dtype),
rtol=1e-3,
atol=1e-3)
torch.testing.assert_close(
output, output_ref.to(output.dtype), rtol=1e-3, atol=1e-3
)
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
reason="CUTLASS W4A8 is not supported on this GPU type.")
@pytest.mark.parametrize("shape",
MNK_SHAPES,
ids=lambda x: "x".join(str(v) for v in x))
@pytest.mark.skipif(
not IS_SUPPORTED_BY_GPU, reason="CUTLASS W4A8 is not supported on this GPU type."
)
@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x))
@pytest.mark.parametrize("types", TEST_TYPES)
@pytest.mark.parametrize("schedule", SCHEDULES)
def test_cutlass_w4a8(shape, types: TypeConfig, schedule):
@@ -201,7 +231,6 @@ def test_cutlass_w4a8(shape, types: TypeConfig, schedule):
# Test to make sure cuda graphs work
class W4A8Layer(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__()
self.kwargs = kwargs
@@ -210,8 +239,9 @@ class W4A8Layer(torch.nn.Module):
return ops.cutlass_w4a8_mm(a=a, **self.kwargs)
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
reason="CUTLASS W4A8 is not supported on this GPU type.")
@pytest.mark.skipif(
not IS_SUPPORTED_BY_GPU, reason="CUTLASS W4A8 is not supported on this GPU type."
)
def test_w4a8_cuda_graph():
m, n, k = 512, 4096, 4096
@@ -224,10 +254,11 @@ def test_w4a8_cuda_graph():
zero_points = False
w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack(
a.dtype, b.to(torch.float16), wtype, stype, group_size, zero_points)
a.dtype, b.to(torch.float16), wtype, stype, group_size, zero_points
)
w_tok_s = torch.randn((m, ), device='cuda', dtype=torch.float32)
w_ch_s = torch.ones((n, ), device='cuda', dtype=torch.float32)
w_tok_s = torch.randn((m,), device="cuda", dtype=torch.float32)
w_ch_s = torch.ones((n,), device="cuda", dtype=torch.float32)
# Construct a trivial model with a single layer that calls the kernel
model = W4A8Layer(
@@ -244,7 +275,8 @@ def test_w4a8_cuda_graph():
w_tok_s.unsqueeze(1),
w_ch_s.unsqueeze(0),
out_dtype=torch.bfloat16,
use_fast_accum=True)
use_fast_accum=True,
)
# Run the model with a cuda graph
stream = torch.cuda.Stream()

View File

@@ -2,8 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX,
convert_swizzled_to_linear, dequantize_nvfp4_to_dtype)
from nvfp4_utils import (
FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
convert_swizzled_to_linear,
dequantize_nvfp4_to_dtype,
)
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
@@ -41,18 +45,12 @@ def get_ref_results(
_, m_k = a_fp4.shape
_, n_k = b_fp4.shape
assert m_k == n_k
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
a_sf,
a_global_scale,
dtype=dtype,
device=device,
block_size=block_size)
b_in_dtype = dequantize_nvfp4_to_dtype(b_fp4,
b_sf,
b_global_scale,
dtype=dtype,
device=device,
block_size=block_size)
a_in_dtype = dequantize_nvfp4_to_dtype(
a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size
)
b_in_dtype = dequantize_nvfp4_to_dtype(
b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size
)
return torch.matmul(a_in_dtype, b_in_dtype.t())
@@ -72,8 +70,7 @@ def test_flashinfer_nvfp4_gemm(
autotune: bool,
) -> None:
if backend == "trtllm" and dtype == torch.float16:
pytest.skip(
"Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations")
pytest.skip("Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations")
current_platform.seed_everything(seed)
m, n, packed_k = shape
@@ -82,10 +79,12 @@ def test_flashinfer_nvfp4_gemm(
a_dtype = torch.randn((m, k), dtype=dtype, device=device)
b_dtype = torch.randn((n, k), dtype=dtype, device=device)
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(a_dtype.flatten(), dim=-1)).to(torch.float32)
b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32)
a_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1)
).to(torch.float32)
b_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)
).to(torch.float32)
alpha = 1.0 / (a_global_scale * b_global_scale)
# ops.scaled_fp4_quant returns swizzled scales, while weights
# from checkpoints are in linear scales.
@@ -113,14 +112,18 @@ def test_flashinfer_nvfp4_gemm(
if backend == "trtllm":
epilogue_tile_m = 128
b_fp4 = flashinfer.shuffle_matrix_a(b_fp4.view(torch.uint8),
epilogue_tile_m)
b_fp4 = flashinfer.shuffle_matrix_a(b_fp4.view(torch.uint8), epilogue_tile_m)
b_scale_interleaved = convert_swizzled_to_linear(
b_scale_interleaved, n, k, block_size)
b_scale_interleaved = (flashinfer.shuffle_matrix_sf_a(
b_scale_interleaved.view(torch.uint8), epilogue_tile_m).reshape(
b_scale_interleaved.shape).view(torch.float8_e4m3fn))
b_scale_interleaved, n, k, block_size
)
b_scale_interleaved = (
flashinfer.shuffle_matrix_sf_a(
b_scale_interleaved.view(torch.uint8), epilogue_tile_m
)
.reshape(b_scale_interleaved.shape)
.view(torch.float8_e4m3fn)
)
with flashinfer.autotune(autotune):
out = flashinfer_scaled_fp4_mm(
@@ -133,7 +136,4 @@ def test_flashinfer_nvfp4_gemm(
backend=backend,
)
torch.testing.assert_close(out,
expected_out.to(dtype=dtype),
atol=1e-1,
rtol=1e-1)
torch.testing.assert_close(out, expected_out.to(dtype=dtype), atol=1e-1, rtol=1e-1)

View File

@@ -9,8 +9,7 @@ from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm
if not current_platform.has_device_capability(100):
pytest.skip(
reason=
"Flashinfer FP8 gemms requires compute capability of 10.0 or above.",
reason="Flashinfer FP8 gemms requires compute capability of 10.0 or above.",
allow_module_level=True,
)
@@ -53,7 +52,7 @@ def test_flashinfer_fp8_gemm(
).to(dtype=dtype)
if use_bias:
bias = torch.randn((n, ), dtype=dtype, device=device)
bias = torch.randn((n,), dtype=dtype, device=device)
expected_out = expected_out + bias
else:
bias = None

View File

@@ -5,9 +5,11 @@ import pytest
import torch
import vllm._custom_ops as ops
from tests.kernels.quant_utils import (FP8_DTYPE,
ref_dynamic_per_tensor_fp8_quant,
ref_dynamic_per_token_quant)
from tests.kernels.quant_utils import (
FP8_DTYPE,
ref_dynamic_per_tensor_fp8_quant,
ref_dynamic_per_token_quant,
)
from tests.kernels.utils import opcheck
from vllm.platforms import current_platform
@@ -18,23 +20,25 @@ SCALE_UBS = [True, False]
SEEDS = [0]
def opcheck_fp8_quant(output,
input,
scale=None,
scale_ub=None,
use_per_token_if_dynamic=False):
def opcheck_fp8_quant(
output, input, scale=None, scale_ub=None, use_per_token_if_dynamic=False
):
if scale is not None:
opcheck(torch.ops._C.static_scaled_fp8_quant, (output, input, scale))
elif use_per_token_if_dynamic:
scale = torch.empty((input.shape[0], 1),
device=input.device,
dtype=torch.float32)
opcheck(torch.ops._C.dynamic_per_token_scaled_fp8_quant,
(output, input, scale, scale_ub))
scale = torch.empty(
(input.shape[0], 1), device=input.device, dtype=torch.float32
)
opcheck(
torch.ops._C.dynamic_per_token_scaled_fp8_quant,
(output, input, scale, scale_ub),
)
else:
scale = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
scale = torch.empty(
(input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32,
)
opcheck(torch.ops._C.dynamic_scaled_fp8_quant, (output, input, scale))
@@ -44,30 +48,29 @@ def opcheck_fp8_quant(output,
@pytest.mark.parametrize("scale_ub", SCALE_UBS)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, scale_ub: bool,
seed: int) -> None:
def test_dynamic_per_token_fp8_quant(
num_tokens: int, hidden_size: int, dtype: torch.dtype, scale_ub: bool, seed: int
) -> None:
current_platform.seed_everything(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
device="cuda") + 1e-6 # avoid nans
x = (
torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") + 1e-6
) # avoid nans
scale_ub = torch.mean(x).to(dtype=torch.float32, device='cuda') \
if scale_ub else None
scale_ub = (
torch.mean(x).to(dtype=torch.float32, device="cuda") if scale_ub else None
)
ref_out, ref_scales = ref_dynamic_per_token_quant(x, FP8_DTYPE, scale_ub)
ops_out, ops_scales = ops.scaled_fp8_quant(x,
scale_ub=scale_ub,
use_per_token_if_dynamic=True)
ops_out, ops_scales = ops.scaled_fp8_quant(
x, scale_ub=scale_ub, use_per_token_if_dynamic=True
)
torch.testing.assert_close(ref_scales, ops_scales)
torch.testing.assert_close(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))
torch.testing.assert_close(
ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)
)
opcheck_fp8_quant(ops_out,
x,
None,
scale_ub,
use_per_token_if_dynamic=True)
opcheck_fp8_quant(ops_out, x, None, scale_ub, use_per_token_if_dynamic=True)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@@ -75,8 +78,9 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
def test_dynamic_per_tensor_fp8_quant(
num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int
) -> None:
current_platform.seed_everything(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
@@ -85,8 +89,9 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
ops_out, ops_scale = ops.scaled_fp8_quant(x)
torch.testing.assert_close(ref_scale, ops_scale)
torch.testing.assert_close(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))
torch.testing.assert_close(
ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)
)
opcheck_fp8_quant(ops_out, x)

View File

@@ -6,8 +6,7 @@ import pytest
import torch
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
@@ -18,13 +17,14 @@ from vllm.platforms import current_platform
(64, 1024, 64), # Medium
(128, 2048, 128), # Large
(8, 513, 64), # Non-divisible (native only)
])
],
)
@pytest.mark.parametrize("seed", [42])
@pytest.mark.parametrize("use_ue8m0", [True, False])
@torch.inference_mode()
def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
group_size: int, seed: int,
use_ue8m0: bool) -> None:
def test_quantfp8_group_functionality(
batch_size: int, hidden_dim: int, group_size: int, seed: int, use_ue8m0: bool
) -> None:
"""Test QuantFP8 group quantization with various configurations.
Tests both CUDA and native implementations, column-major scales,
@@ -32,16 +32,17 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
"""
current_platform.seed_everything(seed)
x = torch.randn(
(batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
x = torch.randn((batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
expected_num_groups = (hidden_dim + group_size - 1) // group_size
is_divisible = hidden_dim % group_size == 0
group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=False,
use_ue8m0=use_ue8m0)
quant_op = QuantFP8(
static=False,
group_shape=group_shape,
column_major_scales=False,
use_ue8m0=use_ue8m0,
)
# 1. Test native implementation (always available)
x_quant_native, scales_native = quant_op.forward_native(x.clone())
@@ -49,10 +50,12 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
assert scales_native.shape == (batch_size, expected_num_groups)
# 2. Test column-major scales configuration
quant_op_col = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=True,
use_ue8m0=use_ue8m0)
quant_op_col = QuantFP8(
static=False,
group_shape=group_shape,
column_major_scales=True,
use_ue8m0=use_ue8m0,
)
_, scales_col = quant_op_col.forward_native(x.clone())
assert scales_col.shape == (batch_size, expected_num_groups)
assert scales_col.stride(0) == 1
@@ -86,41 +89,48 @@ def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
# Test with 3D input
batch1, batch2, hidden_dim = 4, 8, 1024
x_3d = torch.randn(
(batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
x_3d = (
torch.randn((batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda")
* 8
)
group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=False,
use_ue8m0=use_ue8m0)
quant_op = QuantFP8(
static=False,
group_shape=group_shape,
column_major_scales=False,
use_ue8m0=use_ue8m0,
)
x_quant, scales = quant_op.forward_native(x_3d.clone())
assert x_quant.shape == x_3d.shape
assert scales.shape == (batch1, batch2, hidden_dim // group_size)
# Test column_major_scales with multi-dim
quant_op_col = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=True,
use_ue8m0=use_ue8m0)
quant_op_col = QuantFP8(
static=False,
group_shape=group_shape,
column_major_scales=True,
use_ue8m0=use_ue8m0,
)
_, scales_col = quant_op_col.forward_native(x_3d.clone())
assert scales_col.shape == (batch1, batch2, hidden_dim // group_size)
# Test with 4D input
batch1, batch2, batch3, hidden_dim = 2, 3, 4, 256
x_4d = torch.randn((batch1, batch2, batch3, hidden_dim),
dtype=torch.bfloat16,
device="cuda") * 8
x_4d = (
torch.randn(
(batch1, batch2, batch3, hidden_dim), dtype=torch.bfloat16, device="cuda"
)
* 8
)
x_quant_4d, scales_4d = quant_op.forward_native(x_4d.clone())
assert x_quant_4d.shape == x_4d.shape
assert scales_4d.shape == (batch1, batch2, batch3,
hidden_dim // group_size)
assert scales_4d.shape == (batch1, batch2, batch3, hidden_dim // group_size)
_, scales_4d_col = quant_op_col.forward_native(x_4d.clone())
assert scales_4d_col.shape == (batch1, batch2, hidden_dim // group_size,
batch3)
assert scales_4d_col.shape == (batch1, batch2, hidden_dim // group_size, batch3)
@pytest.mark.parametrize("seed", [42])
@@ -132,30 +142,24 @@ def test_quantfp8_group_edge_cases(seed: int) -> None:
group_size = 64
# Test with single group (group_size >= hidden_dim)
x_small = torch.randn(
(batch_size, 32), dtype=torch.bfloat16, device="cuda") * 8
x_small = torch.randn((batch_size, 32), dtype=torch.bfloat16, device="cuda") * 8
group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=False)
quant_op = QuantFP8(
static=False, group_shape=group_shape, column_major_scales=False
)
x_quant_small, scales_small = quant_op.forward_native(x_small.clone())
assert x_quant_small.shape == x_small.shape
assert scales_small.shape == (batch_size, 1)
# Test with zero inputs
x_zero = torch.zeros((batch_size, 256),
dtype=torch.bfloat16,
device="cuda")
x_zero = torch.zeros((batch_size, 256), dtype=torch.bfloat16, device="cuda")
x_quant_zero, scales_zero = quant_op.forward_native(x_zero.clone())
assert x_quant_zero.shape == x_zero.shape
assert (scales_zero > 0).all(), "Scales should be clamped to minimum"
# Test very large values
x_large = torch.full((batch_size, 256),
1000.0,
dtype=torch.bfloat16,
device="cuda")
x_large = torch.full((batch_size, 256), 1000.0, dtype=torch.bfloat16, device="cuda")
x_quant_large, scales_large = quant_op.forward_native(x_large.clone())
assert x_quant_large.shape == x_large.shape
# FP8 max is typically 448 or 224, so scales should be > 1

View File

@@ -13,33 +13,42 @@ from vllm import _custom_ops as ops # noqa: F401
def test_ggml_opcheck(quant_type):
block_size, type_size = gguf.GGML_QUANT_SIZES[quant_type]
shape = [256, 1152]
qweight = torch.randint(0, 100, shape, device='cuda', dtype=torch.uint8)
qweight = torch.randint(0, 100, shape, device="cuda", dtype=torch.uint8)
m = qweight.shape[0]
n = qweight.shape[1] // type_size * block_size
opcheck(torch.ops._C.ggml_dequantize,
(qweight, quant_type, m, n, torch.float16))
opcheck(torch.ops._C.ggml_dequantize, (qweight, quant_type, m, n, torch.float16))
x = torch.rand((m, 512), device='cuda', dtype=torch.float16)
opcheck(torch.ops._C.ggml_mul_mat_a8,
(qweight, x, quant_type, qweight.shape[0]))
opcheck(torch.ops._C.ggml_mul_mat_vec_a8,
(qweight, x, quant_type, qweight.shape[0]))
x = torch.rand((m, 512), device="cuda", dtype=torch.float16)
opcheck(torch.ops._C.ggml_mul_mat_a8, (qweight, x, quant_type, qweight.shape[0]))
opcheck(
torch.ops._C.ggml_mul_mat_vec_a8, (qweight, x, quant_type, qweight.shape[0])
)
shape = [256, 1024, 336]
qweight = torch.randint(0, 100, shape, device='cuda', dtype=torch.uint8)
x = torch.rand((1, 1024), device='cuda', dtype=torch.float16)
sorted_token_ids = torch.arange(776, device='cuda')
expert_ids = torch.randint(0, 256, (194, ), device='cuda')
num_tokens_post_padded = torch.tensor([1],
dtype=torch.int64,
device='cuda')
qweight = torch.randint(0, 100, shape, device="cuda", dtype=torch.uint8)
x = torch.rand((1, 1024), device="cuda", dtype=torch.float16)
sorted_token_ids = torch.arange(776, device="cuda")
expert_ids = torch.randint(0, 256, (194,), device="cuda")
num_tokens_post_padded = torch.tensor([1], dtype=torch.int64, device="cuda")
opcheck(torch.ops._C.ggml_moe_a8,
(x, qweight, sorted_token_ids, expert_ids, num_tokens_post_padded,
quant_type, qweight.shape[0], 1, x.shape[0]))
opcheck(
torch.ops._C.ggml_moe_a8,
(
x,
qweight,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
quant_type,
qweight.shape[0],
1,
x.shape[0],
),
)
topk_ids = torch.zeros((1, 1), device='cuda', dtype=torch.int32)
topk_ids = torch.zeros((1, 1), device="cuda", dtype=torch.int32)
opcheck(
torch.ops._C.ggml_moe_a8_vec,
(x, qweight, topk_ids, 1, quant_type, qweight.shape[0], x.shape[0]))
(x, qweight, topk_ids, 1, quant_type, qweight.shape[0], x.shape[0]),
)

View File

@@ -18,8 +18,8 @@ GGUF_SAMPLE_MOE = snapshot_download("SzymonOzog/test-gguf-moe-sample")
def get_gguf_sample_tensors(
hidden_size: int,
quant_type: GGMLQuantizationType) -> list[ReaderTensor]:
hidden_size: int, quant_type: GGMLQuantizationType
) -> list[ReaderTensor]:
sample_dir = GGUF_SAMPLE
filename = f"Quant_{quant_type.name}_{hidden_size}.gguf"
sample_file = Path(sample_dir) / filename
@@ -27,8 +27,8 @@ def get_gguf_sample_tensors(
def get_gguf_MoE_tensors(
hidden_size: int,
quant_type: GGMLQuantizationType) -> list[ReaderTensor]:
hidden_size: int, quant_type: GGMLQuantizationType
) -> list[ReaderTensor]:
sample_dir = GGUF_SAMPLE_MOE
filename = f"Quant_{quant_type.name}_{hidden_size}.gguf"
sample_file = Path(sample_dir) / filename
@@ -68,17 +68,20 @@ QUANT_TYPES = [
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
@torch.inference_mode()
def test_dequantize(hidden_size: int, dtype: torch.dtype,
quant_type: GGMLQuantizationType):
def test_dequantize(
hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType
):
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
for tensor in tensors:
shape_str = tensor.name.split("_")[-1]
shape = map(int, shape_str.split("x"))
ref_output = torch.tensor(dequantize(tensor.data, quant_type),
device="cuda").to(dtype)
output = ops.ggml_dequantize(torch.tensor(tensor.data, device="cuda"),
quant_type, *list(shape), dtype)
ref_output = torch.tensor(
dequantize(tensor.data, quant_type), device="cuda"
).to(dtype)
output = ops.ggml_dequantize(
torch.tensor(tensor.data, device="cuda"), quant_type, *list(shape), dtype
)
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=4e-2)
@@ -87,20 +90,21 @@ def test_dequantize(hidden_size: int, dtype: torch.dtype,
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
@torch.inference_mode()
def test_mmvq(hidden_size: int, dtype: torch.dtype,
quant_type: GGMLQuantizationType):
def test_mmvq(hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType):
current_platform.seed_everything(0)
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
x = torch.rand((1, hidden_size), dtype=dtype, device="cuda")
for tensor in tensors:
weight = torch.tensor(dequantize(tensor.data, quant_type),
device="cuda").to(dtype)
weight = torch.tensor(dequantize(tensor.data, quant_type), device="cuda").to(
dtype
)
ref_output = x @ weight.T
qweight = torch.tensor(tensor.data, device="cuda")
output = ops.ggml_mul_mat_vec_a8(qweight, x, quant_type,
qweight.shape[0]).to(dtype)
output = ops.ggml_mul_mat_vec_a8(qweight, x, quant_type, qweight.shape[0]).to(
dtype
)
torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)
@@ -121,17 +125,23 @@ def test_mmvq(hidden_size: int, dtype: torch.dtype,
GGMLQuantizationType.Q4_0,
GGMLQuantizationType.Q5_0,
GGMLQuantizationType.Q8_0,
])
],
)
@torch.inference_mode()
def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
quant_type: GGMLQuantizationType):
def test_mmq(
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
quant_type: GGMLQuantizationType,
):
current_platform.seed_everything(0)
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
x = torch.rand((num_tokens, hidden_size), dtype=dtype, device="cuda")
for tensor in tensors:
weight = torch.tensor(dequantize(tensor.data, quant_type),
device="cuda").to(dtype)
weight = torch.tensor(dequantize(tensor.data, quant_type), device="cuda").to(
dtype
)
ref_output = x @ weight.T
qweight = torch.tensor(tensor.data, device="cuda")
@@ -141,10 +151,9 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
# bfloat16 tends to accumulate and can greatly inflate rtol
# since outputs are also very close to 0
rtols = {torch.half: 1e-1, torch.bfloat16: 1e4, torch.float: 2e1}
torch.testing.assert_close(output,
ref_output,
atol=atols[dtype],
rtol=rtols[dtype])
torch.testing.assert_close(
output, ref_output, atol=atols[dtype], rtol=rtols[dtype]
)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@@ -153,35 +162,46 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
@torch.inference_mode()
def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype,
quant_type: GGMLQuantizationType, top_k: int):
def test_moe(
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
quant_type: GGMLQuantizationType,
top_k: int,
):
current_platform.seed_everything(0)
H, E = 1024, 256
x = torch.rand((num_tokens, H), dtype=dtype, device="cuda")
topk_weights = torch.rand(num_tokens, top_k, device="cuda", dtype=dtype)
topk_ids = torch.randint(0,
E, (num_tokens, top_k),
device="cuda",
dtype=torch.int32)
topk_ids = torch.randint(
0, E, (num_tokens, top_k), device="cuda", dtype=torch.int32
)
tensors = get_gguf_MoE_tensors(hidden_size, quant_type)
w13 = tensors[0]
w2 = tensors[1]
w13_dequant = torch.tensor(dequantize(w13.data, quant_type),
device="cuda").to(dtype)
w13_dequant = torch.tensor(dequantize(w13.data, quant_type), device="cuda").to(
dtype
)
w2_dequant = torch.tensor(dequantize(w2.data, quant_type),
device="cuda").to(dtype)
w2_dequant = torch.tensor(dequantize(w2.data, quant_type), device="cuda").to(dtype)
output = _fused_moe_gguf(x, torch.tensor(w13.data, device="cuda"),
torch.tensor(w2.data,
device="cuda"), topk_weights,
topk_ids, quant_type, quant_type, "silu")
output = _fused_moe_gguf(
x,
torch.tensor(w13.data, device="cuda"),
torch.tensor(w2.data, device="cuda"),
topk_weights,
topk_ids,
quant_type,
quant_type,
"silu",
)
ref_output = fused_experts(x, w13_dequant, w2_dequant, topk_weights,
topk_ids).reshape(output.shape)
ref_output = fused_experts(
x, w13_dequant, w2_dequant, topk_weights, topk_ids
).reshape(output.shape)
torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)

View File

@@ -8,25 +8,22 @@ from vllm import _custom_ops as ops # noqa: F401
def test_gptq_shuffle_opcheck():
weight = torch.randint(-2000000,
2000000, (1792, 4096),
device='cuda',
dtype=torch.int32)
perm = torch.empty((0, ), device='cuda', dtype=torch.int32)
weight = torch.randint(
-2000000, 2000000, (1792, 4096), device="cuda", dtype=torch.int32
)
perm = torch.empty((0,), device="cuda", dtype=torch.int32)
bit = 4
opcheck(torch.ops._C.gptq_shuffle, (weight, perm, bit))
def test_gptq_gemm_opcheck():
a = torch.rand((240, 4096), device='cuda', dtype=torch.float16)
weight = torch.randint(-2000000,
2000000, (512, 6144),
device='cuda',
dtype=torch.int32)
zeros = torch.zeros((32, 768), device='cuda', dtype=torch.int32)
scales = torch.rand((32, 6144), device='cuda', dtype=torch.float16)
idx = torch.empty((0, ), device='cuda', dtype=torch.int32)
a = torch.rand((240, 4096), device="cuda", dtype=torch.float16)
weight = torch.randint(
-2000000, 2000000, (512, 6144), device="cuda", dtype=torch.int32
)
zeros = torch.zeros((32, 768), device="cuda", dtype=torch.int32)
scales = torch.rand((32, 6144), device="cuda", dtype=torch.float16)
idx = torch.empty((0,), device="cuda", dtype=torch.int32)
use_exllama = True
bit = 4
opcheck(torch.ops._C.gptq_gemm,
(a, weight, zeros, scales, idx, use_exllama, bit))
opcheck(torch.ops._C.gptq_gemm, (a, weight, zeros, scales, idx, use_exllama, bit))

View File

@@ -15,7 +15,8 @@ from vllm import _custom_ops as ops
def test_hadacore(batch_size, hidden_dim, dtype=torch.bfloat16, device="cuda"):
x = torch.eye(hidden_dim, dtype=dtype, device=device)
hadamard = deterministic_hadamard_matrix(
hidden_dim, dtype=torch.float64, device="cuda") / math.sqrt(hidden_dim)
hidden_dim, dtype=torch.float64, device="cuda"
) / math.sqrt(hidden_dim)
y = ops.hadacore_transform(x.clone())
y_true = (x.to(hadamard.dtype) @ hadamard.T).to(y.dtype)

View File

@@ -11,12 +11,12 @@ from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_quant_int8)
per_token_quant_int8,
)
from vllm.platforms import current_platform
if current_platform.get_device_capability() < (7, 0):
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
allow_module_level=True)
pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True)
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
@@ -26,14 +26,13 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
assert B.ndim == 2 and B.is_contiguous(
), "B must be a 2D contiguous tensor"
assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor"
# Reshape input
M = A.numel() // A.shape[-1]
B = B.t() # Transpose weight matrix
N, K = B.shape
origin_C_shape = A.shape[:-1] + (K, )
origin_C_shape = A.shape[:-1] + (K,)
A = A.reshape(M, N)
# As is per-token [M, 1], Bs is per-column [1, K]
@@ -43,8 +42,7 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
return C.reshape(origin_C_shape).to(output_dtype)
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, topk_weight,
topk_ids):
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, topk_weight, topk_ids):
"""This function performs fused moe with per-column int8 quantization
using native torch."""
@@ -66,25 +64,22 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, topk_weight,
mask = topk_ids == i
if mask.sum():
# First MLP layer: note that a_s is now per-token
inter_out = native_w8a8_per_token_matmul(a_q[mask],
w1[i],
a_s[mask],
w1_s[i],
output_dtype=a.dtype)
inter_out = native_w8a8_per_token_matmul(
a_q[mask], w1[i], a_s[mask], w1_s[i], output_dtype=a.dtype
)
# Activation function
act_out = SiluAndMul().forward_native(inter_out)
# Quantize activation output with per-token
act_out_q, act_out_s = per_token_quant_int8(act_out)
# Second MLP layer
out[mask] = native_w8a8_per_token_matmul(act_out_q,
w2[i],
act_out_s,
w2_s[i],
output_dtype=a.dtype)
out[mask] = native_w8a8_per_token_matmul(
act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype
)
# Apply routing weights and sum
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
@pytest.fixture(autouse=True, scope="module")
@@ -102,8 +97,10 @@ TOP_KS = [2, 6]
SEEDS = [0]
@pytest.mark.parametrize("M, N, K, E, topk, dtype, seed",
itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS))
@pytest.mark.parametrize(
"M, N, K, E, topk, dtype, seed",
itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS),
)
@torch.inference_mode()
def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
torch.manual_seed(seed)
@@ -130,8 +127,9 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weights, topk_ids = torch.topk(score, topk)
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk,
topk_weights, topk_ids)
ref_out = torch_w8a8_per_column_moe(
a, w1, w2, w1_s, w2_s, topk, topk_weights, topk_ids
)
quant_config = FusedMoEQuantConfig.make(
torch.int8,
@@ -151,7 +149,7 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
)
# Check results
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
rel_diff = torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
assert rel_diff < 0.05

View File

@@ -18,26 +18,24 @@ SCALE = [0.1, 2.1]
def opcheck_int8_quant_static(output, input, scale, azp=None):
if azp is None:
opcheck(torch.ops._C.static_scaled_int8_quant,
(output, input, scale, None))
opcheck(torch.ops._C.static_scaled_int8_quant, (output, input, scale, None))
else:
opcheck(torch.ops._C.static_scaled_int8_quant,
(output, input, scale, azp))
opcheck(torch.ops._C.static_scaled_int8_quant, (output, input, scale, azp))
def opcheck_int8_quant_dynamic(output, input, symmetric=True):
scale = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
scale = torch.empty(
(input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32
)
if symmetric:
opcheck(torch.ops._C.dynamic_scaled_int8_quant,
(output, input, scale, None))
opcheck(torch.ops._C.dynamic_scaled_int8_quant, (output, input, scale, None))
else:
azp = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.int32)
opcheck(torch.ops._C.dynamic_scaled_int8_quant,
(output, input, scale, azp))
azp = torch.empty(
(input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.int32,
)
opcheck(torch.ops._C.dynamic_scaled_int8_quant, (output, input, scale, azp))
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@@ -45,8 +43,9 @@ def opcheck_int8_quant_dynamic(output, input, symmetric=True):
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
def test_dynamic_scaled_int8_quant(
num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int
) -> None:
current_platform.seed_everything(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
@@ -68,30 +67,31 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
def test_dynamic_scaled_int8_azp_quant(
num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int
) -> None:
current_platform.seed_everything(seed)
int8_traits = torch.iinfo(torch.int8)
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
device="cuda") * 1000 - 300
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - 300
x_token_max, _ = x.to(dtype=torch.float32).max(dim=1, keepdim=True)
x_token_min, _ = x.to(dtype=torch.float32).min(dim=1, keepdim=True)
# calculate scale and azp, and adjust the range
scales = (x_token_max - x_token_min) / torch.tensor(255.0)
azps = torch.round(torch.tensor(-128.0) - x_token_min / scales).to(
torch.int32)
azps = torch.round(torch.tensor(-128.0) - x_token_min / scales).to(torch.int32)
torch_out = ((x / scales).round() + azps).clamp(
int8_traits.min, int8_traits.max).to(torch.int8)
assert torch_out.min() >= int8_traits.min and torch_out.max(
) <= int8_traits.max
torch_out = (
((x / scales).round() + azps)
.clamp(int8_traits.min, int8_traits.max)
.to(torch.int8)
)
assert torch_out.min() >= int8_traits.min and torch_out.max() <= int8_traits.max
ops_out, scales_out, azp_out = scaled_int8_quant(x, symmetric=False)
if (not torch.allclose(scales_out, scales)):
if not torch.allclose(scales_out, scales):
print(torch.argmax(torch.abs(scales_out - scales)))
torch.testing.assert_close(scales_out, scales)
# big atol to account for rounding errors
@@ -108,17 +108,18 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("scale", SCALE)
@torch.inference_mode()
def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int,
scale: float) -> None:
def test_static_scaled_int8_quant(
num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int, scale: float
) -> None:
current_platform.seed_everything(seed)
int8_traits = torch.iinfo(torch.int8)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda")
out1 = (x / scale_arg).round().clamp(int8_traits.min,
int8_traits.max).to(torch.int8)
out1 = (
(x / scale_arg).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8)
)
out2, scale2, _ = scaled_int8_quant(x, scale_arg)
assert scale2 is scale_arg
@@ -135,24 +136,28 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
@pytest.mark.parametrize("scale", SCALE)
@pytest.mark.parametrize("azp", [-255, 54])
@torch.inference_mode()
def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int,
scale: float, azp: int) -> None:
def test_static_scaled_int8_azp_quant(
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
seed: int,
scale: float,
azp: int,
) -> None:
current_platform.seed_everything(seed)
int8_traits = torch.iinfo(torch.int8)
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
device="cuda") * 1000 - 300
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - 300
out1 = ((x / scale).round() + azp).clamp(int8_traits.min,
int8_traits.max).to(torch.int8)
out1 = (
((x / scale).round() + azp)
.clamp(int8_traits.min, int8_traits.max)
.to(torch.int8)
)
scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda")
azp_arg = torch.tensor([azp], dtype=torch.int32, device="cuda")
out2, scale2, azp2 = scaled_int8_quant(x,
scale_arg,
azp_arg,
symmetric=False)
out2, scale2, azp2 = scaled_int8_quant(x, scale_arg, azp_arg, symmetric=False)
assert scale2 is scale_arg
assert azp2 is azp_arg
@@ -172,10 +177,7 @@ def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None:
int32_traits = torch.iinfo(torch.int32)
val = float(int32_traits.max if is_max else int32_traits.min)
x_vals = [[
nextafter(val, inf), val + 1, val, val - 1,
nextafter(val, -inf)
]]
x_vals = [[nextafter(val, inf), val + 1, val, val - 1, nextafter(val, -inf)]]
x = torch.tensor(x_vals, dtype=torch.float32, device="cuda")
# The calculation in the kernel is: cast<int8>(cast<int32>(x / scale) + azp)

View File

@@ -15,15 +15,16 @@ import torch
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.machete_utils import (
query_machete_supported_group_sizes)
query_machete_supported_group_sizes,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_rows, quantize_weights)
pack_rows,
quantize_weights,
)
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
# unit tests to a common utility function. Currently the use of
@@ -72,29 +73,38 @@ class Tensors:
# Ch Scales Type, Tok Scales Type)
# NOTE: None "Scale Type" means the act type is floating point
# None "Output Type" means the output type is the same as the act type
TestTypeTuple = tuple[list[torch.dtype], ScalarType, Optional[torch.dtype],
Optional[torch.dtype], bool]
TestTypeTuple = tuple[
list[torch.dtype], ScalarType, Optional[torch.dtype], Optional[torch.dtype], bool
]
TEST_TYPES = [
# GPTQ style
*(TypeConfig(act_type=a_type,
weight_type=w_type,
output_type=None,
group_scale_type=a_type,
group_zero_type=None,
channel_scale_type=None,
token_scale_type=None)
for w_type in [scalar_types.uint4b8, scalar_types.uint8b128]
for a_type in [torch.float16, torch.bfloat16]),
*(
TypeConfig(
act_type=a_type,
weight_type=w_type,
output_type=None,
group_scale_type=a_type,
group_zero_type=None,
channel_scale_type=None,
token_scale_type=None,
)
for w_type in [scalar_types.uint4b8, scalar_types.uint8b128]
for a_type in [torch.float16, torch.bfloat16]
),
# AWQ style
*(TypeConfig(act_type=a_type,
weight_type=w_type,
output_type=None,
group_scale_type=a_type,
group_zero_type=a_type,
channel_scale_type=None,
token_scale_type=None)
for w_type in [scalar_types.uint4, scalar_types.uint8]
for a_type in [torch.float16, torch.bfloat16]),
*(
TypeConfig(
act_type=a_type,
weight_type=w_type,
output_type=None,
group_scale_type=a_type,
group_zero_type=a_type,
channel_scale_type=None,
token_scale_type=None,
)
for w_type in [scalar_types.uint4, scalar_types.uint8]
for a_type in [torch.float16, torch.bfloat16]
),
# # QQQ style
# *(TypeConfig(act_type=torch.int8,
# weight_type=scalar_types.uint4b8,
@@ -133,17 +143,18 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
return zps if zps is None else -1 * s * (zps.to(s.dtype))
def group_size_valid(shape: tuple[int, int, int],
group_size: Optional[int]) -> bool:
def group_size_valid(shape: tuple[int, int, int], group_size: Optional[int]) -> bool:
return group_size is None or group_size == -1 or shape[2] % group_size == 0
def machete_quantize_and_pack(atype: torch.dtype,
w: torch.Tensor,
wtype: ScalarType,
stype: Optional[torch.dtype],
group_size: Optional[int],
zero_points: bool = False):
def machete_quantize_and_pack(
atype: torch.dtype,
w: torch.Tensor,
wtype: ScalarType,
stype: Optional[torch.dtype],
group_size: Optional[int],
zero_points: bool = False,
):
assert wtype.is_integer(), "TODO: support floating point weights"
w_ref, w_q, w_s, w_zp = quantize_weights(
@@ -152,7 +163,8 @@ def machete_quantize_and_pack(atype: torch.dtype,
group_size=group_size,
zero_points=zero_points,
# to match how the kernel applies zps
ref_zero_points_after_scales=True)
ref_zero_points_after_scales=True,
)
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
w_q = w_q.t().contiguous().t() # convert to col major
@@ -163,15 +175,18 @@ def machete_quantize_and_pack(atype: torch.dtype,
return w_ref, w_q_machete, w_s, w_zp
def create_test_tensors(shape: tuple[int, int, int],
types: TypeConfig,
group_size: Optional[int],
subset_stride_factor: Optional[int] = None) -> Tensors:
def create_test_tensors(
shape: tuple[int, int, int],
types: TypeConfig,
group_size: Optional[int],
subset_stride_factor: Optional[int] = None,
) -> Tensors:
m, n, k = shape
factor = subset_stride_factor or 1
print("create_test_tensors, shape:", shape, "types:", types, "group_size:",
group_size)
print(
"create_test_tensors, shape:", shape, "types:", types, "group_size:", group_size
)
a = rand_data((m * factor, k * factor), types.act_type, scale=3, offset=2)
w = rand_data((k * factor, n * factor), types.act_type, scale=3, offset=1)
@@ -186,8 +201,13 @@ def create_test_tensors(shape: tuple[int, int, int],
w = w.to(torch.float16)
w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
a.dtype, w, types.weight_type, types.group_scale_type, group_size,
types.group_zero_type is not None)
a.dtype,
w,
types.weight_type,
types.group_scale_type,
group_size,
types.group_zero_type is not None,
)
if not a.dtype.is_floating_point:
aiinfo = torch.iinfo(a.dtype)
@@ -196,35 +216,47 @@ def create_test_tensors(shape: tuple[int, int, int],
a_ref = a.to(torch.float32)
w_ref = w_ref.to(torch.float32)
w_ch_s = None if types.channel_scale_type is None else\
rand_data((n,), types.channel_scale_type)
w_tok_s = None if types.token_scale_type is None else\
rand_data((m,), types.token_scale_type)
w_ch_s = (
None
if types.channel_scale_type is None
else rand_data((n,), types.channel_scale_type)
)
w_tok_s = (
None
if types.token_scale_type is None
else rand_data((m,), types.token_scale_type)
)
return Tensors(w_ref=w_ref,
a_ref=a_ref,
a=a,
w_q=w_q_packed,
w_g_s=w_s,
w_g_zp=maybe_convert_zeropoints(w_zp, w_s),
w_ch_s=w_ch_s,
w_tok_s=w_tok_s)
return Tensors(
w_ref=w_ref,
a_ref=a_ref,
a=a,
w_q=w_q_packed,
w_g_s=w_s,
w_g_zp=maybe_convert_zeropoints(w_zp, w_s),
w_ch_s=w_ch_s,
w_tok_s=w_tok_s,
)
# None stype means scales use the same dtype as a
def machete_mm_test_helper(types: TypeConfig,
tensors: Tensors,
group_size: Optional[int] = None,
schedule: Optional[str] = None):
def machete_mm_test_helper(
types: TypeConfig,
tensors: Tensors,
group_size: Optional[int] = None,
schedule: Optional[str] = None,
):
output_ref = torch.matmul(tensors.a_ref, tensors.w_ref)
output_ref_type = output_ref.dtype
if tensors.w_ch_s is not None:
output_ref = (output_ref.to(tensors.w_ch_s.dtype) *
tensors.w_ch_s.unsqueeze(0)).to(output_ref_type)
output_ref = (
output_ref.to(tensors.w_ch_s.dtype) * tensors.w_ch_s.unsqueeze(0)
).to(output_ref_type)
if tensors.w_tok_s is not None:
output_ref = (output_ref.to(tensors.w_tok_s.dtype) *
tensors.w_tok_s.unsqueeze(1)).to(output_ref_type)
output_ref = (
output_ref.to(tensors.w_tok_s.dtype) * tensors.w_tok_s.unsqueeze(1)
).to(output_ref_type)
output = ops.machete_mm(
a=tensors.a,
@@ -245,23 +277,23 @@ def machete_mm_test_helper(types: TypeConfig,
# Relax atol as our reduction dim becomes larger (more rounding error)
# Relax atol when we have zeropoints since the way machete applies
# zeropoints (after scales) causes noise around 0
atol = 1 if tensors.w_g_zp is not None\
atol = (
1
if tensors.w_g_zp is not None
else min(5e-2 * math.sqrt(tensors.a.shape[1]), 1)
)
rtol = 1e-1 if tensors.a.element_size() >= 2 else 2e-1
torch.testing.assert_close(output,
output_ref.to(output.dtype),
rtol=rtol,
atol=atol)
torch.testing.assert_close(
output, output_ref.to(output.dtype), rtol=rtol, atol=atol
)
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
reason="Machete is not supported on this GPU type.")
@pytest.mark.parametrize("shape",
MNK_SHAPES,
ids=lambda x: "x".join(str(v) for v in x))
@pytest.mark.skipif(
not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type."
)
@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x))
@pytest.mark.parametrize("types", TEST_TYPES)
def test_machete_all_schedules(shape, types: TypeConfig):
group_sizes: list[Optional[int]] = []
if types.group_scale_type is None:
group_sizes = [None]
@@ -275,20 +307,20 @@ def test_machete_all_schedules(shape, types: TypeConfig):
tensors = create_test_tensors(shape, types, group_size)
print(f"MNK = {shape}")
for schedule in ops.machete_supported_schedules(
types.act_type,
types.weight_type,
group_scales_type=types.group_scale_type,
group_zeros_type=types.group_scale_type,
out_type=types.output_type):
types.act_type,
types.weight_type,
group_scales_type=types.group_scale_type,
group_zeros_type=types.group_scale_type,
out_type=types.output_type,
):
print(f"Testing schedule {schedule}")
machete_mm_test_helper(types, tensors, group_size, schedule)
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
reason="Machete is not supported on this GPU type.")
@pytest.mark.parametrize("shape",
MNK_SHAPES,
ids=lambda x: "x".join(str(v) for v in x))
@pytest.mark.skipif(
not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type."
)
@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x))
@pytest.mark.parametrize("types", TEST_TYPES)
def test_machete_heuristic(shape, types: TypeConfig):
group_sizes: list[Optional[int]] = []
@@ -306,19 +338,22 @@ def test_machete_heuristic(shape, types: TypeConfig):
# Test working on other devices
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
reason="Machete is not supported on this GPU type.")
@pytest.mark.skipif(
not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type."
)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_machete_devices(device: str):
group_size = 128
type_config = TypeConfig(act_type=torch.float16,
weight_type=scalar_types.uint4b8,
output_type=None,
group_scale_type=torch.float16,
group_zero_type=None,
channel_scale_type=None,
token_scale_type=None)
type_config = TypeConfig(
act_type=torch.float16,
weight_type=scalar_types.uint4b8,
output_type=None,
group_scale_type=torch.float16,
group_zero_type=None,
channel_scale_type=None,
token_scale_type=None,
)
tensors = create_test_tensors((512, 4096, 4096), type_config, group_size)
@@ -331,29 +366,30 @@ def test_machete_devices(device: str):
# Test working with a subset of A and B
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
reason="Machete is not supported on this GPU type.")
@pytest.mark.skipif(
not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type."
)
def test_machete_subset():
group_size = 128
type_config = TypeConfig(act_type=torch.float16,
weight_type=scalar_types.uint4b8,
output_type=None,
group_scale_type=torch.float16,
group_zero_type=None,
channel_scale_type=None,
token_scale_type=None)
type_config = TypeConfig(
act_type=torch.float16,
weight_type=scalar_types.uint4b8,
output_type=None,
group_scale_type=torch.float16,
group_zero_type=None,
channel_scale_type=None,
token_scale_type=None,
)
tensors = create_test_tensors((512, 4096, 4096),
type_config,
group_size,
subset_stride_factor=2)
tensors = create_test_tensors(
(512, 4096, 4096), type_config, group_size, subset_stride_factor=2
)
machete_mm_test_helper(type_config, tensors, group_size)
# Test to make sure cuda graphs work
class MacheteLayer(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__()
self.kwargs = kwargs
@@ -362,8 +398,9 @@ class MacheteLayer(torch.nn.Module):
return ops.machete_mm(a=a, **self.kwargs)
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
reason="Machete is not supported on this GPU type.")
@pytest.mark.skipif(
not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type."
)
def test_machete_cuda_graph():
m, n, k = 512, 4096, 4096
@@ -375,7 +412,8 @@ def test_machete_cuda_graph():
zero_points = False
w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
a.dtype, b, wtype, stype, group_size, zero_points)
a.dtype, b, wtype, stype, group_size, zero_points
)
# Construct a trivial model with a single layer that calls a machete kernel
model = MacheteLayer(

View File

@@ -4,6 +4,7 @@
Run `pytest tests/kernels/quantization/test_marlin_gemm.py`.
"""
import pytest
import torch
@@ -11,24 +12,44 @@ from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from tests.quantization.utils import is_quant_method_supported
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
GPTQ_MARLIN_24_MAX_PARALLEL,
GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales,
query_marlin_supported_quant_types)
MARLIN_SUPPORTED_GROUP_SIZES,
marlin_make_empty_g_idx,
marlin_make_workspace_new,
marlin_permute_bias,
marlin_permute_scales,
query_marlin_supported_quant_types,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_mxfp4_like,
rand_marlin_weight_nvfp4_like)
FP4_MARLIN_SUPPORTED_GROUP_SIZES,
rand_marlin_weight_mxfp4_like,
rand_marlin_weight_nvfp4_like,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
marlin_quant_fp8_torch)
marlin_quant_fp8_torch,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize,
marlin_weights)
MarlinWorkspace,
awq_marlin_quantize,
get_weight_perm,
marlin_quantize,
marlin_weights,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
marlin_24_quantize)
marlin_24_quantize,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
awq_pack,
gptq_pack,
gptq_quantize_weights,
quantize_weights,
sort_weights,
)
from vllm.scalar_type import scalar_types
ACT_ORDER_OPTS = [False, True]
@@ -56,24 +77,27 @@ DTYPES = [torch.float16, torch.bfloat16]
def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref))
torch.abs(output_ref)
)
def rand_data(shape, dtype=torch.float16):
return torch.randn(shape, dtype=dtype, device="cuda")
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.skipif(
not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.",
)
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(False, False))
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False, False))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
act_order, mnk_factors):
def test_gptq_marlin_repack(
k_chunk, n_chunk, quant_type, group_size, act_order, mnk_factors
):
m_factor, n_factor, k_factor = mnk_factors
size_k = k_chunk * k_factor
@@ -96,7 +120,8 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
# Quantize (and apply act_order if provided)
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
b_weight, quant_type, group_size, act_order)
b_weight, quant_type, group_size, act_order
)
# Pack to GPTQ format
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
@@ -109,11 +134,14 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
# Pack to Marlin format
weight_perm = get_weight_perm(quant_type.size_bits)
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
weight_perm)
marlin_q_w_1 = marlin_weights(
q_w, size_k, size_n, quant_type.size_bits, weight_perm
)
opcheck(torch.ops._C.gptq_marlin_repack,
(q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits))
opcheck(
torch.ops._C.gptq_marlin_repack,
(q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits),
)
# Run Marlin repack GPU kernel
marlin_q_w_2 = ops.gptq_marlin_repack(
@@ -128,16 +156,16 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.skipif(
not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.",
)
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type",
query_marlin_supported_quant_types(True))
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(True))
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
mnk_factors):
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
m_factor, n_factor, k_factor = mnk_factors
size_k = k_chunk * k_factor
@@ -152,21 +180,22 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
b_weight = rand_data((size_k, size_n))
# Quantize
w_ref, q_w, s, zp = quantize_weights(b_weight,
quant_type,
group_size,
zero_points=True)
w_ref, q_w, s, zp = quantize_weights(
b_weight, quant_type, group_size, zero_points=True
)
# Pack to AWQ format
q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
# Pack to Marlin format
weight_perm = get_weight_perm(quant_type.size_bits)
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
weight_perm)
marlin_q_w_1 = marlin_weights(
q_w, size_k, size_n, quant_type.size_bits, weight_perm
)
opcheck(torch.ops._C.awq_marlin_repack,
(q_w_awq, size_k, size_n, quant_type.size_bits))
opcheck(
torch.ops._C.awq_marlin_repack, (q_w_awq, size_k, size_n, quant_type.size_bits)
)
# Run Marlin repack GPU kernel
marlin_q_w_2 = ops.awq_marlin_repack(
@@ -180,23 +209,34 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.skipif(
not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.",
)
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types())
@pytest.mark.parametrize(
"group_size",
set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES))
"group_size", set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES)
)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS)
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
mnk_factors, act_order, is_k_full, use_atomic_add,
use_fp32_reduce, dtype):
def test_gptq_marlin_gemm(
k_chunk,
n_chunk,
quant_type,
group_size,
mnk_factors,
act_order,
is_k_full,
use_atomic_add,
use_fp32_reduce,
dtype,
):
m_factor, n_factor, k_factor = mnk_factors
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
@@ -225,11 +265,13 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
return
if group_size == 16:
w_ref, marlin_q_w, marlin_s, marlin_s2 = \
rand_marlin_weight_nvfp4_like(b_weight.T, group_size)
w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_nvfp4_like(
b_weight.T, group_size
)
else:
w_ref, marlin_q_w, marlin_s = \
rand_marlin_weight_mxfp4_like(b_weight.T, group_size)
w_ref, marlin_q_w, marlin_s = rand_marlin_weight_mxfp4_like(
b_weight.T, group_size
)
marlin_s2 = None
g_idx = None
@@ -240,8 +282,7 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
return
if act_order:
return
w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(
b_weight.T, group_size)
w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(b_weight.T, group_size)
g_idx = None
sort_indices = None
marlin_zp = None
@@ -250,7 +291,8 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
if group_size == 16:
return
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
b_weight, quant_type, group_size)
b_weight, quant_type, group_size
)
g_idx = None
sort_indices = None
marlin_s2 = None
@@ -258,18 +300,37 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
if group_size == 16:
return
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, quant_type, group_size, act_order)
b_weight, quant_type, group_size, act_order
)
marlin_zp = None
marlin_s2 = None
workspace = marlin_make_workspace_new(w_ref.device)
opcheck(torch.ops._C.gptq_marlin_gemm,
(a_input, None, marlin_q_w, None, marlin_s, marlin_s2, marlin_zp,
g_idx, sort_indices, workspace, quant_type.id, a_input.shape[0],
b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add,
use_fp32_reduce, False),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
opcheck(
torch.ops._C.gptq_marlin_gemm,
(
a_input,
None,
marlin_q_w,
None,
marlin_s,
marlin_s2,
marlin_zp,
g_idx,
sort_indices,
workspace,
quant_type.id,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
is_k_full,
use_atomic_add,
use_fp32_reduce,
False,
),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
output = ops.gptq_marlin_gemm(
a_input,
@@ -302,23 +363,40 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
# TODO: find better way to test this?
@torch.compile(fullgraph=True)
def marlin_24_gemm_tester(a_input, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s, scratch, quant_type, size_m, size_n,
size_k):
return ops.gptq_marlin_24_gemm(a_input, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s, scratch, quant_type, size_m,
size_n, size_k)
def marlin_24_gemm_tester(
a_input,
marlin_24_q_w_comp,
marlin_24_meta,
marlin_24_s,
scratch,
quant_type,
size_m,
size_n,
size_k,
):
return ops.gptq_marlin_24_gemm(
a_input,
marlin_24_q_w_comp,
marlin_24_meta,
marlin_24_s,
scratch,
quant_type,
size_m,
size_n,
size_k,
)
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.skipif(
not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.",
)
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
mnk_factors):
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor
@@ -328,19 +406,31 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n))
(w_24_ref, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s) = marlin_24_quantize(b_weight, quant_type, group_size)
(w_24_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = marlin_24_quantize(
b_weight, quant_type, group_size
)
workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_MAX_PARALLEL)
workspace_24 = MarlinWorkspace(
size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
)
output_ref = torch.matmul(a_input, w_24_ref)
opcheck(torch.ops._C.gptq_marlin_24_gemm,
(a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s,
workspace_24.scratch, quant_type.id, a_input.shape[0],
b_weight.shape[1], a_input.shape[1]),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
opcheck(
torch.ops._C.gptq_marlin_24_gemm,
(
a_input,
marlin_24_q_w_comp,
marlin_24_meta,
marlin_24_s,
workspace_24.scratch,
quant_type.id,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
),
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
)
output = marlin_24_gemm_tester(
a_input,
@@ -361,8 +451,10 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
assert max_diff < 0.04
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.skipif(
not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.",
)
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("group_size", HQQ_SUPPORTED_GROUP_SIZES)
@@ -386,22 +478,22 @@ def test_hqq_marlin_gemm(
a_input = rand_data((size_m, size_k))
dev = a_input.device
b_weight = torch.randint(0,
10, (size_n, size_k),
dtype=torch.uint8,
device=dev)
b_weight = torch.randint(0, 10, (size_n, size_k), dtype=torch.uint8, device=dev)
scale = rand_data((size_n, size_k // group_size))
zero = rand_data((size_n, size_k // group_size))
gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n)
sort_indices = torch.empty(0, dtype=torch.int, device=dev)
marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n,
4).to(dev)
marlin_s = marlin_permute_scales(scale.transpose(1, 0), size_k, size_n,
group_size).to(dev)
marlin_zp = marlin_permute_scales(zero.transpose(1, 0), size_k, size_n,
group_size).to(dev)
marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n, 4).to(
dev
)
marlin_s = marlin_permute_scales(
scale.transpose(1, 0), size_k, size_n, group_size
).to(dev)
marlin_zp = marlin_permute_scales(
zero.transpose(1, 0), size_k, size_n, group_size
).to(dev)
g_idx = marlin_make_empty_g_idx(dev)
g_idx_sort_indices = marlin_make_empty_g_idx(dev)
@@ -433,8 +525,7 @@ def test_hqq_marlin_gemm(
s_flat = scale.reshape(-1, 1)
dequant = (b_flat - zp_flat) * s_flat
output_ref = torch.matmul(a_input,
dequant.reshape(b_weight.shape).transpose(1, 0))
output_ref = torch.matmul(a_input, dequant.reshape(b_weight.shape).transpose(1, 0))
torch.cuda.synchronize()
@@ -451,11 +542,12 @@ def test_marlin_gemm_subset_input():
big_m = size_m * 2
big_k = size_k * 2
a_input = rand_data((big_m, big_k))[8:size_m + 8, 8:size_k + 8]
a_input = rand_data((big_m, big_k))[8 : size_m + 8, 8 : size_k + 8]
b_weight = rand_data((size_k, size_n))
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, quant_type, group_size, False)
b_weight, quant_type, group_size, False
)
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
workspace = marlin_make_workspace_new(a_input.device)
@@ -497,12 +589,13 @@ def test_marlin_gemm_with_bias(size_m):
size_k, size_n = 1024, 2048
a_input = rand_data((size_m, size_k))
b_weight = rand_data((size_k, size_n))
b_bias = rand_data((size_n, )) * 10
b_bias = rand_data((size_n,)) * 10
marlin_bias = marlin_permute_bias(b_bias)
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, quant_type, group_size, False)
b_weight, quant_type, group_size, False
)
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
workspace = marlin_make_workspace_new(a_input.device)

View File

@@ -8,15 +8,27 @@ from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
if not current_platform.has_device_capability(100):
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True)
pytest.skip(
reason="Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True,
)
DTYPES = [torch.float16, torch.bfloat16]
SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)]
PAD_SHAPES = [(90, 64), (150, 64), (128, 48), (128, 80), (150, 80), (90, 48),
(90, 128), (150, 128), (150, 48), (90, 80)]
PAD_SHAPES = [
(90, 64),
(150, 64),
(128, 48),
(128, 80),
(150, 80),
(90, 48),
(90, 128),
(150, 128),
(150, 48),
(90, 80),
]
SEEDS = [42]
CUDA_DEVICES = ['cuda:0']
CUDA_DEVICES = ["cuda:0"]
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
@@ -31,7 +43,22 @@ FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
# 0001 -> 0.5
# 0000 -> 0
E2M1_TO_FLOAT32 = [
0., 0.5, 1., 1.5, 2., 3., 4., 6., 0., -0.5, -1., -1.5, -2., -3., -4., -6.
0.0,
0.5,
1.0,
1.5,
2.0,
3.0,
4.0,
6.0,
0.0,
-0.5,
-1.0,
-1.5,
-2.0,
-3.0,
-4.0,
-6.0,
]
BLOCK_SIZE = 16
@@ -74,8 +101,7 @@ def ref_nvfp4_quant(x, global_scale):
assert x.ndim == 2
m, n = x.shape
x = torch.reshape(x, (m, n // BLOCK_SIZE, BLOCK_SIZE))
vec_max = torch.max(torch.abs(x), dim=-1,
keepdim=True)[0].to(torch.float32)
vec_max = torch.max(torch.abs(x), dim=-1, keepdim=True)[0].to(torch.float32)
scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX))
scale = scale.to(torch.float8_e4m3fn).to(torch.float32)
output_scale = get_reciprocal(scale * get_reciprocal(global_scale))
@@ -131,7 +157,7 @@ def test_quantize_to_fp4(
def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
dtype = torch.float16
current_platform.seed_everything(42)
torch.set_default_device('cuda:0')
torch.set_default_device("cuda:0")
m, n = pad_shape

View File

@@ -2,15 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, dequantize_nvfp4_to_dtype
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
if not current_platform.has_device_capability(100):
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True)
pytest.skip(
reason="Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True,
)
DTYPES = [torch.float16, torch.bfloat16]
# m, n, k
@@ -19,26 +20,31 @@ PAD_SHAPES = [(150, 128, 64), (128, 128, 96)]
SHAPES.extend(PAD_SHAPES)
SEEDS = [42]
CUDA_DEVICES = ['cuda:0']
CUDA_DEVICES = ["cuda:0"]
def get_ref_results(a_fp4, b_fp4, a_sf, b_sf, a_global_scale, b_global_scale,
m, n, dtype, block_size, device):
def get_ref_results(
a_fp4,
b_fp4,
a_sf,
b_sf,
a_global_scale,
b_global_scale,
m,
n,
dtype,
block_size,
device,
):
_, m_k = a_fp4.shape
_, n_k = b_fp4.shape
assert (m_k == n_k)
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
a_sf,
a_global_scale,
dtype=dtype,
device=device,
block_size=block_size)
b_in_dtype = dequantize_nvfp4_to_dtype(b_fp4,
b_sf,
b_global_scale,
dtype=dtype,
device=device,
block_size=block_size)
assert m_k == n_k
a_in_dtype = dequantize_nvfp4_to_dtype(
a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size
)
b_in_dtype = dequantize_nvfp4_to_dtype(
b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size
)
return torch.matmul(a_in_dtype, b_in_dtype.t())
@@ -60,25 +66,34 @@ def test_nvfp4_gemm(
a_dtype = torch.randn((m, k), dtype=dtype, device=device)
b_dtype = torch.randn((n, k), dtype=dtype, device=device)
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(a_dtype.flatten(), dim=-1)).to(torch.float32)
b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32)
alpha = 1. / (a_global_scale * b_global_scale)
a_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1)
).to(torch.float32)
b_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)
).to(torch.float32)
alpha = 1.0 / (a_global_scale * b_global_scale)
# ops.scaled_fp4_quant returns swizzled scales, while weights
# from checkpoints are in linear scales.
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale)
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale)
# get_ref_results unswizzles the scales internally.
expected_out = get_ref_results(a_fp4, b_fp4, a_scale_interleaved,
b_scale_interleaved, a_global_scale,
b_global_scale, m, n, dtype, block_size,
device)
out = ops.cutlass_scaled_fp4_mm(a_fp4, b_fp4, a_scale_interleaved,
b_scale_interleaved, alpha, dtype)
expected_out = get_ref_results(
a_fp4,
b_fp4,
a_scale_interleaved,
b_scale_interleaved,
a_global_scale,
b_global_scale,
m,
n,
dtype,
block_size,
device,
)
out = ops.cutlass_scaled_fp4_mm(
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
)
torch.testing.assert_close(out,
expected_out.to(dtype=dtype),
atol=1e-1,
rtol=1e-1)
torch.testing.assert_close(out, expected_out.to(dtype=dtype), atol=1e-1, rtol=1e-1)

View File

@@ -13,15 +13,15 @@ from vllm.model_executor.layers.quantization.utils import fp8_utils, int8_utils
@pytest.mark.parametrize("scale_ue8m0", [False, True])
@pytest.mark.parametrize("group_size", [64, 128])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_per_token_group_quant_fp8(shape, column_major: bool,
scale_ue8m0: bool, group_size: int):
def test_per_token_group_quant_fp8(
shape, column_major: bool, scale_ue8m0: bool, group_size: int
):
device = "cuda"
torch.manual_seed(42)
num_tokens, hidden_dim = shape
x = (torch.randn(
(num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8)
x = torch.randn((num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8
# cuda path
out_q, scale = fp8_utils.per_token_group_quant_fp8(
@@ -53,8 +53,7 @@ def test_per_token_group_quant_int8(shape, group_size: int):
torch.manual_seed(42)
num_tokens, hidden_dim = shape
x = (torch.randn(
(num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8)
x = torch.randn((num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8
# cuda path
out_q, scale = int8_utils.per_token_group_quant_int8(

View File

@@ -63,12 +63,11 @@ SEEDS = [0]
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16])
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="only test for rocm")
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
@torch.inference_mode()
def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
torch.manual_seed(seed)
#TODO: Zero-centering the inputs causes errors for LLMM1!
# TODO: Zero-centering the inputs causes errors for LLMM1!
# Without that the numbers quickly saturate, and may
# be giving false matches.
A = torch.rand(n, k, dtype=dtype, device="cuda")
@@ -83,14 +82,13 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="only test for rocm")
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = current_platform.get_cu_count()
A = torch.rand(n, k, dtype=dtype, device="cuda") - .5
B = torch.rand(m, k, dtype=dtype, device="cuda") - .5
A = torch.rand(n, k, dtype=dtype, device="cuda") - 0.5
B = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5
ref_out = torch.nn.functional.linear(A, B)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count)
@@ -101,16 +99,15 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="only test for rocm")
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = current_platform.get_cu_count()
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") - .5) * xavier
B = (torch.rand(m, k, dtype=dtype, device="cuda") - .5) * xavier
BIAS = torch.rand(m, dtype=dtype, device="cuda") - .5
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier
BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5
ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
@@ -121,16 +118,15 @@ def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="only test for rocm")
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
cu_count = current_platform.get_cu_count()
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, dtype=dtype, device="cuda") - .5) * xavier
B = (torch.rand(m, k, dtype=dtype, device="cuda") - .5) * xavier
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - .5
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - 0.5
ref_out = torch.nn.functional.linear(A, B, BIAS)
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
@@ -143,7 +139,8 @@ def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(
not (current_platform.is_rocm() and current_platform.supports_fp8()),
reason="only test for rocm fp8")
reason="only test for rocm fp8",
)
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
@@ -153,13 +150,10 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
ref_out = torch._scaled_mm(A,
B.t(),
out_dtype=dtype,
scale_a=scale_a,
scale_b=scale_b)
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b,
current_platform.get_cu_count())
ref_out = torch._scaled_mm(
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b
)
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, current_platform.get_cu_count())
assert torch.allclose(out, ref_out, rtol=0.01)
@@ -169,25 +163,24 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(
not (current_platform.is_rocm() and current_platform.supports_fp8()),
reason="only test for rocm fp8")
reason="only test for rocm fp8",
)
def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed):
torch.manual_seed(seed)
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
A = (torch.rand(n, k, device="cuda") - .5) * xavier
B = (torch.rand(m, k, device="cuda") - .5) * xavier
BIAS = torch.rand(m, dtype=dtype, device="cuda") - .5
A = (torch.rand(n, k, device="cuda") - 0.5) * xavier
B = (torch.rand(m, k, device="cuda") - 0.5) * xavier
BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
ref_out = torch._scaled_mm(A,
B.t(),
out_dtype=dtype,
scale_a=scale_a,
scale_b=scale_b,
bias=BIAS)
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b,
current_platform.get_cu_count(), BIAS)
ref_out = torch._scaled_mm(
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS
)
out = ops.wvSplitKQ(
B, A, dtype, scale_a, scale_b, current_platform.get_cu_count(), BIAS
)
assert torch.allclose(out, ref_out, rtol=0.01)

View File

@@ -3,16 +3,20 @@
import pytest
import torch
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from tests.kernels.quantization.nvfp4_utils import (
FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype,
)
from vllm._custom_ops import scaled_fp4_quant
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.platforms import current_platform
if not current_platform.has_device_capability(100):
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True)
pytest.skip(
reason="Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True,
)
FP4_DTYPE = torch.uint8
FP8_DTYPE = current_platform.fp8_dtype()
@@ -30,24 +34,24 @@ def test_silu_mul_nvfp4_quant(
shape: tuple[int, int],
) -> None:
current_platform.seed_everything(42)
device = 'cuda:0'
device = "cuda:0"
torch.set_default_device(device)
x = torch.randn(shape, dtype=dtype)
# ref op
ref_output = SiluAndMul().forward_native(x)
ref_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.abs(ref_output).max().to(torch.float32))
ref_output_quant, ref_block_scale = scaled_fp4_quant(
ref_output, ref_global_scale)
ref_global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(
ref_output
).max().to(torch.float32)
ref_output_quant, ref_block_scale = scaled_fp4_quant(ref_output, ref_global_scale)
# fused op
fused_output_quant = torch.empty_like(ref_output_quant)
fused_block_scale = torch.empty_like(ref_block_scale)
torch.ops._C.silu_and_mul_nvfp4_quant(fused_output_quant,
fused_block_scale, x,
ref_global_scale)
torch.ops._C.silu_and_mul_nvfp4_quant(
fused_output_quant, fused_block_scale, x, ref_global_scale
)
# check dtype
assert ref_output_quant.dtype == FP4_DTYPE
@@ -59,17 +63,14 @@ def test_silu_mul_nvfp4_quant(
assert ref_block_scale.shape == fused_block_scale.shape
# check dequantized output
ref_output_dequant = dequantize_nvfp4_to_dtype(ref_output_quant,
ref_block_scale,
ref_global_scale, dtype,
device)
fused_output_dequant = dequantize_nvfp4_to_dtype(fused_output_quant,
fused_block_scale,
ref_global_scale, dtype,
device)
ref_output_dequant = dequantize_nvfp4_to_dtype(
ref_output_quant, ref_block_scale, ref_global_scale, dtype, device
)
fused_output_dequant = dequantize_nvfp4_to_dtype(
fused_output_quant, fused_block_scale, ref_global_scale, dtype, device
)
atol, rtol = 3e-1, 3e-1
torch.testing.assert_close(ref_output_dequant,
fused_output_dequant,
atol=atol,
rtol=rtol)
torch.testing.assert_close(
ref_output_dequant, fused_output_dequant, atol=atol, rtol=rtol
)

View File

@@ -4,6 +4,7 @@
Run `pytest tests/kernels/quantization/test_triton_scaled_mm.py`.
"""
import importlib
from typing import Optional
@@ -15,17 +16,19 @@ from vllm.platforms import current_platform
device = "cuda"
triton_scaled_mm_module = importlib.import_module(
"vllm.model_executor.layers.quantization.compressed_tensors."
"triton_scaled_mm")
"vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm"
)
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
def torch_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: type[torch.dtype],
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def torch_scaled_mm(
a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: type[torch.dtype],
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
out = torch.mm(a.to(torch.float32), b.to(torch.float32))
out = scale_a * out
out = scale_b.T * out
@@ -44,20 +47,22 @@ def get_8bit_types():
# This test is to check regressions for int8 support on ROCm.
@pytest.mark.parametrize("model_path", [
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
])
@pytest.mark.parametrize(
"model_path",
[
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
],
)
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10])
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="Should only run on ROCm")
def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path,
max_tokens, num_logprobs):
@pytest.mark.skipif(not current_platform.is_rocm(), reason="Should only run on ROCm")
def test_rocm_compressed_tensors_w8a8(
vllm_runner, example_prompts, model_path, max_tokens, num_logprobs
):
dtype = "bfloat16"
with vllm_runner(model_path, dtype=dtype) as vllm_model:
vllm_model.generate_greedy_logprobs(example_prompts, max_tokens,
num_logprobs)
vllm_model.generate_greedy_logprobs(example_prompts, max_tokens, num_logprobs)
MNK_FACTORS = [
@@ -76,10 +81,10 @@ MNK_FACTORS = [
@pytest.mark.parametrize("use_scalar_scale_a", [True, False])
@pytest.mark.parametrize("use_scalar_scale_b", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a,
use_scalar_scale_b, use_bias):
is_floating_point_type = lambda t: torch.tensor([1, 1], dtype=t
).is_floating_point()
def test_scaled_mm(
M, N, K, in_dtype, out_dtype, use_scalar_scale_a, use_scalar_scale_b, use_bias
):
is_floating_point_type = lambda t: torch.tensor([1, 1], dtype=t).is_floating_point()
current_platform.seed_everything(0)
@@ -93,10 +98,8 @@ def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a,
#
# So, the values here are kept small enough to avoid this situation.
if is_floating_point_type(in_dtype):
a = (0.25 * torch.rand(
(M, K), dtype=torch.float32, device=device)).to(in_dtype)
b = (0.25 * torch.rand(
(K, N), dtype=torch.float32, device=device)).to(in_dtype)
a = (0.25 * torch.rand((M, K), dtype=torch.float32, device=device)).to(in_dtype)
b = (0.25 * torch.rand((K, N), dtype=torch.float32, device=device)).to(in_dtype)
else:
a = torch.randint(-32, 32, (M, K), dtype=in_dtype, device=device)
b = torch.randint(-32, 32, (K, N), dtype=in_dtype, device=device)
@@ -113,7 +116,7 @@ def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a,
bias = None
if use_bias:
bias = torch.rand((N, ), device=device, dtype=out_dtype)
bias = torch.rand((N,), device=device, dtype=out_dtype)
c_check = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)