Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -8,8 +8,9 @@ from vllm.scalar_type import scalar_types
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.],
|
||||
dtype=torch.float32)
|
||||
kE2M1ToFloat = torch.tensor(
|
||||
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
|
||||
)
|
||||
|
||||
|
||||
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
|
||||
@@ -22,12 +23,9 @@ def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
|
||||
return out[0:m, 0:k]
|
||||
|
||||
|
||||
def dequantize_nvfp4_to_dtype(tensor_fp4,
|
||||
tensor_sf,
|
||||
global_scale,
|
||||
dtype,
|
||||
device,
|
||||
block_size=16):
|
||||
def dequantize_nvfp4_to_dtype(
|
||||
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
|
||||
):
|
||||
"""Dequantize the fp4 tensor back to high precision."""
|
||||
# Two fp4 values are packed into one uint8.
|
||||
assert tensor_fp4.dtype == torch.uint8
|
||||
@@ -69,7 +67,8 @@ def break_fp4_bytes(a, dtype):
|
||||
|
||||
|
||||
def quant_nvfp4_tensor(a: torch.Tensor):
|
||||
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||
torch.abs(a).max().to(torch.float32))
|
||||
a_global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to(
|
||||
torch.float32
|
||||
)
|
||||
a_quant, a_block_scale = scaled_fp4_quant(a, a_global_scale)
|
||||
return a_quant, a_block_scale, a_global_scale
|
||||
|
||||
@@ -6,24 +6,25 @@ import torch
|
||||
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
|
||||
ALLSPARK_AMPERE_K_ALIGN, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
ALLSPARK_AMPERE_N_ALIGN)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
quantize_weights)
|
||||
ALLSPARK_AMPERE_K_ALIGN,
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
ALLSPARK_AMPERE_N_ALIGN,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import quantize_weights
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
|
||||
def is_gptq_allspark_supported(min_capability: int,
|
||||
max_capability: int) -> bool:
|
||||
def is_gptq_allspark_supported(min_capability: int, max_capability: int) -> bool:
|
||||
if not current_platform.is_cuda():
|
||||
return False
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
assert capability is not None
|
||||
|
||||
return capability.to_int() >= min_capability \
|
||||
and capability.to_int() <= max_capability
|
||||
return (
|
||||
capability.to_int() >= min_capability and capability.to_int() <= max_capability
|
||||
)
|
||||
|
||||
|
||||
MNK_FACTORS = [
|
||||
@@ -43,7 +44,8 @@ HAS_ZP_OPTS = [False, True]
|
||||
|
||||
def compute_max_diff(output, output_ref):
|
||||
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
||||
torch.abs(output_ref))
|
||||
torch.abs(output_ref)
|
||||
)
|
||||
|
||||
|
||||
def rand_data(shape, dtype=torch.float16):
|
||||
@@ -52,7 +54,8 @@ def rand_data(shape, dtype=torch.float16):
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_gptq_allspark_supported(80, 89),
|
||||
reason="AllSpark Ampere kernel is not supported on this GPU type.")
|
||||
reason="AllSpark Ampere kernel is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("group_size", [-1])
|
||||
@pytest.mark.parametrize("has_zp", HAS_ZP_OPTS)
|
||||
@@ -67,8 +70,9 @@ def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype):
|
||||
weight = rand_data((k, n), dtype=dtype)
|
||||
|
||||
# Quantize (and apply act_order if provided)
|
||||
w_ref, qw, s, zp = quantize_weights(weight, scalar_types.uint8b128,
|
||||
group_size, has_zp)
|
||||
w_ref, qw, s, zp = quantize_weights(
|
||||
weight, scalar_types.uint8b128, group_size, has_zp
|
||||
)
|
||||
|
||||
qw = qw.to(torch.uint8)
|
||||
if has_zp:
|
||||
@@ -79,20 +83,42 @@ def test_gptq_allspark_gemm_ampere(mnk_factors, group_size, has_zp, dtype):
|
||||
|
||||
n_32align = (n + 32 - 1) // 32 * 32
|
||||
|
||||
qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(
|
||||
qw, s, zp, has_zp)
|
||||
opcheck(torch.ops._C.rearrange_kn_weight_as_n32k16_order,
|
||||
(qw, s, zp, has_zp, qw_reorder, s_reorder, zp_reorder, k, n,
|
||||
n_32align))
|
||||
qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(qw, s, zp, has_zp)
|
||||
opcheck(
|
||||
torch.ops._C.rearrange_kn_weight_as_n32k16_order,
|
||||
(qw, s, zp, has_zp, qw_reorder, s_reorder, zp_reorder, k, n, n_32align),
|
||||
)
|
||||
|
||||
opcheck(torch.ops._C.allspark_w8a16_gemm,
|
||||
(input, qw_reorder, s_reorder, zp_reorder, n, group_size, sm_count,
|
||||
sm_version, ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, has_zp, True),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||
output = ops.allspark_w8a16_gemm(input, qw_reorder, s_reorder, zp_reorder,
|
||||
n, group_size, sm_count, sm_version,
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
has_zp, True)
|
||||
opcheck(
|
||||
torch.ops._C.allspark_w8a16_gemm,
|
||||
(
|
||||
input,
|
||||
qw_reorder,
|
||||
s_reorder,
|
||||
zp_reorder,
|
||||
n,
|
||||
group_size,
|
||||
sm_count,
|
||||
sm_version,
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
has_zp,
|
||||
True,
|
||||
),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
output = ops.allspark_w8a16_gemm(
|
||||
input,
|
||||
qw_reorder,
|
||||
s_reorder,
|
||||
zp_reorder,
|
||||
n,
|
||||
group_size,
|
||||
sm_count,
|
||||
sm_version,
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
has_zp,
|
||||
True,
|
||||
)
|
||||
|
||||
output_ref = torch.matmul(input, w_ref)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@@ -8,40 +8,42 @@ from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops # noqa: F401
|
||||
|
||||
|
||||
@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_dequantize"),
|
||||
reason="AWQ is not supported on this GPU type.")
|
||||
@pytest.mark.skipif(
|
||||
not hasattr(torch.ops._C, "awq_dequantize"),
|
||||
reason="AWQ is not supported on this GPU type.",
|
||||
)
|
||||
def test_awq_dequantize_opcheck(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_TRITON_AWQ", "0")
|
||||
qweight = torch.randint(-2000000000,
|
||||
2000000000, (8192, 256),
|
||||
device='cuda',
|
||||
dtype=torch.int32)
|
||||
scales = torch.rand((64, 2048), device='cuda', dtype=torch.float16)
|
||||
zeros = torch.empty((64, 256), device='cuda', dtype=torch.int32)
|
||||
qweight = torch.randint(
|
||||
-2000000000, 2000000000, (8192, 256), device="cuda", dtype=torch.int32
|
||||
)
|
||||
scales = torch.rand((64, 2048), device="cuda", dtype=torch.float16)
|
||||
zeros = torch.empty((64, 256), device="cuda", dtype=torch.int32)
|
||||
split_k_iters = 0
|
||||
thx = 0
|
||||
thy = 0
|
||||
opcheck(torch.ops._C.awq_dequantize,
|
||||
(qweight, scales, zeros, split_k_iters, thx, thy))
|
||||
opcheck(
|
||||
torch.ops._C.awq_dequantize,
|
||||
(qweight, scales, zeros, split_k_iters, thx, thy),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Not working; needs investigation.")
|
||||
@pytest.mark.skipif(not hasattr(torch.ops._C, "awq_gemm"),
|
||||
reason="AWQ is not supported on this GPU type.")
|
||||
@pytest.mark.skipif(
|
||||
not hasattr(torch.ops._C, "awq_gemm"),
|
||||
reason="AWQ is not supported on this GPU type.",
|
||||
)
|
||||
def test_awq_gemm_opcheck(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_TRITON_AWQ", "0")
|
||||
input = torch.rand((2, 8192), device='cuda', dtype=torch.float16)
|
||||
qweight = torch.randint(-2000000000,
|
||||
2000000000, (8192, 256),
|
||||
device='cuda',
|
||||
dtype=torch.int32)
|
||||
scales = torch.randint(-2000000000,
|
||||
2000000000, (64, 256),
|
||||
device='cuda',
|
||||
dtype=torch.int32)
|
||||
qzeros = torch.empty((64, 2048), device='cuda', dtype=torch.float16)
|
||||
input = torch.rand((2, 8192), device="cuda", dtype=torch.float16)
|
||||
qweight = torch.randint(
|
||||
-2000000000, 2000000000, (8192, 256), device="cuda", dtype=torch.int32
|
||||
)
|
||||
scales = torch.randint(
|
||||
-2000000000, 2000000000, (64, 256), device="cuda", dtype=torch.int32
|
||||
)
|
||||
qzeros = torch.empty((64, 2048), device="cuda", dtype=torch.float16)
|
||||
split_k_iters = 8
|
||||
opcheck(torch.ops._C.awq_gemm,
|
||||
(input, qweight, qzeros, scales, split_k_iters))
|
||||
opcheck(torch.ops._C.awq_gemm, (input, qweight, qzeros, scales, split_k_iters))
|
||||
|
||||
@@ -4,11 +4,15 @@
|
||||
|
||||
Run `pytest tests/kernels/quantization/test_awq_triton.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.awq_triton import (
|
||||
AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton)
|
||||
AWQ_TRITON_SUPPORTED_GROUP_SIZES,
|
||||
awq_dequantize_triton,
|
||||
awq_gemm_triton,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
device = "cuda"
|
||||
@@ -33,23 +37,24 @@ def reverse_awq_order(t: torch.Tensor):
|
||||
# qweights - [R , C // 8], int32
|
||||
# scales - [R // G, C ], float16
|
||||
# zeros - [R // G, C // 8], int32
|
||||
def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor,
|
||||
qzeros: torch.Tensor,
|
||||
group_size: int) -> torch.Tensor:
|
||||
|
||||
def awq_dequantize_torch(
|
||||
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, group_size: int
|
||||
) -> torch.Tensor:
|
||||
if group_size == -1:
|
||||
group_size = qweight.shape[0]
|
||||
|
||||
bits = 4
|
||||
shifts = torch.arange(0, 32, bits, device=qzeros.device)
|
||||
|
||||
iweights = torch.bitwise_right_shift(qweight[:, :, None],
|
||||
shifts[None, None, :]).to(torch.int8)
|
||||
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
|
||||
torch.int8
|
||||
)
|
||||
|
||||
iweights = iweights.view(iweights.shape[0], -1)
|
||||
|
||||
zeros = torch.bitwise_right_shift(qzeros[:, :, None],
|
||||
shifts[None, None, :]).to(torch.int8)
|
||||
zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(
|
||||
torch.int8
|
||||
)
|
||||
zeros = zeros.view(qzeros.shape[0], -1)
|
||||
zeros = reverse_awq_order(zeros)
|
||||
|
||||
@@ -70,7 +75,6 @@ def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor,
|
||||
@pytest.mark.parametrize("qweight_cols", [448, 576, 4736, 16, 32, 64, 128])
|
||||
@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES)
|
||||
def test_dequantize(qweight_rows, qweight_cols, group_size):
|
||||
|
||||
if group_size == -1:
|
||||
group_size = qweight_rows
|
||||
|
||||
@@ -84,25 +88,27 @@ def test_dequantize(qweight_rows, qweight_cols, group_size):
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
qweight = torch.randint(0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(qweight_rows, qweight_cols),
|
||||
dtype=qweight_dtype,
|
||||
device=device)
|
||||
scales = torch.rand(scales_rows,
|
||||
scales_cols,
|
||||
dtype=scales_dtype,
|
||||
device=device)
|
||||
zeros = torch.randint(0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(zeros_rows, zeros_cols),
|
||||
dtype=zeros_dtype,
|
||||
device=device)
|
||||
qweight = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(qweight_rows, qweight_cols),
|
||||
dtype=qweight_dtype,
|
||||
device=device,
|
||||
)
|
||||
scales = torch.rand(scales_rows, scales_cols, dtype=scales_dtype, device=device)
|
||||
zeros = torch.randint(
|
||||
0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(zeros_rows, zeros_cols),
|
||||
dtype=zeros_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
iweights_triton = awq_dequantize_triton(qweight, scales, zeros)
|
||||
|
||||
assert (not torch.any(torch.isinf(iweights_triton))
|
||||
and not torch.any(torch.isnan(iweights_triton)))
|
||||
assert not torch.any(torch.isinf(iweights_triton)) and not torch.any(
|
||||
torch.isnan(iweights_triton)
|
||||
)
|
||||
|
||||
iweights_torch = awq_dequantize_torch(qweight, scales, zeros, group_size)
|
||||
|
||||
@@ -119,7 +125,6 @@ def test_dequantize(qweight_rows, qweight_cols, group_size):
|
||||
@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("splitK", [1, 8])
|
||||
def test_gemm(N, K, M, splitK, group_size):
|
||||
|
||||
if group_size == -1:
|
||||
group_size = K
|
||||
|
||||
@@ -138,35 +143,29 @@ def test_gemm(N, K, M, splitK, group_size):
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
input = torch.rand((input_rows, input_cols),
|
||||
dtype=input_dtype,
|
||||
device=device)
|
||||
qweight = torch.randint(0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(qweight_rows, qweight_cols),
|
||||
device=device)
|
||||
qzeros = torch.randint(0,
|
||||
torch.iinfo(torch.int32).max,
|
||||
(qzeros_rows, qzeros_cols),
|
||||
device=device)
|
||||
scales = torch.rand((scales_rows, scales_cols),
|
||||
dtype=scales_dtype,
|
||||
device=device)
|
||||
input = torch.rand((input_rows, input_cols), dtype=input_dtype, device=device)
|
||||
qweight = torch.randint(
|
||||
0, torch.iinfo(torch.int32).max, (qweight_rows, qweight_cols), device=device
|
||||
)
|
||||
qzeros = torch.randint(
|
||||
0, torch.iinfo(torch.int32).max, (qzeros_rows, qzeros_cols), device=device
|
||||
)
|
||||
scales = torch.rand((scales_rows, scales_cols), dtype=scales_dtype, device=device)
|
||||
|
||||
output_triton = awq_gemm_triton(input, qweight, scales, qzeros,
|
||||
split_k_iters)
|
||||
output_triton = awq_gemm_triton(input, qweight, scales, qzeros, split_k_iters)
|
||||
|
||||
assert (not torch.any(torch.isinf(output_triton))
|
||||
and not torch.any(torch.isnan(output_triton)))
|
||||
assert not torch.any(torch.isinf(output_triton)) and not torch.any(
|
||||
torch.isnan(output_triton)
|
||||
)
|
||||
|
||||
dequantized_weights = awq_dequantize_triton(qweight, scales, qzeros)
|
||||
|
||||
output_torch = torch.matmul(input, dequantized_weights)
|
||||
|
||||
assert (not torch.any(torch.isinf(output_torch))
|
||||
and not torch.any(torch.isnan(output_torch)))
|
||||
assert not torch.any(torch.isinf(output_torch)) and not torch.any(
|
||||
torch.isnan(output_torch)
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output_triton.cpu(),
|
||||
output_torch.cpu(),
|
||||
atol=1e-1,
|
||||
rtol=1e-1)
|
||||
torch.testing.assert_close(
|
||||
output_triton.cpu(), output_torch.cpu(), atol=1e-1, rtol=1e-1
|
||||
)
|
||||
|
||||
@@ -7,20 +7,26 @@ import itertools
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
|
||||
native_w8a8_block_matmul)
|
||||
from tests.kernels.quant_utils import (
|
||||
native_per_token_group_quant_fp8,
|
||||
native_w8a8_block_matmul,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
cutlass_scaled_mm, per_token_group_quant_fp8, w8a8_triton_block_scaled_mm)
|
||||
cutlass_scaled_mm,
|
||||
per_token_group_quant_fp8,
|
||||
w8a8_triton_block_scaled_mm,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import (fp8_gemm_nt,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
per_block_cast_to_fp8)
|
||||
from vllm.utils.deep_gemm import (
|
||||
fp8_gemm_nt,
|
||||
get_col_major_tma_aligned_tensor,
|
||||
per_block_cast_to_fp8,
|
||||
)
|
||||
|
||||
if current_platform.get_device_capability() < (9, 0):
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
||||
allow_module_level=True)
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
@@ -51,7 +57,8 @@ def setup_cuda():
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_tokens,d,dtype,group_size,seed",
|
||||
itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS))
|
||||
itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, SEEDS),
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
|
||||
torch.manual_seed(seed)
|
||||
@@ -60,15 +67,14 @@ def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
|
||||
ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size)
|
||||
out, scale = per_token_group_quant_fp8(x, group_size)
|
||||
|
||||
assert torch.allclose(out.to(torch.float32),
|
||||
ref_out.to(torch.float32),
|
||||
rtol=0.15)
|
||||
assert torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15)
|
||||
assert torch.allclose(scale, ref_scale)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M,N,K,block_size,out_dtype,seed",
|
||||
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
|
||||
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
@@ -89,14 +95,12 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale
|
||||
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
|
||||
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
|
||||
out_dtype)
|
||||
out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size,
|
||||
out_dtype)
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||
rel_diff = torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
|
||||
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
|
||||
assert rel_diff < 0.001
|
||||
|
||||
|
||||
@@ -127,32 +131,32 @@ def test_w8a8_block_fp8_cutlass_matmul():
|
||||
|
||||
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
|
||||
# Hopper requires row-major format for scales
|
||||
Bs_cutlass = Bs.T.contiguous() if current_platform.is_device_capability(
|
||||
90) else Bs
|
||||
Bs_cutlass = Bs.T.contiguous() if current_platform.is_device_capability(90) else Bs
|
||||
|
||||
A_fp8, As = per_token_group_quant_fp8(A_fp32,
|
||||
block_size[1],
|
||||
column_major_scales=False)
|
||||
A_fp8, As = per_token_group_quant_fp8(
|
||||
A_fp32, block_size[1], column_major_scales=False
|
||||
)
|
||||
# CUTLASS uses column-major format for scales
|
||||
A_fp8_cutlass, As_cutlass = per_token_group_quant_fp8(
|
||||
A_fp32, block_size[1], column_major_scales=True)
|
||||
A_fp32, block_size[1], column_major_scales=True
|
||||
)
|
||||
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
|
||||
out_dtype)
|
||||
out = cutlass_scaled_mm(A_fp8_cutlass, B_fp8, As_cutlass, Bs_cutlass,
|
||||
block_size, out_dtype)
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
out = cutlass_scaled_mm(
|
||||
A_fp8_cutlass, B_fp8, As_cutlass, Bs_cutlass, block_size, out_dtype
|
||||
)
|
||||
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||
rel_diff = torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
|
||||
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
|
||||
assert rel_diff < 0.001
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M,N,K,block_size,out_dtype,seed",
|
||||
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
|
||||
@pytest.mark.skipif(not has_deep_gemm(),
|
||||
reason="DeepGemm kernels not available.")
|
||||
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS),
|
||||
)
|
||||
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGemm kernels not available.")
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
# only aligned sizes
|
||||
@@ -172,20 +176,20 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
As = As_fp8.to(torch.float32)
|
||||
Bs = Bs_fp8.to(torch.float32)
|
||||
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
|
||||
out_dtype)
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
|
||||
# Transpose earlier so that the testing will not trigger transposing kernels
|
||||
As_fp8 = get_col_major_tma_aligned_tensor(As_fp8)
|
||||
|
||||
out = torch.zeros((M, N), device='cuda', dtype=out_dtype)
|
||||
out = torch.zeros((M, N), device="cuda", dtype=out_dtype)
|
||||
|
||||
assert As_fp8.shape == (M, (K + 127) //
|
||||
128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}"
|
||||
assert As_fp8.shape == (M, (K + 127) // 128), (
|
||||
f"{As_fp8.shape} != {(M, (K + 127) // 128)}"
|
||||
)
|
||||
|
||||
fp8_gemm_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out)
|
||||
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||
rel_diff = torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
|
||||
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
|
||||
assert rel_diff < 0.001
|
||||
|
||||
@@ -10,12 +10,12 @@ import torch
|
||||
from tests.kernels.quant_utils import native_w8a8_block_matmul
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
w8a8_block_int8_matmul)
|
||||
w8a8_block_int8_matmul,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.get_device_capability() < (7, 0):
|
||||
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
|
||||
allow_module_level=True)
|
||||
pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
@@ -36,8 +36,10 @@ def setup_cuda():
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M,N,K,block_size,out_dtype,seed",
|
||||
itertools.product(M, N, K, BLOCK_SIZE, DTYPES, SEEDS))
|
||||
@pytest.mark.parametrize(
|
||||
"M,N,K,block_size,out_dtype,seed",
|
||||
itertools.product(M, N, K, BLOCK_SIZE, DTYPES, SEEDS),
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
@@ -58,11 +60,10 @@ def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale
|
||||
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
|
||||
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
|
||||
out_dtype)
|
||||
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
out = w8a8_block_int8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
|
||||
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||
rel_diff = torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
|
||||
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
|
||||
assert rel_diff < 0.001
|
||||
|
||||
@@ -11,12 +11,11 @@ import torch
|
||||
from tests.kernels.utils import baseline_scaled_mm, to_fp8, to_int8
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
sparse_cutlass_supported)
|
||||
sparse_cutlass_supported,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
@@ -40,9 +39,7 @@ def prune_to_2_4(tensor):
|
||||
|
||||
# Create binary mask
|
||||
mask = torch.zeros_like(reshaped)
|
||||
mask.scatter_(dim=1,
|
||||
index=indices,
|
||||
src=torch.ones_like(indices, dtype=mask.dtype))
|
||||
mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype))
|
||||
|
||||
# Apply mask and reshape back
|
||||
pruned = reshaped * mask
|
||||
@@ -55,32 +52,31 @@ def prune_to_2_4(tensor):
|
||||
|
||||
# This function checks that applying an identity matrix multiplication
|
||||
# to the compressed weights yields the original uncompressed weights.
|
||||
def check_compress_decompress_invariance(dtype: torch.dtype, b: torch.Tensor,
|
||||
b_compressed: torch.Tensor,
|
||||
b_metadata: torch.Tensor):
|
||||
|
||||
def check_compress_decompress_invariance(
|
||||
dtype: torch.dtype,
|
||||
b: torch.Tensor,
|
||||
b_compressed: torch.Tensor,
|
||||
b_metadata: torch.Tensor,
|
||||
):
|
||||
# For float16 and bfloat16, cutlass_scaled_sparse_mm's output must be the
|
||||
# same dtype as its inputs. This line addresses that constraint while
|
||||
# arbitrarily using bfloat16 for the int8/fp8 cases.
|
||||
out_dtype = torch.float16 if dtype is torch.float16 else torch.bfloat16
|
||||
|
||||
eye = torch.eye(b.shape[0], device='cuda', dtype=dtype)
|
||||
eye_scale = torch.ones(1, device='cuda', dtype=torch.float32)
|
||||
b_decomp = ops.cutlass_scaled_sparse_mm(eye,
|
||||
b_compressed,
|
||||
b_metadata,
|
||||
eye_scale,
|
||||
eye_scale,
|
||||
out_dtype=out_dtype)
|
||||
eye = torch.eye(b.shape[0], device="cuda", dtype=dtype)
|
||||
eye_scale = torch.ones(1, device="cuda", dtype=torch.float32)
|
||||
b_decomp = ops.cutlass_scaled_sparse_mm(
|
||||
eye, b_compressed, b_metadata, eye_scale, eye_scale, out_dtype=out_dtype
|
||||
)
|
||||
|
||||
torch.testing.assert_close(b.to(dtype=out_dtype), b_decomp)
|
||||
|
||||
|
||||
def make_rand_sparse_tensors(
|
||||
dtype: torch.dtype, m: int, n: int, k: int
|
||||
dtype: torch.dtype, m: int, n: int, k: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
a = torch.randn((m, k), device='cuda')
|
||||
b = torch.randn((n, k), device='cuda').t()
|
||||
a = torch.randn((m, k), device="cuda")
|
||||
b = torch.randn((n, k), device="cuda").t()
|
||||
|
||||
if dtype == torch.int8:
|
||||
# ensure A and B aren't all zeros after rounding
|
||||
@@ -107,32 +103,25 @@ def make_rand_sparse_tensors(
|
||||
return b_compressed, e, a, b
|
||||
|
||||
|
||||
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.")
|
||||
@pytest.mark.skipif(
|
||||
not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.",
|
||||
)
|
||||
# Test working with a subset of A and B for sparse matmul
|
||||
def test_cutlass_sparse_subset():
|
||||
|
||||
big_m = 1024
|
||||
m, n, k = 512, 512, 512
|
||||
|
||||
# Create tensors
|
||||
b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn,
|
||||
big_m, n, k)
|
||||
b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, big_m, n, k)
|
||||
a = whole_a[0:m, 0:k]
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(a,
|
||||
b_comp,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
baseline = baseline_scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
out = ops.cutlass_scaled_sparse_mm(
|
||||
a, b_comp, e, scale_a, scale_b, out_dtype=torch.bfloat16
|
||||
)
|
||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
@@ -161,105 +150,87 @@ MNK_FACTORS = [
|
||||
|
||||
|
||||
# Test working with a subset of A and B for sparse matmul
|
||||
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.")
|
||||
@pytest.mark.skipif(
|
||||
not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("m, n, k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: type[torch.dtype],
|
||||
use_bias: bool):
|
||||
|
||||
def test_cutlass_sparse_gemm(
|
||||
m: int, k: int, n: int, dtype: type[torch.dtype], use_bias: bool
|
||||
):
|
||||
# Create tensors
|
||||
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
|
||||
scale_a = torch.ones((1, 1), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.ones((1, 1), device="cuda", dtype=torch.float32)
|
||||
|
||||
bias = torch.rand((n, ), device="cuda", dtype=dtype) if use_bias else None
|
||||
bias = torch.rand((n,), device="cuda", dtype=dtype) if use_bias else None
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(a,
|
||||
b_comp,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=dtype,
|
||||
bias=bias)
|
||||
out = ops.cutlass_scaled_sparse_mm(
|
||||
a, b_comp, e, scale_a, scale_b, out_dtype=dtype, bias=bias
|
||||
)
|
||||
|
||||
baseline = baseline_scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=dtype,
|
||||
bias=bias)
|
||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=dtype, bias=bias)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.")
|
||||
@pytest.mark.skipif(
|
||||
not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("m, k, n", MNK_FACTORS)
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_sparse_fp8_gemm(m: int, n: int, k: int, use_bias: bool):
|
||||
|
||||
# Create tensors
|
||||
b_comp, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
|
||||
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
|
||||
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32)
|
||||
out_dtype = torch.bfloat16
|
||||
|
||||
bias = torch.rand(
|
||||
(n, ), device="cuda", dtype=out_dtype) * 10 if use_bias else None
|
||||
bias = torch.rand((n,), device="cuda", dtype=out_dtype) * 10 if use_bias else None
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(a,
|
||||
b_comp,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=out_dtype,
|
||||
bias=bias)
|
||||
out = ops.cutlass_scaled_sparse_mm(
|
||||
a, b_comp, e, scale_a, scale_b, out_dtype=out_dtype, bias=bias
|
||||
)
|
||||
|
||||
baseline = baseline_scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=out_dtype,
|
||||
bias=bias)
|
||||
baseline = baseline_scaled_mm(
|
||||
a, b, scale_a, scale_b, out_dtype=out_dtype, bias=bias
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.")
|
||||
@pytest.mark.skipif(
|
||||
not sparse_cutlass_supported(),
|
||||
reason="Sparse CUTLASS is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("m,k,n", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_sparse_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
|
||||
per_out_ch: bool, use_bias: bool):
|
||||
|
||||
def test_cutlass_sparse_int8_gemm(
|
||||
m: int, n: int, k: int, per_act_token: bool, per_out_ch: bool, use_bias: bool
|
||||
):
|
||||
# Create tensors
|
||||
b_comp, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
|
||||
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
|
||||
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32)
|
||||
out_dtype = torch.bfloat16
|
||||
|
||||
bias = torch.rand(
|
||||
(n, ), device="cuda", dtype=out_dtype) * 10 if use_bias else None
|
||||
bias = torch.rand((n,), device="cuda", dtype=out_dtype) * 10 if use_bias else None
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(a,
|
||||
b_comp,
|
||||
e,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=out_dtype,
|
||||
bias=bias)
|
||||
out = ops.cutlass_scaled_sparse_mm(
|
||||
a, b_comp, e, scale_a, scale_b, out_dtype=out_dtype, bias=bias
|
||||
)
|
||||
|
||||
baseline = baseline_scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=out_dtype,
|
||||
bias=bias)
|
||||
baseline = baseline_scaled_mm(
|
||||
a, b, scale_a, scale_b, out_dtype=out_dtype, bias=bias
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
Run `pytest tests/kernels/quantization/test_cutlass_scaled_mm.py`.
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
@@ -36,9 +37,7 @@ MNK_FACTORS = [
|
||||
(512, 24576, 128),
|
||||
]
|
||||
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
# -1 means full extent in that dimension
|
||||
TENSORWISE_GROUP_SHAPE = (-1, -1)
|
||||
@@ -60,18 +59,19 @@ def group_scale_helper(shape, group_shape):
|
||||
def scale_shape(shape, group_shape):
|
||||
assert len(shape) == len(group_shape)
|
||||
group_shape = group_scale_helper(shape, group_shape)
|
||||
return tuple(
|
||||
cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
|
||||
return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
|
||||
|
||||
|
||||
def cutlass_fp8_gemm_helper(m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
a_scale_group_shape: tuple,
|
||||
b_scale_group_shape: tuple,
|
||||
use_bias: bool,
|
||||
out_dtype: type[torch.dtype] = torch.bfloat16,
|
||||
device: str = "cuda"):
|
||||
def cutlass_fp8_gemm_helper(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
a_scale_group_shape: tuple,
|
||||
b_scale_group_shape: tuple,
|
||||
use_bias: bool,
|
||||
out_dtype: type[torch.dtype] = torch.bfloat16,
|
||||
device: str = "cuda",
|
||||
):
|
||||
# Test for a cutlass kernel with per-token activation quantization
|
||||
# and per-output channel weight quantization.
|
||||
a = to_fp8(torch.randn((m, k), device=device))
|
||||
@@ -80,8 +80,8 @@ def cutlass_fp8_gemm_helper(m: int,
|
||||
a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
|
||||
b_scales_shape = scale_shape(b.shape, b_scale_group_shape)
|
||||
|
||||
scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32))
|
||||
scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32))
|
||||
scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32)
|
||||
scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32)
|
||||
|
||||
# make scales M-major for blockwise quant, doesn't affect 1D scales
|
||||
scale_a = scale_a.t().contiguous().t()
|
||||
@@ -89,7 +89,7 @@ def cutlass_fp8_gemm_helper(m: int,
|
||||
scale_b = scale_b.t().contiguous().t()
|
||||
|
||||
if use_bias:
|
||||
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
|
||||
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10
|
||||
else:
|
||||
bias = None
|
||||
|
||||
@@ -98,18 +98,19 @@ def cutlass_fp8_gemm_helper(m: int,
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=5e-1, atol=1.5e-1)
|
||||
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm,
|
||||
(out, a, b, scale_a, scale_b, bias))
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm, (out, a, b, scale_a, scale_b, bias))
|
||||
|
||||
|
||||
def cutlass_int8_gemm_helper(m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
a_scale_group_shape: tuple,
|
||||
b_scale_group_shape: tuple,
|
||||
use_bias: bool,
|
||||
out_dtype: type[torch.dtype] = torch.bfloat16,
|
||||
device: str = "cuda"):
|
||||
def cutlass_int8_gemm_helper(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
a_scale_group_shape: tuple,
|
||||
b_scale_group_shape: tuple,
|
||||
use_bias: bool,
|
||||
out_dtype: type[torch.dtype] = torch.bfloat16,
|
||||
device: str = "cuda",
|
||||
):
|
||||
# Test for a cutlass kernel with per-token activation quantization
|
||||
# and per-output channel weight quantization.
|
||||
a = to_int8(torch.randn((m, k), device=device) * 5)
|
||||
@@ -118,11 +119,11 @@ def cutlass_int8_gemm_helper(m: int,
|
||||
a_scales_shape = scale_shape(a.shape, a_scale_group_shape)
|
||||
b_scales_shape = scale_shape(b.shape, b_scale_group_shape)
|
||||
|
||||
scale_a = (torch.randn(a_scales_shape, device=device, dtype=torch.float32))
|
||||
scale_b = (torch.randn(b_scales_shape, device=device, dtype=torch.float32))
|
||||
scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32)
|
||||
scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32)
|
||||
|
||||
if use_bias:
|
||||
bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
|
||||
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10
|
||||
else:
|
||||
bias = None
|
||||
|
||||
@@ -131,145 +132,192 @@ def cutlass_int8_gemm_helper(m: int,
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm,
|
||||
(out, a, b, scale_a, scale_b, bias))
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm, (out, a, b, scale_a, scale_b, bias))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm(m: int, n: int, k: int, a_scale_group_shape,
|
||||
b_scale_group_shape, use_bias: bool):
|
||||
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
|
||||
use_bias)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_gemm(
|
||||
m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool
|
||||
):
|
||||
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape",
|
||||
[((1, 128), (128, 128))])
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape,b_scale_group_shape", [((1, 128), (128, 128))]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [False])
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
reason="FP8 blockwise is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape, use_bias: bool):
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="FP8 blockwise is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_blockwise_scale_gemm(
|
||||
m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool
|
||||
):
|
||||
if k % b_scale_group_shape[0] != 0 or n % b_scale_group_shape[1] != 0:
|
||||
return
|
||||
if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0:
|
||||
return
|
||||
if m % 4 != 0 and current_platform.has_device_capability(100):
|
||||
return
|
||||
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
|
||||
use_bias)
|
||||
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_int8_gemm(m: int, n: int, k: int, a_scale_group_shape,
|
||||
b_scale_group_shape, use_bias: bool):
|
||||
cutlass_int8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape,
|
||||
use_bias)
|
||||
def test_cutlass_int8_gemm(
|
||||
m: int, n: int, k: int, a_scale_group_shape, b_scale_group_shape, use_bias: bool
|
||||
):
|
||||
cutlass_int8_gemm_helper(
|
||||
m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
out_dtype: type[torch.dtype],
|
||||
use_bias: bool):
|
||||
cutlass_int8_gemm_helper(512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=out_dtype)
|
||||
def test_cutlass_int8_gemm_output_dtype(
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
out_dtype: type[torch.dtype],
|
||||
use_bias: bool,
|
||||
):
|
||||
cutlass_int8_gemm_helper(
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
out_dtype: type[torch.dtype],
|
||||
use_bias: bool):
|
||||
cutlass_fp8_gemm_helper(512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=out_dtype)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_gemm_output_dtype(
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
out_dtype: type[torch.dtype],
|
||||
use_bias: bool,
|
||||
):
|
||||
cutlass_fp8_gemm_helper(
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("a_scale_group_shape,b_scale_group_shape",
|
||||
[((1, 128), (128, 128))])
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape,b_scale_group_shape", [((1, 128), (128, 128))]
|
||||
)
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [False])
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(90),
|
||||
reason="FP8 blockwise is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_blockwise_scale_gemm_dtype(a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
out_dtype: type[torch.dtype],
|
||||
use_bias: bool):
|
||||
cutlass_fp8_gemm_helper(512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=out_dtype)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(90),
|
||||
reason="FP8 blockwise is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_blockwise_scale_gemm_dtype(
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
out_dtype: type[torch.dtype],
|
||||
use_bias: bool,
|
||||
):
|
||||
cutlass_fp8_gemm_helper(
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
|
||||
use_bias: bool, device: str):
|
||||
cutlass_fp8_gemm_helper(512, 512, 512, a_scale_group_shape,
|
||||
b_scale_group_shape, use_bias, torch.bfloat16,
|
||||
device)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_gemm_devices(
|
||||
a_scale_group_shape, b_scale_group_shape, use_bias: bool, device: str
|
||||
):
|
||||
cutlass_fp8_gemm_helper(
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
torch.bfloat16,
|
||||
device,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
|
||||
use_bias: bool, device: str):
|
||||
cutlass_int8_gemm_helper(512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=torch.bfloat16,
|
||||
device=device)
|
||||
def test_cutlass_int8_gemm_devices(
|
||||
a_scale_group_shape, b_scale_group_shape, use_bias: bool, device: str
|
||||
):
|
||||
cutlass_int8_gemm_helper(
|
||||
512,
|
||||
512,
|
||||
512,
|
||||
a_scale_group_shape,
|
||||
b_scale_group_shape,
|
||||
use_bias,
|
||||
out_dtype=torch.bfloat16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
# For the following two tests:
|
||||
@@ -277,32 +325,42 @@ def test_cutlass_int8_gemm_devices(a_scale_group_shape, b_scale_group_shape,
|
||||
# of a large power of two. In any case, the kernel will have a naive fallback
|
||||
# when N and K are not divisible by 16. But M is the number of tokens and the
|
||||
# kernel must handle any M thrown at it.
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.skipif(not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape,
|
||||
use_bias: bool):
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.has_device_capability(89),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_gemm_m_sweep(
|
||||
a_scale_group_shape, b_scale_group_shape, use_bias: bool
|
||||
):
|
||||
for nk in range(32, 128, 32):
|
||||
for m in range(1, 128):
|
||||
cutlass_fp8_gemm_helper(m, nk, nk, a_scale_group_shape,
|
||||
b_scale_group_shape, use_bias)
|
||||
cutlass_fp8_gemm_helper(
|
||||
m, nk, nk, a_scale_group_shape, b_scale_group_shape, use_bias
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("a_scale_group_shape",
|
||||
[PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize("b_scale_group_shape",
|
||||
[PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE])
|
||||
@pytest.mark.parametrize(
|
||||
"a_scale_group_shape", [PER_TOKEN_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"b_scale_group_shape", [PER_OUT_CH_GROUP_SHAPE, TENSORWISE_GROUP_SHAPE]
|
||||
)
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_cutlass_int8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape,
|
||||
use_bias: bool):
|
||||
def test_cutlass_int8_gemm_m_sweep(
|
||||
a_scale_group_shape, b_scale_group_shape, use_bias: bool
|
||||
):
|
||||
for nk in range(32, 128, 32):
|
||||
for m in range(1, 128):
|
||||
cutlass_int8_gemm_helper(m, nk, nk, a_scale_group_shape,
|
||||
b_scale_group_shape, use_bias)
|
||||
cutlass_int8_gemm_helper(
|
||||
m, nk, nk, a_scale_group_shape, b_scale_group_shape, use_bias
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [32, 64, 128])
|
||||
@@ -310,8 +368,7 @@ def test_cutlass_int8_gemm_m_sweep(a_scale_group_shape, b_scale_group_shape,
|
||||
@pytest.mark.parametrize("k", [64, 128, 256])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.skip
|
||||
def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
|
||||
out_dtype: torch.dtype):
|
||||
def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int, out_dtype: torch.dtype):
|
||||
# Currently, the test is failing because folding azp into
|
||||
# 16-bit bias loses too much precision
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
@@ -328,7 +385,7 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
|
||||
|
||||
b_dq = scale_b * bq_f32
|
||||
|
||||
azp_a = torch.rand((1, ), device="cuda", dtype=torch.float32) * 10 + 1.5
|
||||
azp_a = torch.rand((1,), device="cuda", dtype=torch.float32) * 10 + 1.5
|
||||
azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
|
||||
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
|
||||
|
||||
@@ -340,18 +397,17 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
|
||||
J = torch.ones((1, k), device="cuda", dtype=torch.float32)
|
||||
azp_bias = (azp_a * scale_b * (J @ bq_f32)).to(out_dtype)
|
||||
assert azp_bias.shape == (1, n)
|
||||
assert azp_bias[0, :].shape == (n, )
|
||||
assert azp_bias[0, :].shape == (n,)
|
||||
|
||||
baseline_q = (scale_a.to(device='cpu') * scale_b.to(device='cpu') * (
|
||||
(aq_i32 + azp_aq_i8).to(device='cpu') @ bq_i32.to(device='cpu'))).to(
|
||||
dtype=out_dtype, device='cuda')
|
||||
baseline_q = (
|
||||
scale_a.to(device="cpu")
|
||||
* scale_b.to(device="cpu")
|
||||
* ((aq_i32 + azp_aq_i8).to(device="cpu") @ bq_i32.to(device="cpu"))
|
||||
).to(dtype=out_dtype, device="cuda")
|
||||
|
||||
out = ops.cutlass_scaled_mm(aq_i8,
|
||||
bq_i8,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=out_dtype,
|
||||
bias=azp_bias[0, :])
|
||||
out = ops.cutlass_scaled_mm(
|
||||
aq_i8, bq_i8, scale_a, scale_b, out_dtype=out_dtype, bias=azp_bias[0, :]
|
||||
)
|
||||
torch.testing.assert_close(out, baseline_dq, rtol=1e-2, atol=1e0)
|
||||
torch.testing.assert_close(out, baseline_q, rtol=1e-2, atol=1e0)
|
||||
|
||||
@@ -362,8 +418,9 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
|
||||
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
@pytest.mark.parametrize("azp_per_token", [True, False])
|
||||
def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
|
||||
use_bias: bool, azp_per_token: bool):
|
||||
def test_cutlass_int8_azp(
|
||||
m: int, n: int, k: int, out_dtype: torch.dtype, use_bias: bool, azp_per_token: bool
|
||||
):
|
||||
m_azp = m if azp_per_token else 1
|
||||
scale_a = torch.randn((m_azp, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
|
||||
@@ -377,16 +434,12 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
|
||||
bq_f32 = bq_i8.to(dtype=torch.float32)
|
||||
b_dq = scale_b * bq_f32
|
||||
|
||||
azp_a = torch.rand(
|
||||
(m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5
|
||||
azp_a = torch.rand((m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5
|
||||
azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
|
||||
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
|
||||
|
||||
a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
|
||||
torch.testing.assert_close(a_dq,
|
||||
scale_a * aq_f32 - azp_a,
|
||||
rtol=1e-4,
|
||||
atol=1e-3)
|
||||
torch.testing.assert_close(a_dq, scale_a * aq_f32 - azp_a, rtol=1e-4, atol=1e-3)
|
||||
|
||||
if use_bias:
|
||||
bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
|
||||
@@ -396,8 +449,8 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
|
||||
baseline_dq = (torch.mm(a_dq, b_dq) + bias).to(out_dtype)
|
||||
|
||||
# int32 mm not supported on CUDA
|
||||
a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device='cpu')
|
||||
cq = (a_noazp_i32_cpu @ bq_i32.to(device='cpu')).to(device='cuda')
|
||||
a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device="cpu")
|
||||
cq = (a_noazp_i32_cpu @ bq_i32.to(device="cpu")).to(device="cuda")
|
||||
baseline_q = (scale_a * scale_b * cq + bias).to(dtype=out_dtype)
|
||||
|
||||
# Hadamard is just the sum of the cols
|
||||
@@ -406,14 +459,14 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
|
||||
func_bias = bias if use_bias else None
|
||||
|
||||
if azp_per_token:
|
||||
out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
|
||||
out_dtype, azp_adj_i32, azp_i32,
|
||||
func_bias)
|
||||
out = ops.cutlass_scaled_mm_azp(
|
||||
aq_i8, bq_i8, scale_a, scale_b, out_dtype, azp_adj_i32, azp_i32, func_bias
|
||||
)
|
||||
else:
|
||||
azp_with_adj_i32 = azp_i32 * azp_adj_i32
|
||||
out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
|
||||
out_dtype, azp_with_adj_i32, None,
|
||||
func_bias)
|
||||
out = ops.cutlass_scaled_mm_azp(
|
||||
aq_i8, bq_i8, scale_a, scale_b, out_dtype, azp_with_adj_i32, None, func_bias
|
||||
)
|
||||
|
||||
# bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
|
||||
# float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
|
||||
@@ -423,13 +476,15 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
|
||||
torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
|
||||
|
||||
if azp_per_token:
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm_azp,
|
||||
(out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32,
|
||||
func_bias))
|
||||
opcheck(
|
||||
torch.ops._C.cutlass_scaled_mm_azp,
|
||||
(out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32, func_bias),
|
||||
)
|
||||
else:
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm_azp,
|
||||
(out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None,
|
||||
func_bias))
|
||||
opcheck(
|
||||
torch.ops._C.cutlass_scaled_mm_azp,
|
||||
(out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None, func_bias),
|
||||
)
|
||||
|
||||
|
||||
# Test working with a subset of A and B
|
||||
@@ -445,23 +500,14 @@ def test_cutlass_subset():
|
||||
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
out = ops.cutlass_scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
baseline = baseline_scaled_mm(a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=torch.bfloat16)
|
||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16)
|
||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16)
|
||||
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
|
||||
# Test to make sure cuda graphs work
|
||||
class CutlassLayer(torch.nn.Module):
|
||||
|
||||
def __init__(self, b, scale_a, scale_b, out_dtype):
|
||||
super().__init__()
|
||||
self.b = b
|
||||
@@ -470,8 +516,9 @@ class CutlassLayer(torch.nn.Module):
|
||||
self.out_dtype = out_dtype
|
||||
|
||||
def forward(self, a):
|
||||
return ops.cutlass_scaled_mm(a, self.b, self.scale_a, self.scale_b,
|
||||
self.out_dtype)
|
||||
return ops.cutlass_scaled_mm(
|
||||
a, self.b, self.scale_a, self.scale_b, self.out_dtype
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@@ -485,10 +532,8 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
|
||||
m_a_scales = m if per_act_token else 1
|
||||
n_b_scales = n if per_out_ch else 1
|
||||
|
||||
scale_a = (torch.randn(
|
||||
(m_a_scales, 1), device="cuda", dtype=torch.float32) / 10)
|
||||
scale_b = (torch.randn(
|
||||
(1, n_b_scales), device="cuda", dtype=torch.float32) / 10)
|
||||
scale_a = torch.randn((m_a_scales, 1), device="cuda", dtype=torch.float32) / 10
|
||||
scale_b = torch.randn((1, n_b_scales), device="cuda", dtype=torch.float32) / 10
|
||||
|
||||
# Construct a trivial model with a single layer that calls a CUTLASS kernel
|
||||
model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16)
|
||||
@@ -502,13 +547,14 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
|
||||
out.zero_()
|
||||
g.replay()
|
||||
|
||||
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
|
||||
scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16)
|
||||
baseline = torch.mm(
|
||||
scale_a * a.to(dtype=torch.float32), scale_b * b.to(dtype=torch.float32)
|
||||
).to(torch.bfloat16)
|
||||
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||
|
||||
|
||||
def test_cutlass_support_opcheck():
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability, ))
|
||||
opcheck(torch.ops._C.cutlass_scaled_mm_supports_fp8, (capability,))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_experts", [8, 64])
|
||||
@@ -517,11 +563,13 @@ def test_cutlass_support_opcheck():
|
||||
@pytest.mark.parametrize("use_bias", [False])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
|
||||
per_out_ch: bool, use_bias: bool):
|
||||
|
||||
current_platform.get_device_capability()
|
||||
),
|
||||
reason="Grouped gemm is not supported on this GPU type.",
|
||||
)
|
||||
def test_cutlass_fp8_group_gemm(
|
||||
num_experts: int, per_act_token: bool, per_out_ch: bool, use_bias: bool
|
||||
):
|
||||
# Device and dtype setup
|
||||
device = "cuda"
|
||||
out_dtype = torch.half
|
||||
@@ -533,13 +581,9 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
|
||||
b_scales_tensors = []
|
||||
baseline_tensors = []
|
||||
|
||||
expert_offsets = torch.zeros((num_experts + 1),
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
expert_offsets = torch.zeros((num_experts + 1), device=device, dtype=torch.int64)
|
||||
|
||||
problem_sizes = torch.zeros((num_experts, 3),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
problem_sizes = torch.zeros((num_experts, 3), device=device, dtype=torch.int32)
|
||||
|
||||
if not per_act_token:
|
||||
one_scale_a = torch.randn((1, 1), device=device, dtype=torch.float32)
|
||||
@@ -566,75 +610,76 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool,
|
||||
b_tensors.append(b_g)
|
||||
|
||||
# Set up A/B scales
|
||||
scale_b = torch.randn((1, n_b_scales),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
scale_b = torch.randn((1, n_b_scales), device=device, dtype=torch.float32)
|
||||
b_scales_tensors.append(scale_b)
|
||||
|
||||
if per_act_token:
|
||||
scale_a = torch.randn((m_a_scales, 1),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
scale_a = torch.randn((m_a_scales, 1), device=device, dtype=torch.float32)
|
||||
a_scales_tensors.append(scale_a)
|
||||
else:
|
||||
scale_a = one_scale_a
|
||||
|
||||
# Compute baseline result for this group
|
||||
baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype,
|
||||
None)
|
||||
baseline_g = baseline_scaled_mm(a_g, b_g, scale_a, scale_b, out_dtype, None)
|
||||
baseline_tensors.append(baseline_g)
|
||||
|
||||
a_tensors_stacked = torch.empty((expert_offsets[num_experts], k_g),
|
||||
device=device,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
b_tensors_stacked = torch.empty((num_experts, n_g, k_g),
|
||||
device=device,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
a_tensors_stacked = torch.empty(
|
||||
(expert_offsets[num_experts], k_g), device=device, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
b_tensors_stacked = torch.empty(
|
||||
(num_experts, n_g, k_g), device=device, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
for g in range(num_experts):
|
||||
a_tensors_stacked[expert_offsets[g]:expert_offsets[g +
|
||||
1]] = a_tensors[g]
|
||||
a_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]] = a_tensors[g]
|
||||
b_tensors_stacked[g] = b_tensors[g].t()
|
||||
b_tensors_stacked = b_tensors_stacked.transpose(1, 2)
|
||||
|
||||
if per_act_token:
|
||||
a_scales_tensors_stacked = torch.empty(
|
||||
(expert_offsets[num_experts], 1),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
(expert_offsets[num_experts], 1), device=device, dtype=torch.float32
|
||||
)
|
||||
for g in range(num_experts):
|
||||
a_scales_tensors_stacked[
|
||||
expert_offsets[g]:expert_offsets[g + 1]] = a_scales_tensors[g]
|
||||
a_scales_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]] = (
|
||||
a_scales_tensors[g]
|
||||
)
|
||||
else:
|
||||
a_scales_tensors_stacked = one_scale_a
|
||||
|
||||
b_scales_tensors_stacked = torch.empty((num_experts, n_b_scales),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
b_scales_tensors_stacked = torch.empty(
|
||||
(num_experts, n_b_scales), device=device, dtype=torch.float32
|
||||
)
|
||||
for g in range(num_experts):
|
||||
b_scales_tensors_stacked[g] = b_scales_tensors[g]
|
||||
|
||||
out_tensors_stacked = torch.zeros((expert_offsets[num_experts], n_g),
|
||||
device=device,
|
||||
dtype=out_dtype)
|
||||
out_tensors_stacked = torch.zeros(
|
||||
(expert_offsets[num_experts], n_g), device=device, dtype=out_dtype
|
||||
)
|
||||
|
||||
ab_strides = torch.full((num_experts, ),
|
||||
a_tensors_stacked.stride(0),
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
c_strides = torch.full((num_experts, ),
|
||||
out_tensors_stacked.stride(0),
|
||||
device="cuda",
|
||||
dtype=torch.int64)
|
||||
ab_strides = torch.full(
|
||||
(num_experts,), a_tensors_stacked.stride(0), device="cuda", dtype=torch.int64
|
||||
)
|
||||
c_strides = torch.full(
|
||||
(num_experts,), out_tensors_stacked.stride(0), device="cuda", dtype=torch.int64
|
||||
)
|
||||
|
||||
ops.cutlass_moe_mm(out_tensors_stacked, a_tensors_stacked,
|
||||
b_tensors_stacked, a_scales_tensors_stacked,
|
||||
b_scales_tensors_stacked, expert_offsets[:-1],
|
||||
problem_sizes, ab_strides, ab_strides, c_strides,
|
||||
per_act_token, per_out_ch)
|
||||
ops.cutlass_moe_mm(
|
||||
out_tensors_stacked,
|
||||
a_tensors_stacked,
|
||||
b_tensors_stacked,
|
||||
a_scales_tensors_stacked,
|
||||
b_scales_tensors_stacked,
|
||||
expert_offsets[:-1],
|
||||
problem_sizes,
|
||||
ab_strides,
|
||||
ab_strides,
|
||||
c_strides,
|
||||
per_act_token,
|
||||
per_out_ch,
|
||||
)
|
||||
|
||||
# Validate each group's result against the baseline
|
||||
for g in range(num_experts):
|
||||
baseline = baseline_tensors[g]
|
||||
c = out_tensors_stacked[expert_offsets[g]:expert_offsets[g + 1]]
|
||||
c = out_tensors_stacked[expert_offsets[g] : expert_offsets[g + 1]]
|
||||
torch.testing.assert_close(c, baseline, rtol=1e-2, atol=5e-4)
|
||||
|
||||
@@ -13,7 +13,9 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_rows, quantize_weights)
|
||||
pack_rows,
|
||||
quantize_weights,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
@@ -24,16 +26,33 @@ from vllm.scalar_type import ScalarType, scalar_types
|
||||
# have kernels and some kernels support multiple quantization methods.
|
||||
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9
|
||||
|
||||
MNK_SHAPES = [(1, 128, 128), (1, 512, 1024), (1, 4096, 4096), (1, 8192, 28672),
|
||||
(13, 8192, 4096), (26, 4096, 8192), (64, 4096, 4096),
|
||||
(64, 8192, 28672), (257, 128, 4096), (257, 4096, 4096),
|
||||
(1024, 4096, 8192), (1024, 8192, 4096)]
|
||||
MNK_SHAPES = [
|
||||
(1, 128, 128),
|
||||
(1, 512, 1024),
|
||||
(1, 4096, 4096),
|
||||
(1, 8192, 28672),
|
||||
(13, 8192, 4096),
|
||||
(26, 4096, 8192),
|
||||
(64, 4096, 4096),
|
||||
(64, 8192, 28672),
|
||||
(257, 128, 4096),
|
||||
(257, 4096, 4096),
|
||||
(1024, 4096, 8192),
|
||||
(1024, 8192, 4096),
|
||||
]
|
||||
|
||||
# TODO(czhu): get supported schedules from fn
|
||||
SCHEDULES = [
|
||||
'128x16_1x1x1', '256x16_1x1x1', '128x32_1x1x1', '256x32_1x1x1',
|
||||
'128x64_1x1x1', '256x64_1x1x1', '128x128_1x1x1', '256x128_1x1x1',
|
||||
'128x256_1x1x1', '128x256_2x1x1'
|
||||
"128x16_1x1x1",
|
||||
"256x16_1x1x1",
|
||||
"128x32_1x1x1",
|
||||
"256x32_1x1x1",
|
||||
"128x64_1x1x1",
|
||||
"256x64_1x1x1",
|
||||
"128x128_1x1x1",
|
||||
"256x128_1x1x1",
|
||||
"128x256_1x1x1",
|
||||
"128x256_2x1x1",
|
||||
]
|
||||
|
||||
|
||||
@@ -60,19 +79,23 @@ class Tensors:
|
||||
|
||||
# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints,
|
||||
# Ch Scales Type, Tok Scales Type)
|
||||
TestTypeTuple = tuple[list[torch.dtype], ScalarType, Optional[torch.dtype],
|
||||
Optional[torch.dtype], bool]
|
||||
TestTypeTuple = tuple[
|
||||
list[torch.dtype], ScalarType, Optional[torch.dtype], Optional[torch.dtype], bool
|
||||
]
|
||||
TEST_TYPES = [
|
||||
*(
|
||||
TypeConfig(act_type=torch.float8_e4m3fn,
|
||||
weight_type=w_type,
|
||||
output_type=o_type,
|
||||
group_scale_type=torch.float8_e4m3fn,
|
||||
channel_scale_type=torch.float32,
|
||||
token_scale_type=torch.float32)
|
||||
TypeConfig(
|
||||
act_type=torch.float8_e4m3fn,
|
||||
weight_type=w_type,
|
||||
output_type=o_type,
|
||||
group_scale_type=torch.float8_e4m3fn,
|
||||
channel_scale_type=torch.float32,
|
||||
token_scale_type=torch.float32,
|
||||
)
|
||||
for w_type in [scalar_types.int4]
|
||||
# TODO(czhu): fp16 out type
|
||||
for o_type in [torch.bfloat16]),
|
||||
for o_type in [torch.bfloat16]
|
||||
),
|
||||
]
|
||||
|
||||
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
|
||||
@@ -86,26 +109,28 @@ IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90)
|
||||
# For testing quantized linear kernels
|
||||
def to_fp8(tensor: torch.Tensor):
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return tensor.clamp(min=finfo.min,
|
||||
max=finfo.max).to(dtype=torch.float8_e4m3fn)
|
||||
return tensor.clamp(min=finfo.min, max=finfo.max).to(dtype=torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def cutlass_quantize_and_pack(atype: torch.dtype,
|
||||
w: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
stype: Optional[torch.dtype],
|
||||
group_size: Optional[int],
|
||||
zero_points: bool = False):
|
||||
def cutlass_quantize_and_pack(
|
||||
atype: torch.dtype,
|
||||
w: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
stype: Optional[torch.dtype],
|
||||
group_size: Optional[int],
|
||||
zero_points: bool = False,
|
||||
):
|
||||
assert wtype.is_integer(), "TODO: support floating point weights"
|
||||
|
||||
w_ref, w_q, w_s, w_zp = quantize_weights(w,
|
||||
wtype,
|
||||
group_size=group_size,
|
||||
zero_points=zero_points)
|
||||
w_ref, w_q, w_s, w_zp = quantize_weights(
|
||||
w, wtype, group_size=group_size, zero_points=zero_points
|
||||
)
|
||||
|
||||
# since scales are cast to fp8, we need to compute w_ref this way
|
||||
w_ref = ((w_q).to(torch.float32) * w_s.to(atype).to(
|
||||
torch.float32).repeat_interleave(group_size, dim=0)).to(atype)
|
||||
w_ref = (
|
||||
(w_q).to(torch.float32)
|
||||
* w_s.to(atype).to(torch.float32).repeat_interleave(group_size, dim=0)
|
||||
).to(atype)
|
||||
|
||||
# bit mask prevents sign extending int4 when packing
|
||||
w_q = pack_rows(w_q & 0x0F, wtype.size_bits, *w_q.shape)
|
||||
@@ -117,12 +142,14 @@ def cutlass_quantize_and_pack(atype: torch.dtype,
|
||||
return w_ref, w_q_packed, w_s_packed, w_zp
|
||||
|
||||
|
||||
def create_test_tensors(shape: tuple[int, int, int], types: TypeConfig,
|
||||
group_size: Optional[int]) -> Tensors:
|
||||
def create_test_tensors(
|
||||
shape: tuple[int, int, int], types: TypeConfig, group_size: Optional[int]
|
||||
) -> Tensors:
|
||||
m, n, k = shape
|
||||
|
||||
print("create_test_tensors, shape:", shape, "types:", types, "group_size:",
|
||||
group_size)
|
||||
print(
|
||||
"create_test_tensors, shape:", shape, "types:", types, "group_size:", group_size
|
||||
)
|
||||
|
||||
a = to_fp8(torch.randn((m, k), device="cuda"))
|
||||
w = to_fp8(torch.randn((k, n), device="cuda"))
|
||||
@@ -133,30 +160,34 @@ def create_test_tensors(shape: tuple[int, int, int], types: TypeConfig,
|
||||
w = w.to(torch.float16)
|
||||
|
||||
w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack(
|
||||
a.dtype, w, types.weight_type, types.group_scale_type, group_size,
|
||||
False)
|
||||
a.dtype, w, types.weight_type, types.group_scale_type, group_size, False
|
||||
)
|
||||
|
||||
a_ref = a.to(torch.float32)
|
||||
w_ref = w_ref.to(torch.float32)
|
||||
|
||||
# for the practical use case we need per-tok scales for fp8 activations
|
||||
w_tok_s = torch.randn((m, ), device='cuda', dtype=types.token_scale_type)
|
||||
w_tok_s = torch.randn((m,), device="cuda", dtype=types.token_scale_type)
|
||||
# weights are already per-group quantized, use placeholder here
|
||||
w_ch_s = torch.ones((n, ), device='cuda', dtype=types.channel_scale_type)
|
||||
w_ch_s = torch.ones((n,), device="cuda", dtype=types.channel_scale_type)
|
||||
|
||||
return Tensors(w_ref=w_ref,
|
||||
a_ref=a_ref,
|
||||
a=a,
|
||||
w_q=w_q_packed,
|
||||
w_g_s=w_s,
|
||||
w_ch_s=w_ch_s,
|
||||
w_tok_s=w_tok_s)
|
||||
return Tensors(
|
||||
w_ref=w_ref,
|
||||
a_ref=a_ref,
|
||||
a=a,
|
||||
w_q=w_q_packed,
|
||||
w_g_s=w_s,
|
||||
w_ch_s=w_ch_s,
|
||||
w_tok_s=w_tok_s,
|
||||
)
|
||||
|
||||
|
||||
def mm_test_helper(types: TypeConfig,
|
||||
tensors: Tensors,
|
||||
group_size: Optional[int] = None,
|
||||
schedule: Optional[str] = None):
|
||||
def mm_test_helper(
|
||||
types: TypeConfig,
|
||||
tensors: Tensors,
|
||||
group_size: Optional[int] = None,
|
||||
schedule: Optional[str] = None,
|
||||
):
|
||||
# CUTLASS upstream uses fp8 with fastaccum as reference
|
||||
# https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L406
|
||||
output_ref = torch._scaled_mm(
|
||||
@@ -165,7 +196,8 @@ def mm_test_helper(types: TypeConfig,
|
||||
tensors.w_tok_s.unsqueeze(1),
|
||||
tensors.w_ch_s.unsqueeze(0),
|
||||
out_dtype=types.output_type,
|
||||
use_fast_accum=True)
|
||||
use_fast_accum=True,
|
||||
)
|
||||
|
||||
output = ops.cutlass_w4a8_mm(
|
||||
a=tensors.a,
|
||||
@@ -179,17 +211,15 @@ def mm_test_helper(types: TypeConfig,
|
||||
print(output)
|
||||
print(output_ref)
|
||||
|
||||
torch.testing.assert_close(output,
|
||||
output_ref.to(output.dtype),
|
||||
rtol=1e-3,
|
||||
atol=1e-3)
|
||||
torch.testing.assert_close(
|
||||
output, output_ref.to(output.dtype), rtol=1e-3, atol=1e-3
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||
reason="CUTLASS W4A8 is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("shape",
|
||||
MNK_SHAPES,
|
||||
ids=lambda x: "x".join(str(v) for v in x))
|
||||
@pytest.mark.skipif(
|
||||
not IS_SUPPORTED_BY_GPU, reason="CUTLASS W4A8 is not supported on this GPU type."
|
||||
)
|
||||
@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x))
|
||||
@pytest.mark.parametrize("types", TEST_TYPES)
|
||||
@pytest.mark.parametrize("schedule", SCHEDULES)
|
||||
def test_cutlass_w4a8(shape, types: TypeConfig, schedule):
|
||||
@@ -201,7 +231,6 @@ def test_cutlass_w4a8(shape, types: TypeConfig, schedule):
|
||||
|
||||
# Test to make sure cuda graphs work
|
||||
class W4A8Layer(torch.nn.Module):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.kwargs = kwargs
|
||||
@@ -210,8 +239,9 @@ class W4A8Layer(torch.nn.Module):
|
||||
return ops.cutlass_w4a8_mm(a=a, **self.kwargs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||
reason="CUTLASS W4A8 is not supported on this GPU type.")
|
||||
@pytest.mark.skipif(
|
||||
not IS_SUPPORTED_BY_GPU, reason="CUTLASS W4A8 is not supported on this GPU type."
|
||||
)
|
||||
def test_w4a8_cuda_graph():
|
||||
m, n, k = 512, 4096, 4096
|
||||
|
||||
@@ -224,10 +254,11 @@ def test_w4a8_cuda_graph():
|
||||
zero_points = False
|
||||
|
||||
w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack(
|
||||
a.dtype, b.to(torch.float16), wtype, stype, group_size, zero_points)
|
||||
a.dtype, b.to(torch.float16), wtype, stype, group_size, zero_points
|
||||
)
|
||||
|
||||
w_tok_s = torch.randn((m, ), device='cuda', dtype=torch.float32)
|
||||
w_ch_s = torch.ones((n, ), device='cuda', dtype=torch.float32)
|
||||
w_tok_s = torch.randn((m,), device="cuda", dtype=torch.float32)
|
||||
w_ch_s = torch.ones((n,), device="cuda", dtype=torch.float32)
|
||||
|
||||
# Construct a trivial model with a single layer that calls the kernel
|
||||
model = W4A8Layer(
|
||||
@@ -244,7 +275,8 @@ def test_w4a8_cuda_graph():
|
||||
w_tok_s.unsqueeze(1),
|
||||
w_ch_s.unsqueeze(0),
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=True)
|
||||
use_fast_accum=True,
|
||||
)
|
||||
|
||||
# Run the model with a cuda graph
|
||||
stream = torch.cuda.Stream()
|
||||
|
||||
@@ -2,8 +2,12 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
from nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX,
|
||||
convert_swizzled_to_linear, dequantize_nvfp4_to_dtype)
|
||||
from nvfp4_utils import (
|
||||
FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
convert_swizzled_to_linear,
|
||||
dequantize_nvfp4_to_dtype,
|
||||
)
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
@@ -41,18 +45,12 @@ def get_ref_results(
|
||||
_, m_k = a_fp4.shape
|
||||
_, n_k = b_fp4.shape
|
||||
assert m_k == n_k
|
||||
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
|
||||
a_sf,
|
||||
a_global_scale,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
block_size=block_size)
|
||||
b_in_dtype = dequantize_nvfp4_to_dtype(b_fp4,
|
||||
b_sf,
|
||||
b_global_scale,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
block_size=block_size)
|
||||
a_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size
|
||||
)
|
||||
b_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size
|
||||
)
|
||||
return torch.matmul(a_in_dtype, b_in_dtype.t())
|
||||
|
||||
|
||||
@@ -72,8 +70,7 @@ def test_flashinfer_nvfp4_gemm(
|
||||
autotune: bool,
|
||||
) -> None:
|
||||
if backend == "trtllm" and dtype == torch.float16:
|
||||
pytest.skip(
|
||||
"Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations")
|
||||
pytest.skip("Only torch.bfloat16 is supported for TRTLLM FP4 GEMM operations")
|
||||
|
||||
current_platform.seed_everything(seed)
|
||||
m, n, packed_k = shape
|
||||
@@ -82,10 +79,12 @@ def test_flashinfer_nvfp4_gemm(
|
||||
a_dtype = torch.randn((m, k), dtype=dtype, device=device)
|
||||
b_dtype = torch.randn((n, k), dtype=dtype, device=device)
|
||||
|
||||
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||
torch.amax(a_dtype.flatten(), dim=-1)).to(torch.float32)
|
||||
b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||
torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32)
|
||||
a_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
b_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
alpha = 1.0 / (a_global_scale * b_global_scale)
|
||||
# ops.scaled_fp4_quant returns swizzled scales, while weights
|
||||
# from checkpoints are in linear scales.
|
||||
@@ -113,14 +112,18 @@ def test_flashinfer_nvfp4_gemm(
|
||||
|
||||
if backend == "trtllm":
|
||||
epilogue_tile_m = 128
|
||||
b_fp4 = flashinfer.shuffle_matrix_a(b_fp4.view(torch.uint8),
|
||||
epilogue_tile_m)
|
||||
b_fp4 = flashinfer.shuffle_matrix_a(b_fp4.view(torch.uint8), epilogue_tile_m)
|
||||
|
||||
b_scale_interleaved = convert_swizzled_to_linear(
|
||||
b_scale_interleaved, n, k, block_size)
|
||||
b_scale_interleaved = (flashinfer.shuffle_matrix_sf_a(
|
||||
b_scale_interleaved.view(torch.uint8), epilogue_tile_m).reshape(
|
||||
b_scale_interleaved.shape).view(torch.float8_e4m3fn))
|
||||
b_scale_interleaved, n, k, block_size
|
||||
)
|
||||
b_scale_interleaved = (
|
||||
flashinfer.shuffle_matrix_sf_a(
|
||||
b_scale_interleaved.view(torch.uint8), epilogue_tile_m
|
||||
)
|
||||
.reshape(b_scale_interleaved.shape)
|
||||
.view(torch.float8_e4m3fn)
|
||||
)
|
||||
|
||||
with flashinfer.autotune(autotune):
|
||||
out = flashinfer_scaled_fp4_mm(
|
||||
@@ -133,7 +136,4 @@ def test_flashinfer_nvfp4_gemm(
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out,
|
||||
expected_out.to(dtype=dtype),
|
||||
atol=1e-1,
|
||||
rtol=1e-1)
|
||||
torch.testing.assert_close(out, expected_out.to(dtype=dtype), atol=1e-1, rtol=1e-1)
|
||||
|
||||
@@ -9,8 +9,7 @@ from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(
|
||||
reason=
|
||||
"Flashinfer FP8 gemms requires compute capability of 10.0 or above.",
|
||||
reason="Flashinfer FP8 gemms requires compute capability of 10.0 or above.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
@@ -53,7 +52,7 @@ def test_flashinfer_fp8_gemm(
|
||||
).to(dtype=dtype)
|
||||
|
||||
if use_bias:
|
||||
bias = torch.randn((n, ), dtype=dtype, device=device)
|
||||
bias = torch.randn((n,), dtype=dtype, device=device)
|
||||
expected_out = expected_out + bias
|
||||
else:
|
||||
bias = None
|
||||
|
||||
@@ -5,9 +5,11 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from tests.kernels.quant_utils import (FP8_DTYPE,
|
||||
ref_dynamic_per_tensor_fp8_quant,
|
||||
ref_dynamic_per_token_quant)
|
||||
from tests.kernels.quant_utils import (
|
||||
FP8_DTYPE,
|
||||
ref_dynamic_per_tensor_fp8_quant,
|
||||
ref_dynamic_per_token_quant,
|
||||
)
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -18,23 +20,25 @@ SCALE_UBS = [True, False]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
def opcheck_fp8_quant(output,
|
||||
input,
|
||||
scale=None,
|
||||
scale_ub=None,
|
||||
use_per_token_if_dynamic=False):
|
||||
def opcheck_fp8_quant(
|
||||
output, input, scale=None, scale_ub=None, use_per_token_if_dynamic=False
|
||||
):
|
||||
if scale is not None:
|
||||
opcheck(torch.ops._C.static_scaled_fp8_quant, (output, input, scale))
|
||||
elif use_per_token_if_dynamic:
|
||||
scale = torch.empty((input.shape[0], 1),
|
||||
device=input.device,
|
||||
dtype=torch.float32)
|
||||
opcheck(torch.ops._C.dynamic_per_token_scaled_fp8_quant,
|
||||
(output, input, scale, scale_ub))
|
||||
scale = torch.empty(
|
||||
(input.shape[0], 1), device=input.device, dtype=torch.float32
|
||||
)
|
||||
opcheck(
|
||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant,
|
||||
(output, input, scale, scale_ub),
|
||||
)
|
||||
else:
|
||||
scale = torch.empty((input.numel() // input.shape[-1], 1),
|
||||
device=input.device,
|
||||
dtype=torch.float32)
|
||||
scale = torch.empty(
|
||||
(input.numel() // input.shape[-1], 1),
|
||||
device=input.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
opcheck(torch.ops._C.dynamic_scaled_fp8_quant, (output, input, scale))
|
||||
|
||||
|
||||
@@ -44,30 +48,29 @@ def opcheck_fp8_quant(output,
|
||||
@pytest.mark.parametrize("scale_ub", SCALE_UBS)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, scale_ub: bool,
|
||||
seed: int) -> None:
|
||||
def test_dynamic_per_token_fp8_quant(
|
||||
num_tokens: int, hidden_size: int, dtype: torch.dtype, scale_ub: bool, seed: int
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
|
||||
device="cuda") + 1e-6 # avoid nans
|
||||
x = (
|
||||
torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") + 1e-6
|
||||
) # avoid nans
|
||||
|
||||
scale_ub = torch.mean(x).to(dtype=torch.float32, device='cuda') \
|
||||
if scale_ub else None
|
||||
scale_ub = (
|
||||
torch.mean(x).to(dtype=torch.float32, device="cuda") if scale_ub else None
|
||||
)
|
||||
ref_out, ref_scales = ref_dynamic_per_token_quant(x, FP8_DTYPE, scale_ub)
|
||||
ops_out, ops_scales = ops.scaled_fp8_quant(x,
|
||||
scale_ub=scale_ub,
|
||||
use_per_token_if_dynamic=True)
|
||||
ops_out, ops_scales = ops.scaled_fp8_quant(
|
||||
x, scale_ub=scale_ub, use_per_token_if_dynamic=True
|
||||
)
|
||||
|
||||
torch.testing.assert_close(ref_scales, ops_scales)
|
||||
torch.testing.assert_close(ref_out.to(dtype=torch.float32),
|
||||
ops_out.to(dtype=torch.float32))
|
||||
torch.testing.assert_close(
|
||||
ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)
|
||||
)
|
||||
|
||||
opcheck_fp8_quant(ops_out,
|
||||
x,
|
||||
None,
|
||||
scale_ub,
|
||||
use_per_token_if_dynamic=True)
|
||||
opcheck_fp8_quant(ops_out, x, None, scale_ub, use_per_token_if_dynamic=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@@ -75,8 +78,9 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int) -> None:
|
||||
def test_dynamic_per_tensor_fp8_quant(
|
||||
num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||
@@ -85,8 +89,9 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
|
||||
ops_out, ops_scale = ops.scaled_fp8_quant(x)
|
||||
|
||||
torch.testing.assert_close(ref_scale, ops_scale)
|
||||
torch.testing.assert_close(ref_out.to(dtype=torch.float32),
|
||||
ops_out.to(dtype=torch.float32))
|
||||
torch.testing.assert_close(
|
||||
ref_out.to(dtype=torch.float32), ops_out.to(dtype=torch.float32)
|
||||
)
|
||||
|
||||
opcheck_fp8_quant(ops_out, x)
|
||||
|
||||
|
||||
@@ -6,8 +6,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@@ -18,13 +17,14 @@ from vllm.platforms import current_platform
|
||||
(64, 1024, 64), # Medium
|
||||
(128, 2048, 128), # Large
|
||||
(8, 513, 64), # Non-divisible (native only)
|
||||
])
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("seed", [42])
|
||||
@pytest.mark.parametrize("use_ue8m0", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
|
||||
group_size: int, seed: int,
|
||||
use_ue8m0: bool) -> None:
|
||||
def test_quantfp8_group_functionality(
|
||||
batch_size: int, hidden_dim: int, group_size: int, seed: int, use_ue8m0: bool
|
||||
) -> None:
|
||||
"""Test QuantFP8 group quantization with various configurations.
|
||||
|
||||
Tests both CUDA and native implementations, column-major scales,
|
||||
@@ -32,16 +32,17 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
|
||||
"""
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
x = torch.randn(
|
||||
(batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
|
||||
x = torch.randn((batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
|
||||
expected_num_groups = (hidden_dim + group_size - 1) // group_size
|
||||
is_divisible = hidden_dim % group_size == 0
|
||||
|
||||
group_shape = GroupShape(1, group_size)
|
||||
quant_op = QuantFP8(static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=use_ue8m0)
|
||||
quant_op = QuantFP8(
|
||||
static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=use_ue8m0,
|
||||
)
|
||||
|
||||
# 1. Test native implementation (always available)
|
||||
x_quant_native, scales_native = quant_op.forward_native(x.clone())
|
||||
@@ -49,10 +50,12 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
|
||||
assert scales_native.shape == (batch_size, expected_num_groups)
|
||||
|
||||
# 2. Test column-major scales configuration
|
||||
quant_op_col = QuantFP8(static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=use_ue8m0)
|
||||
quant_op_col = QuantFP8(
|
||||
static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=use_ue8m0,
|
||||
)
|
||||
_, scales_col = quant_op_col.forward_native(x.clone())
|
||||
assert scales_col.shape == (batch_size, expected_num_groups)
|
||||
assert scales_col.stride(0) == 1
|
||||
@@ -86,41 +89,48 @@ def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
|
||||
|
||||
# Test with 3D input
|
||||
batch1, batch2, hidden_dim = 4, 8, 1024
|
||||
x_3d = torch.randn(
|
||||
(batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8
|
||||
x_3d = (
|
||||
torch.randn((batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda")
|
||||
* 8
|
||||
)
|
||||
|
||||
group_shape = GroupShape(1, group_size)
|
||||
quant_op = QuantFP8(static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=use_ue8m0)
|
||||
quant_op = QuantFP8(
|
||||
static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=use_ue8m0,
|
||||
)
|
||||
|
||||
x_quant, scales = quant_op.forward_native(x_3d.clone())
|
||||
assert x_quant.shape == x_3d.shape
|
||||
assert scales.shape == (batch1, batch2, hidden_dim // group_size)
|
||||
|
||||
# Test column_major_scales with multi-dim
|
||||
quant_op_col = QuantFP8(static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=use_ue8m0)
|
||||
quant_op_col = QuantFP8(
|
||||
static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=True,
|
||||
use_ue8m0=use_ue8m0,
|
||||
)
|
||||
_, scales_col = quant_op_col.forward_native(x_3d.clone())
|
||||
assert scales_col.shape == (batch1, batch2, hidden_dim // group_size)
|
||||
|
||||
# Test with 4D input
|
||||
batch1, batch2, batch3, hidden_dim = 2, 3, 4, 256
|
||||
x_4d = torch.randn((batch1, batch2, batch3, hidden_dim),
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda") * 8
|
||||
x_4d = (
|
||||
torch.randn(
|
||||
(batch1, batch2, batch3, hidden_dim), dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
* 8
|
||||
)
|
||||
|
||||
x_quant_4d, scales_4d = quant_op.forward_native(x_4d.clone())
|
||||
assert x_quant_4d.shape == x_4d.shape
|
||||
assert scales_4d.shape == (batch1, batch2, batch3,
|
||||
hidden_dim // group_size)
|
||||
assert scales_4d.shape == (batch1, batch2, batch3, hidden_dim // group_size)
|
||||
|
||||
_, scales_4d_col = quant_op_col.forward_native(x_4d.clone())
|
||||
assert scales_4d_col.shape == (batch1, batch2, hidden_dim // group_size,
|
||||
batch3)
|
||||
assert scales_4d_col.shape == (batch1, batch2, hidden_dim // group_size, batch3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", [42])
|
||||
@@ -132,30 +142,24 @@ def test_quantfp8_group_edge_cases(seed: int) -> None:
|
||||
group_size = 64
|
||||
|
||||
# Test with single group (group_size >= hidden_dim)
|
||||
x_small = torch.randn(
|
||||
(batch_size, 32), dtype=torch.bfloat16, device="cuda") * 8
|
||||
x_small = torch.randn((batch_size, 32), dtype=torch.bfloat16, device="cuda") * 8
|
||||
group_shape = GroupShape(1, group_size)
|
||||
quant_op = QuantFP8(static=False,
|
||||
group_shape=group_shape,
|
||||
column_major_scales=False)
|
||||
quant_op = QuantFP8(
|
||||
static=False, group_shape=group_shape, column_major_scales=False
|
||||
)
|
||||
|
||||
x_quant_small, scales_small = quant_op.forward_native(x_small.clone())
|
||||
assert x_quant_small.shape == x_small.shape
|
||||
assert scales_small.shape == (batch_size, 1)
|
||||
|
||||
# Test with zero inputs
|
||||
x_zero = torch.zeros((batch_size, 256),
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda")
|
||||
x_zero = torch.zeros((batch_size, 256), dtype=torch.bfloat16, device="cuda")
|
||||
x_quant_zero, scales_zero = quant_op.forward_native(x_zero.clone())
|
||||
assert x_quant_zero.shape == x_zero.shape
|
||||
assert (scales_zero > 0).all(), "Scales should be clamped to minimum"
|
||||
|
||||
# Test very large values
|
||||
x_large = torch.full((batch_size, 256),
|
||||
1000.0,
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda")
|
||||
x_large = torch.full((batch_size, 256), 1000.0, dtype=torch.bfloat16, device="cuda")
|
||||
x_quant_large, scales_large = quant_op.forward_native(x_large.clone())
|
||||
assert x_quant_large.shape == x_large.shape
|
||||
# FP8 max is typically 448 or 224, so scales should be > 1
|
||||
|
||||
@@ -13,33 +13,42 @@ from vllm import _custom_ops as ops # noqa: F401
|
||||
def test_ggml_opcheck(quant_type):
|
||||
block_size, type_size = gguf.GGML_QUANT_SIZES[quant_type]
|
||||
shape = [256, 1152]
|
||||
qweight = torch.randint(0, 100, shape, device='cuda', dtype=torch.uint8)
|
||||
qweight = torch.randint(0, 100, shape, device="cuda", dtype=torch.uint8)
|
||||
m = qweight.shape[0]
|
||||
n = qweight.shape[1] // type_size * block_size
|
||||
opcheck(torch.ops._C.ggml_dequantize,
|
||||
(qweight, quant_type, m, n, torch.float16))
|
||||
opcheck(torch.ops._C.ggml_dequantize, (qweight, quant_type, m, n, torch.float16))
|
||||
|
||||
x = torch.rand((m, 512), device='cuda', dtype=torch.float16)
|
||||
opcheck(torch.ops._C.ggml_mul_mat_a8,
|
||||
(qweight, x, quant_type, qweight.shape[0]))
|
||||
opcheck(torch.ops._C.ggml_mul_mat_vec_a8,
|
||||
(qweight, x, quant_type, qweight.shape[0]))
|
||||
x = torch.rand((m, 512), device="cuda", dtype=torch.float16)
|
||||
opcheck(torch.ops._C.ggml_mul_mat_a8, (qweight, x, quant_type, qweight.shape[0]))
|
||||
opcheck(
|
||||
torch.ops._C.ggml_mul_mat_vec_a8, (qweight, x, quant_type, qweight.shape[0])
|
||||
)
|
||||
|
||||
shape = [256, 1024, 336]
|
||||
qweight = torch.randint(0, 100, shape, device='cuda', dtype=torch.uint8)
|
||||
x = torch.rand((1, 1024), device='cuda', dtype=torch.float16)
|
||||
sorted_token_ids = torch.arange(776, device='cuda')
|
||||
expert_ids = torch.randint(0, 256, (194, ), device='cuda')
|
||||
num_tokens_post_padded = torch.tensor([1],
|
||||
dtype=torch.int64,
|
||||
device='cuda')
|
||||
qweight = torch.randint(0, 100, shape, device="cuda", dtype=torch.uint8)
|
||||
x = torch.rand((1, 1024), device="cuda", dtype=torch.float16)
|
||||
sorted_token_ids = torch.arange(776, device="cuda")
|
||||
expert_ids = torch.randint(0, 256, (194,), device="cuda")
|
||||
num_tokens_post_padded = torch.tensor([1], dtype=torch.int64, device="cuda")
|
||||
|
||||
opcheck(torch.ops._C.ggml_moe_a8,
|
||||
(x, qweight, sorted_token_ids, expert_ids, num_tokens_post_padded,
|
||||
quant_type, qweight.shape[0], 1, x.shape[0]))
|
||||
opcheck(
|
||||
torch.ops._C.ggml_moe_a8,
|
||||
(
|
||||
x,
|
||||
qweight,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
quant_type,
|
||||
qweight.shape[0],
|
||||
1,
|
||||
x.shape[0],
|
||||
),
|
||||
)
|
||||
|
||||
topk_ids = torch.zeros((1, 1), device='cuda', dtype=torch.int32)
|
||||
topk_ids = torch.zeros((1, 1), device="cuda", dtype=torch.int32)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.ggml_moe_a8_vec,
|
||||
(x, qweight, topk_ids, 1, quant_type, qweight.shape[0], x.shape[0]))
|
||||
(x, qweight, topk_ids, 1, quant_type, qweight.shape[0], x.shape[0]),
|
||||
)
|
||||
|
||||
@@ -18,8 +18,8 @@ GGUF_SAMPLE_MOE = snapshot_download("SzymonOzog/test-gguf-moe-sample")
|
||||
|
||||
|
||||
def get_gguf_sample_tensors(
|
||||
hidden_size: int,
|
||||
quant_type: GGMLQuantizationType) -> list[ReaderTensor]:
|
||||
hidden_size: int, quant_type: GGMLQuantizationType
|
||||
) -> list[ReaderTensor]:
|
||||
sample_dir = GGUF_SAMPLE
|
||||
filename = f"Quant_{quant_type.name}_{hidden_size}.gguf"
|
||||
sample_file = Path(sample_dir) / filename
|
||||
@@ -27,8 +27,8 @@ def get_gguf_sample_tensors(
|
||||
|
||||
|
||||
def get_gguf_MoE_tensors(
|
||||
hidden_size: int,
|
||||
quant_type: GGMLQuantizationType) -> list[ReaderTensor]:
|
||||
hidden_size: int, quant_type: GGMLQuantizationType
|
||||
) -> list[ReaderTensor]:
|
||||
sample_dir = GGUF_SAMPLE_MOE
|
||||
filename = f"Quant_{quant_type.name}_{hidden_size}.gguf"
|
||||
sample_file = Path(sample_dir) / filename
|
||||
@@ -68,17 +68,20 @@ QUANT_TYPES = [
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
|
||||
@torch.inference_mode()
|
||||
def test_dequantize(hidden_size: int, dtype: torch.dtype,
|
||||
quant_type: GGMLQuantizationType):
|
||||
def test_dequantize(
|
||||
hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType
|
||||
):
|
||||
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
|
||||
for tensor in tensors:
|
||||
shape_str = tensor.name.split("_")[-1]
|
||||
shape = map(int, shape_str.split("x"))
|
||||
|
||||
ref_output = torch.tensor(dequantize(tensor.data, quant_type),
|
||||
device="cuda").to(dtype)
|
||||
output = ops.ggml_dequantize(torch.tensor(tensor.data, device="cuda"),
|
||||
quant_type, *list(shape), dtype)
|
||||
ref_output = torch.tensor(
|
||||
dequantize(tensor.data, quant_type), device="cuda"
|
||||
).to(dtype)
|
||||
output = ops.ggml_dequantize(
|
||||
torch.tensor(tensor.data, device="cuda"), quant_type, *list(shape), dtype
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=4e-2)
|
||||
|
||||
@@ -87,20 +90,21 @@ def test_dequantize(hidden_size: int, dtype: torch.dtype,
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
|
||||
@torch.inference_mode()
|
||||
def test_mmvq(hidden_size: int, dtype: torch.dtype,
|
||||
quant_type: GGMLQuantizationType):
|
||||
def test_mmvq(hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
|
||||
x = torch.rand((1, hidden_size), dtype=dtype, device="cuda")
|
||||
for tensor in tensors:
|
||||
weight = torch.tensor(dequantize(tensor.data, quant_type),
|
||||
device="cuda").to(dtype)
|
||||
weight = torch.tensor(dequantize(tensor.data, quant_type), device="cuda").to(
|
||||
dtype
|
||||
)
|
||||
ref_output = x @ weight.T
|
||||
|
||||
qweight = torch.tensor(tensor.data, device="cuda")
|
||||
output = ops.ggml_mul_mat_vec_a8(qweight, x, quant_type,
|
||||
qweight.shape[0]).to(dtype)
|
||||
output = ops.ggml_mul_mat_vec_a8(qweight, x, quant_type, qweight.shape[0]).to(
|
||||
dtype
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)
|
||||
|
||||
@@ -121,17 +125,23 @@ def test_mmvq(hidden_size: int, dtype: torch.dtype,
|
||||
GGMLQuantizationType.Q4_0,
|
||||
GGMLQuantizationType.Q5_0,
|
||||
GGMLQuantizationType.Q8_0,
|
||||
])
|
||||
],
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
|
||||
quant_type: GGMLQuantizationType):
|
||||
def test_mmq(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
quant_type: GGMLQuantizationType,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
tensors = get_gguf_sample_tensors(hidden_size, quant_type)
|
||||
x = torch.rand((num_tokens, hidden_size), dtype=dtype, device="cuda")
|
||||
for tensor in tensors:
|
||||
weight = torch.tensor(dequantize(tensor.data, quant_type),
|
||||
device="cuda").to(dtype)
|
||||
weight = torch.tensor(dequantize(tensor.data, quant_type), device="cuda").to(
|
||||
dtype
|
||||
)
|
||||
ref_output = x @ weight.T
|
||||
|
||||
qweight = torch.tensor(tensor.data, device="cuda")
|
||||
@@ -141,10 +151,9 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
|
||||
# bfloat16 tends to accumulate and can greatly inflate rtol
|
||||
# since outputs are also very close to 0
|
||||
rtols = {torch.half: 1e-1, torch.bfloat16: 1e4, torch.float: 2e1}
|
||||
torch.testing.assert_close(output,
|
||||
ref_output,
|
||||
atol=atols[dtype],
|
||||
rtol=rtols[dtype])
|
||||
torch.testing.assert_close(
|
||||
output, ref_output, atol=atols[dtype], rtol=rtols[dtype]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@@ -153,35 +162,46 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("quant_type", QUANT_TYPES)
|
||||
@torch.inference_mode()
|
||||
def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype,
|
||||
quant_type: GGMLQuantizationType, top_k: int):
|
||||
def test_moe(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
quant_type: GGMLQuantizationType,
|
||||
top_k: int,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
H, E = 1024, 256
|
||||
|
||||
x = torch.rand((num_tokens, H), dtype=dtype, device="cuda")
|
||||
|
||||
topk_weights = torch.rand(num_tokens, top_k, device="cuda", dtype=dtype)
|
||||
topk_ids = torch.randint(0,
|
||||
E, (num_tokens, top_k),
|
||||
device="cuda",
|
||||
dtype=torch.int32)
|
||||
topk_ids = torch.randint(
|
||||
0, E, (num_tokens, top_k), device="cuda", dtype=torch.int32
|
||||
)
|
||||
|
||||
tensors = get_gguf_MoE_tensors(hidden_size, quant_type)
|
||||
|
||||
w13 = tensors[0]
|
||||
w2 = tensors[1]
|
||||
|
||||
w13_dequant = torch.tensor(dequantize(w13.data, quant_type),
|
||||
device="cuda").to(dtype)
|
||||
w13_dequant = torch.tensor(dequantize(w13.data, quant_type), device="cuda").to(
|
||||
dtype
|
||||
)
|
||||
|
||||
w2_dequant = torch.tensor(dequantize(w2.data, quant_type),
|
||||
device="cuda").to(dtype)
|
||||
w2_dequant = torch.tensor(dequantize(w2.data, quant_type), device="cuda").to(dtype)
|
||||
|
||||
output = _fused_moe_gguf(x, torch.tensor(w13.data, device="cuda"),
|
||||
torch.tensor(w2.data,
|
||||
device="cuda"), topk_weights,
|
||||
topk_ids, quant_type, quant_type, "silu")
|
||||
output = _fused_moe_gguf(
|
||||
x,
|
||||
torch.tensor(w13.data, device="cuda"),
|
||||
torch.tensor(w2.data, device="cuda"),
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type,
|
||||
quant_type,
|
||||
"silu",
|
||||
)
|
||||
|
||||
ref_output = fused_experts(x, w13_dequant, w2_dequant, topk_weights,
|
||||
topk_ids).reshape(output.shape)
|
||||
ref_output = fused_experts(
|
||||
x, w13_dequant, w2_dequant, topk_weights, topk_ids
|
||||
).reshape(output.shape)
|
||||
torch.testing.assert_close(output, ref_output, atol=1, rtol=1e-1)
|
||||
|
||||
@@ -8,25 +8,22 @@ from vllm import _custom_ops as ops # noqa: F401
|
||||
|
||||
|
||||
def test_gptq_shuffle_opcheck():
|
||||
weight = torch.randint(-2000000,
|
||||
2000000, (1792, 4096),
|
||||
device='cuda',
|
||||
dtype=torch.int32)
|
||||
perm = torch.empty((0, ), device='cuda', dtype=torch.int32)
|
||||
weight = torch.randint(
|
||||
-2000000, 2000000, (1792, 4096), device="cuda", dtype=torch.int32
|
||||
)
|
||||
perm = torch.empty((0,), device="cuda", dtype=torch.int32)
|
||||
bit = 4
|
||||
opcheck(torch.ops._C.gptq_shuffle, (weight, perm, bit))
|
||||
|
||||
|
||||
def test_gptq_gemm_opcheck():
|
||||
a = torch.rand((240, 4096), device='cuda', dtype=torch.float16)
|
||||
weight = torch.randint(-2000000,
|
||||
2000000, (512, 6144),
|
||||
device='cuda',
|
||||
dtype=torch.int32)
|
||||
zeros = torch.zeros((32, 768), device='cuda', dtype=torch.int32)
|
||||
scales = torch.rand((32, 6144), device='cuda', dtype=torch.float16)
|
||||
idx = torch.empty((0, ), device='cuda', dtype=torch.int32)
|
||||
a = torch.rand((240, 4096), device="cuda", dtype=torch.float16)
|
||||
weight = torch.randint(
|
||||
-2000000, 2000000, (512, 6144), device="cuda", dtype=torch.int32
|
||||
)
|
||||
zeros = torch.zeros((32, 768), device="cuda", dtype=torch.int32)
|
||||
scales = torch.rand((32, 6144), device="cuda", dtype=torch.float16)
|
||||
idx = torch.empty((0,), device="cuda", dtype=torch.int32)
|
||||
use_exllama = True
|
||||
bit = 4
|
||||
opcheck(torch.ops._C.gptq_gemm,
|
||||
(a, weight, zeros, scales, idx, use_exllama, bit))
|
||||
opcheck(torch.ops._C.gptq_gemm, (a, weight, zeros, scales, idx, use_exllama, bit))
|
||||
|
||||
@@ -15,7 +15,8 @@ from vllm import _custom_ops as ops
|
||||
def test_hadacore(batch_size, hidden_dim, dtype=torch.bfloat16, device="cuda"):
|
||||
x = torch.eye(hidden_dim, dtype=dtype, device=device)
|
||||
hadamard = deterministic_hadamard_matrix(
|
||||
hidden_dim, dtype=torch.float64, device="cuda") / math.sqrt(hidden_dim)
|
||||
hidden_dim, dtype=torch.float64, device="cuda"
|
||||
) / math.sqrt(hidden_dim)
|
||||
|
||||
y = ops.hadacore_transform(x.clone())
|
||||
y_true = (x.to(hadamard.dtype) @ hadamard.T).to(y.dtype)
|
||||
|
||||
@@ -11,12 +11,12 @@ from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
per_token_quant_int8)
|
||||
per_token_quant_int8,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.get_device_capability() < (7, 0):
|
||||
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
|
||||
allow_module_level=True)
|
||||
pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True)
|
||||
|
||||
|
||||
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
|
||||
@@ -26,14 +26,13 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
|
||||
B = B.to(torch.float32)
|
||||
|
||||
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
|
||||
assert B.ndim == 2 and B.is_contiguous(
|
||||
), "B must be a 2D contiguous tensor"
|
||||
assert B.ndim == 2 and B.is_contiguous(), "B must be a 2D contiguous tensor"
|
||||
|
||||
# Reshape input
|
||||
M = A.numel() // A.shape[-1]
|
||||
B = B.t() # Transpose weight matrix
|
||||
N, K = B.shape
|
||||
origin_C_shape = A.shape[:-1] + (K, )
|
||||
origin_C_shape = A.shape[:-1] + (K,)
|
||||
A = A.reshape(M, N)
|
||||
|
||||
# As is per-token [M, 1], Bs is per-column [1, K]
|
||||
@@ -43,8 +42,7 @@ def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
|
||||
return C.reshape(origin_C_shape).to(output_dtype)
|
||||
|
||||
|
||||
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, topk_weight,
|
||||
topk_ids):
|
||||
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, topk_weight, topk_ids):
|
||||
"""This function performs fused moe with per-column int8 quantization
|
||||
using native torch."""
|
||||
|
||||
@@ -66,25 +64,22 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk, topk_weight,
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
# First MLP layer: note that a_s is now per-token
|
||||
inter_out = native_w8a8_per_token_matmul(a_q[mask],
|
||||
w1[i],
|
||||
a_s[mask],
|
||||
w1_s[i],
|
||||
output_dtype=a.dtype)
|
||||
inter_out = native_w8a8_per_token_matmul(
|
||||
a_q[mask], w1[i], a_s[mask], w1_s[i], output_dtype=a.dtype
|
||||
)
|
||||
# Activation function
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
# Quantize activation output with per-token
|
||||
act_out_q, act_out_s = per_token_quant_int8(act_out)
|
||||
|
||||
# Second MLP layer
|
||||
out[mask] = native_w8a8_per_token_matmul(act_out_q,
|
||||
w2[i],
|
||||
act_out_s,
|
||||
w2_s[i],
|
||||
output_dtype=a.dtype)
|
||||
out[mask] = native_w8a8_per_token_matmul(
|
||||
act_out_q, w2[i], act_out_s, w2_s[i], output_dtype=a.dtype
|
||||
)
|
||||
# Apply routing weights and sum
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
return (
|
||||
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
@@ -102,8 +97,10 @@ TOP_KS = [2, 6]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N, K, E, topk, dtype, seed",
|
||||
itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS))
|
||||
@pytest.mark.parametrize(
|
||||
"M, N, K, E, topk, dtype, seed",
|
||||
itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS),
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
@@ -130,8 +127,9 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weights, topk_ids = torch.topk(score, topk)
|
||||
|
||||
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, topk,
|
||||
topk_weights, topk_ids)
|
||||
ref_out = torch_w8a8_per_column_moe(
|
||||
a, w1, w2, w1_s, w2_s, topk, topk_weights, topk_ids
|
||||
)
|
||||
|
||||
quant_config = FusedMoEQuantConfig.make(
|
||||
torch.int8,
|
||||
@@ -151,7 +149,7 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
|
||||
)
|
||||
|
||||
# Check results
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||
rel_diff = torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))
|
||||
) / torch.mean(torch.abs(ref_out.to(torch.float32)))
|
||||
assert rel_diff < 0.05
|
||||
|
||||
@@ -18,26 +18,24 @@ SCALE = [0.1, 2.1]
|
||||
|
||||
def opcheck_int8_quant_static(output, input, scale, azp=None):
|
||||
if azp is None:
|
||||
opcheck(torch.ops._C.static_scaled_int8_quant,
|
||||
(output, input, scale, None))
|
||||
opcheck(torch.ops._C.static_scaled_int8_quant, (output, input, scale, None))
|
||||
else:
|
||||
opcheck(torch.ops._C.static_scaled_int8_quant,
|
||||
(output, input, scale, azp))
|
||||
opcheck(torch.ops._C.static_scaled_int8_quant, (output, input, scale, azp))
|
||||
|
||||
|
||||
def opcheck_int8_quant_dynamic(output, input, symmetric=True):
|
||||
scale = torch.empty((input.numel() // input.shape[-1], 1),
|
||||
device=input.device,
|
||||
dtype=torch.float32)
|
||||
scale = torch.empty(
|
||||
(input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32
|
||||
)
|
||||
if symmetric:
|
||||
opcheck(torch.ops._C.dynamic_scaled_int8_quant,
|
||||
(output, input, scale, None))
|
||||
opcheck(torch.ops._C.dynamic_scaled_int8_quant, (output, input, scale, None))
|
||||
else:
|
||||
azp = torch.empty((input.numel() // input.shape[-1], 1),
|
||||
device=input.device,
|
||||
dtype=torch.int32)
|
||||
opcheck(torch.ops._C.dynamic_scaled_int8_quant,
|
||||
(output, input, scale, azp))
|
||||
azp = torch.empty(
|
||||
(input.numel() // input.shape[-1], 1),
|
||||
device=input.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
opcheck(torch.ops._C.dynamic_scaled_int8_quant, (output, input, scale, azp))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@@ -45,8 +43,9 @@ def opcheck_int8_quant_dynamic(output, input, symmetric=True):
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int) -> None:
|
||||
def test_dynamic_scaled_int8_quant(
|
||||
num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
|
||||
@@ -68,30 +67,31 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int) -> None:
|
||||
def test_dynamic_scaled_int8_azp_quant(
|
||||
num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
|
||||
device="cuda") * 1000 - 300
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - 300
|
||||
|
||||
x_token_max, _ = x.to(dtype=torch.float32).max(dim=1, keepdim=True)
|
||||
x_token_min, _ = x.to(dtype=torch.float32).min(dim=1, keepdim=True)
|
||||
|
||||
# calculate scale and azp, and adjust the range
|
||||
scales = (x_token_max - x_token_min) / torch.tensor(255.0)
|
||||
azps = torch.round(torch.tensor(-128.0) - x_token_min / scales).to(
|
||||
torch.int32)
|
||||
azps = torch.round(torch.tensor(-128.0) - x_token_min / scales).to(torch.int32)
|
||||
|
||||
torch_out = ((x / scales).round() + azps).clamp(
|
||||
int8_traits.min, int8_traits.max).to(torch.int8)
|
||||
assert torch_out.min() >= int8_traits.min and torch_out.max(
|
||||
) <= int8_traits.max
|
||||
torch_out = (
|
||||
((x / scales).round() + azps)
|
||||
.clamp(int8_traits.min, int8_traits.max)
|
||||
.to(torch.int8)
|
||||
)
|
||||
assert torch_out.min() >= int8_traits.min and torch_out.max() <= int8_traits.max
|
||||
|
||||
ops_out, scales_out, azp_out = scaled_int8_quant(x, symmetric=False)
|
||||
|
||||
if (not torch.allclose(scales_out, scales)):
|
||||
if not torch.allclose(scales_out, scales):
|
||||
print(torch.argmax(torch.abs(scales_out - scales)))
|
||||
torch.testing.assert_close(scales_out, scales)
|
||||
# big atol to account for rounding errors
|
||||
@@ -108,17 +108,18 @@ def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("scale", SCALE)
|
||||
@torch.inference_mode()
|
||||
def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int,
|
||||
scale: float) -> None:
|
||||
def test_static_scaled_int8_quant(
|
||||
num_tokens: int, hidden_size: int, dtype: torch.dtype, seed: int, scale: float
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
|
||||
scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda")
|
||||
|
||||
out1 = (x / scale_arg).round().clamp(int8_traits.min,
|
||||
int8_traits.max).to(torch.int8)
|
||||
out1 = (
|
||||
(x / scale_arg).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8)
|
||||
)
|
||||
out2, scale2, _ = scaled_int8_quant(x, scale_arg)
|
||||
assert scale2 is scale_arg
|
||||
|
||||
@@ -135,24 +136,28 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
@pytest.mark.parametrize("scale", SCALE)
|
||||
@pytest.mark.parametrize("azp", [-255, 54])
|
||||
@torch.inference_mode()
|
||||
def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int,
|
||||
scale: float, azp: int) -> None:
|
||||
def test_static_scaled_int8_azp_quant(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
scale: float,
|
||||
azp: int,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
|
||||
device="cuda") * 1000 - 300
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 - 300
|
||||
|
||||
out1 = ((x / scale).round() + azp).clamp(int8_traits.min,
|
||||
int8_traits.max).to(torch.int8)
|
||||
out1 = (
|
||||
((x / scale).round() + azp)
|
||||
.clamp(int8_traits.min, int8_traits.max)
|
||||
.to(torch.int8)
|
||||
)
|
||||
scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda")
|
||||
azp_arg = torch.tensor([azp], dtype=torch.int32, device="cuda")
|
||||
|
||||
out2, scale2, azp2 = scaled_int8_quant(x,
|
||||
scale_arg,
|
||||
azp_arg,
|
||||
symmetric=False)
|
||||
out2, scale2, azp2 = scaled_int8_quant(x, scale_arg, azp_arg, symmetric=False)
|
||||
assert scale2 is scale_arg
|
||||
assert azp2 is azp_arg
|
||||
|
||||
@@ -172,10 +177,7 @@ def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None:
|
||||
int32_traits = torch.iinfo(torch.int32)
|
||||
val = float(int32_traits.max if is_max else int32_traits.min)
|
||||
|
||||
x_vals = [[
|
||||
nextafter(val, inf), val + 1, val, val - 1,
|
||||
nextafter(val, -inf)
|
||||
]]
|
||||
x_vals = [[nextafter(val, inf), val + 1, val, val - 1, nextafter(val, -inf)]]
|
||||
x = torch.tensor(x_vals, dtype=torch.float32, device="cuda")
|
||||
|
||||
# The calculation in the kernel is: cast<int8>(cast<int32>(x / scale) + azp)
|
||||
|
||||
@@ -15,15 +15,16 @@ import torch
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.machete_utils import (
|
||||
query_machete_supported_group_sizes)
|
||||
query_machete_supported_group_sizes,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_rows, quantize_weights)
|
||||
pack_rows,
|
||||
quantize_weights,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
|
||||
# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
|
||||
# unit tests to a common utility function. Currently the use of
|
||||
@@ -72,29 +73,38 @@ class Tensors:
|
||||
# Ch Scales Type, Tok Scales Type)
|
||||
# NOTE: None "Scale Type" means the act type is floating point
|
||||
# None "Output Type" means the output type is the same as the act type
|
||||
TestTypeTuple = tuple[list[torch.dtype], ScalarType, Optional[torch.dtype],
|
||||
Optional[torch.dtype], bool]
|
||||
TestTypeTuple = tuple[
|
||||
list[torch.dtype], ScalarType, Optional[torch.dtype], Optional[torch.dtype], bool
|
||||
]
|
||||
TEST_TYPES = [
|
||||
# GPTQ style
|
||||
*(TypeConfig(act_type=a_type,
|
||||
weight_type=w_type,
|
||||
output_type=None,
|
||||
group_scale_type=a_type,
|
||||
group_zero_type=None,
|
||||
channel_scale_type=None,
|
||||
token_scale_type=None)
|
||||
for w_type in [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
for a_type in [torch.float16, torch.bfloat16]),
|
||||
*(
|
||||
TypeConfig(
|
||||
act_type=a_type,
|
||||
weight_type=w_type,
|
||||
output_type=None,
|
||||
group_scale_type=a_type,
|
||||
group_zero_type=None,
|
||||
channel_scale_type=None,
|
||||
token_scale_type=None,
|
||||
)
|
||||
for w_type in [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
for a_type in [torch.float16, torch.bfloat16]
|
||||
),
|
||||
# AWQ style
|
||||
*(TypeConfig(act_type=a_type,
|
||||
weight_type=w_type,
|
||||
output_type=None,
|
||||
group_scale_type=a_type,
|
||||
group_zero_type=a_type,
|
||||
channel_scale_type=None,
|
||||
token_scale_type=None)
|
||||
for w_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
for a_type in [torch.float16, torch.bfloat16]),
|
||||
*(
|
||||
TypeConfig(
|
||||
act_type=a_type,
|
||||
weight_type=w_type,
|
||||
output_type=None,
|
||||
group_scale_type=a_type,
|
||||
group_zero_type=a_type,
|
||||
channel_scale_type=None,
|
||||
token_scale_type=None,
|
||||
)
|
||||
for w_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
for a_type in [torch.float16, torch.bfloat16]
|
||||
),
|
||||
# # QQQ style
|
||||
# *(TypeConfig(act_type=torch.int8,
|
||||
# weight_type=scalar_types.uint4b8,
|
||||
@@ -133,17 +143,18 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
|
||||
return zps if zps is None else -1 * s * (zps.to(s.dtype))
|
||||
|
||||
|
||||
def group_size_valid(shape: tuple[int, int, int],
|
||||
group_size: Optional[int]) -> bool:
|
||||
def group_size_valid(shape: tuple[int, int, int], group_size: Optional[int]) -> bool:
|
||||
return group_size is None or group_size == -1 or shape[2] % group_size == 0
|
||||
|
||||
|
||||
def machete_quantize_and_pack(atype: torch.dtype,
|
||||
w: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
stype: Optional[torch.dtype],
|
||||
group_size: Optional[int],
|
||||
zero_points: bool = False):
|
||||
def machete_quantize_and_pack(
|
||||
atype: torch.dtype,
|
||||
w: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
stype: Optional[torch.dtype],
|
||||
group_size: Optional[int],
|
||||
zero_points: bool = False,
|
||||
):
|
||||
assert wtype.is_integer(), "TODO: support floating point weights"
|
||||
|
||||
w_ref, w_q, w_s, w_zp = quantize_weights(
|
||||
@@ -152,7 +163,8 @@ def machete_quantize_and_pack(atype: torch.dtype,
|
||||
group_size=group_size,
|
||||
zero_points=zero_points,
|
||||
# to match how the kernel applies zps
|
||||
ref_zero_points_after_scales=True)
|
||||
ref_zero_points_after_scales=True,
|
||||
)
|
||||
|
||||
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
|
||||
w_q = w_q.t().contiguous().t() # convert to col major
|
||||
@@ -163,15 +175,18 @@ def machete_quantize_and_pack(atype: torch.dtype,
|
||||
return w_ref, w_q_machete, w_s, w_zp
|
||||
|
||||
|
||||
def create_test_tensors(shape: tuple[int, int, int],
|
||||
types: TypeConfig,
|
||||
group_size: Optional[int],
|
||||
subset_stride_factor: Optional[int] = None) -> Tensors:
|
||||
def create_test_tensors(
|
||||
shape: tuple[int, int, int],
|
||||
types: TypeConfig,
|
||||
group_size: Optional[int],
|
||||
subset_stride_factor: Optional[int] = None,
|
||||
) -> Tensors:
|
||||
m, n, k = shape
|
||||
factor = subset_stride_factor or 1
|
||||
|
||||
print("create_test_tensors, shape:", shape, "types:", types, "group_size:",
|
||||
group_size)
|
||||
print(
|
||||
"create_test_tensors, shape:", shape, "types:", types, "group_size:", group_size
|
||||
)
|
||||
|
||||
a = rand_data((m * factor, k * factor), types.act_type, scale=3, offset=2)
|
||||
w = rand_data((k * factor, n * factor), types.act_type, scale=3, offset=1)
|
||||
@@ -186,8 +201,13 @@ def create_test_tensors(shape: tuple[int, int, int],
|
||||
w = w.to(torch.float16)
|
||||
|
||||
w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
|
||||
a.dtype, w, types.weight_type, types.group_scale_type, group_size,
|
||||
types.group_zero_type is not None)
|
||||
a.dtype,
|
||||
w,
|
||||
types.weight_type,
|
||||
types.group_scale_type,
|
||||
group_size,
|
||||
types.group_zero_type is not None,
|
||||
)
|
||||
|
||||
if not a.dtype.is_floating_point:
|
||||
aiinfo = torch.iinfo(a.dtype)
|
||||
@@ -196,35 +216,47 @@ def create_test_tensors(shape: tuple[int, int, int],
|
||||
a_ref = a.to(torch.float32)
|
||||
w_ref = w_ref.to(torch.float32)
|
||||
|
||||
w_ch_s = None if types.channel_scale_type is None else\
|
||||
rand_data((n,), types.channel_scale_type)
|
||||
w_tok_s = None if types.token_scale_type is None else\
|
||||
rand_data((m,), types.token_scale_type)
|
||||
w_ch_s = (
|
||||
None
|
||||
if types.channel_scale_type is None
|
||||
else rand_data((n,), types.channel_scale_type)
|
||||
)
|
||||
w_tok_s = (
|
||||
None
|
||||
if types.token_scale_type is None
|
||||
else rand_data((m,), types.token_scale_type)
|
||||
)
|
||||
|
||||
return Tensors(w_ref=w_ref,
|
||||
a_ref=a_ref,
|
||||
a=a,
|
||||
w_q=w_q_packed,
|
||||
w_g_s=w_s,
|
||||
w_g_zp=maybe_convert_zeropoints(w_zp, w_s),
|
||||
w_ch_s=w_ch_s,
|
||||
w_tok_s=w_tok_s)
|
||||
return Tensors(
|
||||
w_ref=w_ref,
|
||||
a_ref=a_ref,
|
||||
a=a,
|
||||
w_q=w_q_packed,
|
||||
w_g_s=w_s,
|
||||
w_g_zp=maybe_convert_zeropoints(w_zp, w_s),
|
||||
w_ch_s=w_ch_s,
|
||||
w_tok_s=w_tok_s,
|
||||
)
|
||||
|
||||
|
||||
# None stype means scales use the same dtype as a
|
||||
def machete_mm_test_helper(types: TypeConfig,
|
||||
tensors: Tensors,
|
||||
group_size: Optional[int] = None,
|
||||
schedule: Optional[str] = None):
|
||||
def machete_mm_test_helper(
|
||||
types: TypeConfig,
|
||||
tensors: Tensors,
|
||||
group_size: Optional[int] = None,
|
||||
schedule: Optional[str] = None,
|
||||
):
|
||||
output_ref = torch.matmul(tensors.a_ref, tensors.w_ref)
|
||||
output_ref_type = output_ref.dtype
|
||||
|
||||
if tensors.w_ch_s is not None:
|
||||
output_ref = (output_ref.to(tensors.w_ch_s.dtype) *
|
||||
tensors.w_ch_s.unsqueeze(0)).to(output_ref_type)
|
||||
output_ref = (
|
||||
output_ref.to(tensors.w_ch_s.dtype) * tensors.w_ch_s.unsqueeze(0)
|
||||
).to(output_ref_type)
|
||||
if tensors.w_tok_s is not None:
|
||||
output_ref = (output_ref.to(tensors.w_tok_s.dtype) *
|
||||
tensors.w_tok_s.unsqueeze(1)).to(output_ref_type)
|
||||
output_ref = (
|
||||
output_ref.to(tensors.w_tok_s.dtype) * tensors.w_tok_s.unsqueeze(1)
|
||||
).to(output_ref_type)
|
||||
|
||||
output = ops.machete_mm(
|
||||
a=tensors.a,
|
||||
@@ -245,23 +277,23 @@ def machete_mm_test_helper(types: TypeConfig,
|
||||
# Relax atol as our reduction dim becomes larger (more rounding error)
|
||||
# Relax atol when we have zeropoints since the way machete applies
|
||||
# zeropoints (after scales) causes noise around 0
|
||||
atol = 1 if tensors.w_g_zp is not None\
|
||||
atol = (
|
||||
1
|
||||
if tensors.w_g_zp is not None
|
||||
else min(5e-2 * math.sqrt(tensors.a.shape[1]), 1)
|
||||
)
|
||||
rtol = 1e-1 if tensors.a.element_size() >= 2 else 2e-1
|
||||
torch.testing.assert_close(output,
|
||||
output_ref.to(output.dtype),
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
torch.testing.assert_close(
|
||||
output, output_ref.to(output.dtype), rtol=rtol, atol=atol
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||
reason="Machete is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("shape",
|
||||
MNK_SHAPES,
|
||||
ids=lambda x: "x".join(str(v) for v in x))
|
||||
@pytest.mark.skipif(
|
||||
not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type."
|
||||
)
|
||||
@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x))
|
||||
@pytest.mark.parametrize("types", TEST_TYPES)
|
||||
def test_machete_all_schedules(shape, types: TypeConfig):
|
||||
|
||||
group_sizes: list[Optional[int]] = []
|
||||
if types.group_scale_type is None:
|
||||
group_sizes = [None]
|
||||
@@ -275,20 +307,20 @@ def test_machete_all_schedules(shape, types: TypeConfig):
|
||||
tensors = create_test_tensors(shape, types, group_size)
|
||||
print(f"MNK = {shape}")
|
||||
for schedule in ops.machete_supported_schedules(
|
||||
types.act_type,
|
||||
types.weight_type,
|
||||
group_scales_type=types.group_scale_type,
|
||||
group_zeros_type=types.group_scale_type,
|
||||
out_type=types.output_type):
|
||||
types.act_type,
|
||||
types.weight_type,
|
||||
group_scales_type=types.group_scale_type,
|
||||
group_zeros_type=types.group_scale_type,
|
||||
out_type=types.output_type,
|
||||
):
|
||||
print(f"Testing schedule {schedule}")
|
||||
machete_mm_test_helper(types, tensors, group_size, schedule)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||
reason="Machete is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("shape",
|
||||
MNK_SHAPES,
|
||||
ids=lambda x: "x".join(str(v) for v in x))
|
||||
@pytest.mark.skipif(
|
||||
not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type."
|
||||
)
|
||||
@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x))
|
||||
@pytest.mark.parametrize("types", TEST_TYPES)
|
||||
def test_machete_heuristic(shape, types: TypeConfig):
|
||||
group_sizes: list[Optional[int]] = []
|
||||
@@ -306,19 +338,22 @@ def test_machete_heuristic(shape, types: TypeConfig):
|
||||
|
||||
|
||||
# Test working on other devices
|
||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||
reason="Machete is not supported on this GPU type.")
|
||||
@pytest.mark.skipif(
|
||||
not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type."
|
||||
)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_machete_devices(device: str):
|
||||
group_size = 128
|
||||
|
||||
type_config = TypeConfig(act_type=torch.float16,
|
||||
weight_type=scalar_types.uint4b8,
|
||||
output_type=None,
|
||||
group_scale_type=torch.float16,
|
||||
group_zero_type=None,
|
||||
channel_scale_type=None,
|
||||
token_scale_type=None)
|
||||
type_config = TypeConfig(
|
||||
act_type=torch.float16,
|
||||
weight_type=scalar_types.uint4b8,
|
||||
output_type=None,
|
||||
group_scale_type=torch.float16,
|
||||
group_zero_type=None,
|
||||
channel_scale_type=None,
|
||||
token_scale_type=None,
|
||||
)
|
||||
|
||||
tensors = create_test_tensors((512, 4096, 4096), type_config, group_size)
|
||||
|
||||
@@ -331,29 +366,30 @@ def test_machete_devices(device: str):
|
||||
|
||||
|
||||
# Test working with a subset of A and B
|
||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||
reason="Machete is not supported on this GPU type.")
|
||||
@pytest.mark.skipif(
|
||||
not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type."
|
||||
)
|
||||
def test_machete_subset():
|
||||
group_size = 128
|
||||
|
||||
type_config = TypeConfig(act_type=torch.float16,
|
||||
weight_type=scalar_types.uint4b8,
|
||||
output_type=None,
|
||||
group_scale_type=torch.float16,
|
||||
group_zero_type=None,
|
||||
channel_scale_type=None,
|
||||
token_scale_type=None)
|
||||
type_config = TypeConfig(
|
||||
act_type=torch.float16,
|
||||
weight_type=scalar_types.uint4b8,
|
||||
output_type=None,
|
||||
group_scale_type=torch.float16,
|
||||
group_zero_type=None,
|
||||
channel_scale_type=None,
|
||||
token_scale_type=None,
|
||||
)
|
||||
|
||||
tensors = create_test_tensors((512, 4096, 4096),
|
||||
type_config,
|
||||
group_size,
|
||||
subset_stride_factor=2)
|
||||
tensors = create_test_tensors(
|
||||
(512, 4096, 4096), type_config, group_size, subset_stride_factor=2
|
||||
)
|
||||
machete_mm_test_helper(type_config, tensors, group_size)
|
||||
|
||||
|
||||
# Test to make sure cuda graphs work
|
||||
class MacheteLayer(torch.nn.Module):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
self.kwargs = kwargs
|
||||
@@ -362,8 +398,9 @@ class MacheteLayer(torch.nn.Module):
|
||||
return ops.machete_mm(a=a, **self.kwargs)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU,
|
||||
reason="Machete is not supported on this GPU type.")
|
||||
@pytest.mark.skipif(
|
||||
not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type."
|
||||
)
|
||||
def test_machete_cuda_graph():
|
||||
m, n, k = 512, 4096, 4096
|
||||
|
||||
@@ -375,7 +412,8 @@ def test_machete_cuda_graph():
|
||||
zero_points = False
|
||||
|
||||
w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
|
||||
a.dtype, b, wtype, stype, group_size, zero_points)
|
||||
a.dtype, b, wtype, stype, group_size, zero_points
|
||||
)
|
||||
|
||||
# Construct a trivial model with a single layer that calls a machete kernel
|
||||
model = MacheteLayer(
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
Run `pytest tests/kernels/quantization/test_marlin_gemm.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -11,24 +12,44 @@ from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL,
|
||||
GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
|
||||
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx,
|
||||
marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales,
|
||||
query_marlin_supported_quant_types)
|
||||
MARLIN_SUPPORTED_GROUP_SIZES,
|
||||
marlin_make_empty_g_idx,
|
||||
marlin_make_workspace_new,
|
||||
marlin_permute_bias,
|
||||
marlin_permute_scales,
|
||||
query_marlin_supported_quant_types,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_mxfp4_like,
|
||||
rand_marlin_weight_nvfp4_like)
|
||||
FP4_MARLIN_SUPPORTED_GROUP_SIZES,
|
||||
rand_marlin_weight_mxfp4_like,
|
||||
rand_marlin_weight_nvfp4_like,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
marlin_quant_fp8_torch)
|
||||
marlin_quant_fp8_torch,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize,
|
||||
marlin_weights)
|
||||
MarlinWorkspace,
|
||||
awq_marlin_quantize,
|
||||
get_weight_perm,
|
||||
marlin_quantize,
|
||||
marlin_weights,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||
marlin_24_quantize)
|
||||
marlin_24_quantize,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
awq_pack, gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights)
|
||||
awq_pack,
|
||||
gptq_pack,
|
||||
gptq_quantize_weights,
|
||||
quantize_weights,
|
||||
sort_weights,
|
||||
)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
ACT_ORDER_OPTS = [False, True]
|
||||
@@ -56,24 +77,27 @@ DTYPES = [torch.float16, torch.bfloat16]
|
||||
|
||||
def compute_max_diff(output, output_ref):
|
||||
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
||||
torch.abs(output_ref))
|
||||
torch.abs(output_ref)
|
||||
)
|
||||
|
||||
|
||||
def rand_data(shape, dtype=torch.float16):
|
||||
return torch.randn(shape, dtype=dtype, device="cuda")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type",
|
||||
query_marlin_supported_quant_types(False, False))
|
||||
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False, False))
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
act_order, mnk_factors):
|
||||
def test_gptq_marlin_repack(
|
||||
k_chunk, n_chunk, quant_type, group_size, act_order, mnk_factors
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_k = k_chunk * k_factor
|
||||
@@ -96,7 +120,8 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
|
||||
# Quantize (and apply act_order if provided)
|
||||
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
|
||||
b_weight, quant_type, group_size, act_order)
|
||||
b_weight, quant_type, group_size, act_order
|
||||
)
|
||||
|
||||
# Pack to GPTQ format
|
||||
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
||||
@@ -109,11 +134,14 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
|
||||
# Pack to Marlin format
|
||||
weight_perm = get_weight_perm(quant_type.size_bits)
|
||||
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
|
||||
weight_perm)
|
||||
marlin_q_w_1 = marlin_weights(
|
||||
q_w, size_k, size_n, quant_type.size_bits, weight_perm
|
||||
)
|
||||
|
||||
opcheck(torch.ops._C.gptq_marlin_repack,
|
||||
(q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits))
|
||||
opcheck(
|
||||
torch.ops._C.gptq_marlin_repack,
|
||||
(q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits),
|
||||
)
|
||||
|
||||
# Run Marlin repack GPU kernel
|
||||
marlin_q_w_2 = ops.gptq_marlin_repack(
|
||||
@@ -128,16 +156,16 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type",
|
||||
query_marlin_supported_quant_types(True))
|
||||
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(True))
|
||||
@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
mnk_factors):
|
||||
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_k = k_chunk * k_factor
|
||||
@@ -152,21 +180,22 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
# Quantize
|
||||
w_ref, q_w, s, zp = quantize_weights(b_weight,
|
||||
quant_type,
|
||||
group_size,
|
||||
zero_points=True)
|
||||
w_ref, q_w, s, zp = quantize_weights(
|
||||
b_weight, quant_type, group_size, zero_points=True
|
||||
)
|
||||
|
||||
# Pack to AWQ format
|
||||
q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
||||
|
||||
# Pack to Marlin format
|
||||
weight_perm = get_weight_perm(quant_type.size_bits)
|
||||
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
|
||||
weight_perm)
|
||||
marlin_q_w_1 = marlin_weights(
|
||||
q_w, size_k, size_n, quant_type.size_bits, weight_perm
|
||||
)
|
||||
|
||||
opcheck(torch.ops._C.awq_marlin_repack,
|
||||
(q_w_awq, size_k, size_n, quant_type.size_bits))
|
||||
opcheck(
|
||||
torch.ops._C.awq_marlin_repack, (q_w_awq, size_k, size_n, quant_type.size_bits)
|
||||
)
|
||||
|
||||
# Run Marlin repack GPU kernel
|
||||
marlin_q_w_2 = ops.awq_marlin_repack(
|
||||
@@ -180,23 +209,34 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
||||
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types())
|
||||
@pytest.mark.parametrize(
|
||||
"group_size",
|
||||
set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES))
|
||||
"group_size", set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES)
|
||||
)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
||||
@pytest.mark.parametrize("is_k_full", K_FULL_OPTS)
|
||||
@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS)
|
||||
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
mnk_factors, act_order, is_k_full, use_atomic_add,
|
||||
use_fp32_reduce, dtype):
|
||||
def test_gptq_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
quant_type,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
act_order,
|
||||
is_k_full,
|
||||
use_atomic_add,
|
||||
use_fp32_reduce,
|
||||
dtype,
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
|
||||
@@ -225,11 +265,13 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
return
|
||||
|
||||
if group_size == 16:
|
||||
w_ref, marlin_q_w, marlin_s, marlin_s2 = \
|
||||
rand_marlin_weight_nvfp4_like(b_weight.T, group_size)
|
||||
w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_nvfp4_like(
|
||||
b_weight.T, group_size
|
||||
)
|
||||
else:
|
||||
w_ref, marlin_q_w, marlin_s = \
|
||||
rand_marlin_weight_mxfp4_like(b_weight.T, group_size)
|
||||
w_ref, marlin_q_w, marlin_s = rand_marlin_weight_mxfp4_like(
|
||||
b_weight.T, group_size
|
||||
)
|
||||
marlin_s2 = None
|
||||
|
||||
g_idx = None
|
||||
@@ -240,8 +282,7 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
return
|
||||
if act_order:
|
||||
return
|
||||
w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(
|
||||
b_weight.T, group_size)
|
||||
w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(b_weight.T, group_size)
|
||||
g_idx = None
|
||||
sort_indices = None
|
||||
marlin_zp = None
|
||||
@@ -250,7 +291,8 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
if group_size == 16:
|
||||
return
|
||||
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
|
||||
b_weight, quant_type, group_size)
|
||||
b_weight, quant_type, group_size
|
||||
)
|
||||
g_idx = None
|
||||
sort_indices = None
|
||||
marlin_s2 = None
|
||||
@@ -258,18 +300,37 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
if group_size == 16:
|
||||
return
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
b_weight, quant_type, group_size, act_order)
|
||||
b_weight, quant_type, group_size, act_order
|
||||
)
|
||||
marlin_zp = None
|
||||
marlin_s2 = None
|
||||
|
||||
workspace = marlin_make_workspace_new(w_ref.device)
|
||||
|
||||
opcheck(torch.ops._C.gptq_marlin_gemm,
|
||||
(a_input, None, marlin_q_w, None, marlin_s, marlin_s2, marlin_zp,
|
||||
g_idx, sort_indices, workspace, quant_type.id, a_input.shape[0],
|
||||
b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add,
|
||||
use_fp32_reduce, False),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||
opcheck(
|
||||
torch.ops._C.gptq_marlin_gemm,
|
||||
(
|
||||
a_input,
|
||||
None,
|
||||
marlin_q_w,
|
||||
None,
|
||||
marlin_s,
|
||||
marlin_s2,
|
||||
marlin_zp,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
workspace,
|
||||
quant_type.id,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
is_k_full,
|
||||
use_atomic_add,
|
||||
use_fp32_reduce,
|
||||
False,
|
||||
),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
a_input,
|
||||
@@ -302,23 +363,40 @@ def test_gptq_marlin_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
|
||||
# TODO: find better way to test this?
|
||||
@torch.compile(fullgraph=True)
|
||||
def marlin_24_gemm_tester(a_input, marlin_24_q_w_comp, marlin_24_meta,
|
||||
marlin_24_s, scratch, quant_type, size_m, size_n,
|
||||
size_k):
|
||||
return ops.gptq_marlin_24_gemm(a_input, marlin_24_q_w_comp, marlin_24_meta,
|
||||
marlin_24_s, scratch, quant_type, size_m,
|
||||
size_n, size_k)
|
||||
def marlin_24_gemm_tester(
|
||||
a_input,
|
||||
marlin_24_q_w_comp,
|
||||
marlin_24_meta,
|
||||
marlin_24_s,
|
||||
scratch,
|
||||
quant_type,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
):
|
||||
return ops.gptq_marlin_24_gemm(
|
||||
a_input,
|
||||
marlin_24_q_w_comp,
|
||||
marlin_24_meta,
|
||||
marlin_24_s,
|
||||
scratch,
|
||||
quant_type,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
|
||||
@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
|
||||
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
mnk_factors):
|
||||
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
@@ -328,19 +406,31 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
(w_24_ref, marlin_24_q_w_comp, marlin_24_meta,
|
||||
marlin_24_s) = marlin_24_quantize(b_weight, quant_type, group_size)
|
||||
(w_24_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = marlin_24_quantize(
|
||||
b_weight, quant_type, group_size
|
||||
)
|
||||
|
||||
workspace_24 = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL)
|
||||
workspace_24 = MarlinWorkspace(
|
||||
size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
|
||||
)
|
||||
|
||||
output_ref = torch.matmul(a_input, w_24_ref)
|
||||
|
||||
opcheck(torch.ops._C.gptq_marlin_24_gemm,
|
||||
(a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s,
|
||||
workspace_24.scratch, quant_type.id, a_input.shape[0],
|
||||
b_weight.shape[1], a_input.shape[1]),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
|
||||
opcheck(
|
||||
torch.ops._C.gptq_marlin_24_gemm,
|
||||
(
|
||||
a_input,
|
||||
marlin_24_q_w_comp,
|
||||
marlin_24_meta,
|
||||
marlin_24_s,
|
||||
workspace_24.scratch,
|
||||
quant_type.id,
|
||||
a_input.shape[0],
|
||||
b_weight.shape[1],
|
||||
a_input.shape[1],
|
||||
),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
output = marlin_24_gemm_tester(
|
||||
a_input,
|
||||
@@ -361,8 +451,10 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.",
|
||||
)
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("group_size", HQQ_SUPPORTED_GROUP_SIZES)
|
||||
@@ -386,22 +478,22 @@ def test_hqq_marlin_gemm(
|
||||
a_input = rand_data((size_m, size_k))
|
||||
dev = a_input.device
|
||||
|
||||
b_weight = torch.randint(0,
|
||||
10, (size_n, size_k),
|
||||
dtype=torch.uint8,
|
||||
device=dev)
|
||||
b_weight = torch.randint(0, 10, (size_n, size_k), dtype=torch.uint8, device=dev)
|
||||
scale = rand_data((size_n, size_k // group_size))
|
||||
zero = rand_data((size_n, size_k // group_size))
|
||||
|
||||
gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n)
|
||||
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=dev)
|
||||
marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n,
|
||||
4).to(dev)
|
||||
marlin_s = marlin_permute_scales(scale.transpose(1, 0), size_k, size_n,
|
||||
group_size).to(dev)
|
||||
marlin_zp = marlin_permute_scales(zero.transpose(1, 0), size_k, size_n,
|
||||
group_size).to(dev)
|
||||
marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n, 4).to(
|
||||
dev
|
||||
)
|
||||
marlin_s = marlin_permute_scales(
|
||||
scale.transpose(1, 0), size_k, size_n, group_size
|
||||
).to(dev)
|
||||
marlin_zp = marlin_permute_scales(
|
||||
zero.transpose(1, 0), size_k, size_n, group_size
|
||||
).to(dev)
|
||||
|
||||
g_idx = marlin_make_empty_g_idx(dev)
|
||||
g_idx_sort_indices = marlin_make_empty_g_idx(dev)
|
||||
@@ -433,8 +525,7 @@ def test_hqq_marlin_gemm(
|
||||
s_flat = scale.reshape(-1, 1)
|
||||
dequant = (b_flat - zp_flat) * s_flat
|
||||
|
||||
output_ref = torch.matmul(a_input,
|
||||
dequant.reshape(b_weight.shape).transpose(1, 0))
|
||||
output_ref = torch.matmul(a_input, dequant.reshape(b_weight.shape).transpose(1, 0))
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@@ -451,11 +542,12 @@ def test_marlin_gemm_subset_input():
|
||||
big_m = size_m * 2
|
||||
big_k = size_k * 2
|
||||
|
||||
a_input = rand_data((big_m, big_k))[8:size_m + 8, 8:size_k + 8]
|
||||
a_input = rand_data((big_m, big_k))[8 : size_m + 8, 8 : size_k + 8]
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
b_weight, quant_type, group_size, False)
|
||||
b_weight, quant_type, group_size, False
|
||||
)
|
||||
|
||||
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
||||
workspace = marlin_make_workspace_new(a_input.device)
|
||||
@@ -497,12 +589,13 @@ def test_marlin_gemm_with_bias(size_m):
|
||||
size_k, size_n = 1024, 2048
|
||||
a_input = rand_data((size_m, size_k))
|
||||
b_weight = rand_data((size_k, size_n))
|
||||
b_bias = rand_data((size_n, )) * 10
|
||||
b_bias = rand_data((size_n,)) * 10
|
||||
|
||||
marlin_bias = marlin_permute_bias(b_bias)
|
||||
|
||||
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
||||
b_weight, quant_type, group_size, False)
|
||||
b_weight, quant_type, group_size, False
|
||||
)
|
||||
|
||||
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
||||
workspace = marlin_make_workspace_new(a_input.device)
|
||||
|
||||
@@ -8,15 +8,27 @@ from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True)
|
||||
pytest.skip(
|
||||
reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)]
|
||||
PAD_SHAPES = [(90, 64), (150, 64), (128, 48), (128, 80), (150, 80), (90, 48),
|
||||
(90, 128), (150, 128), (150, 48), (90, 80)]
|
||||
PAD_SHAPES = [
|
||||
(90, 64),
|
||||
(150, 64),
|
||||
(128, 48),
|
||||
(128, 80),
|
||||
(150, 80),
|
||||
(90, 48),
|
||||
(90, 128),
|
||||
(150, 128),
|
||||
(150, 48),
|
||||
(90, 80),
|
||||
]
|
||||
SEEDS = [42]
|
||||
CUDA_DEVICES = ['cuda:0']
|
||||
CUDA_DEVICES = ["cuda:0"]
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
@@ -31,7 +43,22 @@ FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
# 0001 -> 0.5
|
||||
# 0000 -> 0
|
||||
E2M1_TO_FLOAT32 = [
|
||||
0., 0.5, 1., 1.5, 2., 3., 4., 6., 0., -0.5, -1., -1.5, -2., -3., -4., -6.
|
||||
0.0,
|
||||
0.5,
|
||||
1.0,
|
||||
1.5,
|
||||
2.0,
|
||||
3.0,
|
||||
4.0,
|
||||
6.0,
|
||||
0.0,
|
||||
-0.5,
|
||||
-1.0,
|
||||
-1.5,
|
||||
-2.0,
|
||||
-3.0,
|
||||
-4.0,
|
||||
-6.0,
|
||||
]
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
@@ -74,8 +101,7 @@ def ref_nvfp4_quant(x, global_scale):
|
||||
assert x.ndim == 2
|
||||
m, n = x.shape
|
||||
x = torch.reshape(x, (m, n // BLOCK_SIZE, BLOCK_SIZE))
|
||||
vec_max = torch.max(torch.abs(x), dim=-1,
|
||||
keepdim=True)[0].to(torch.float32)
|
||||
vec_max = torch.max(torch.abs(x), dim=-1, keepdim=True)[0].to(torch.float32)
|
||||
scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX))
|
||||
scale = scale.to(torch.float8_e4m3fn).to(torch.float32)
|
||||
output_scale = get_reciprocal(scale * get_reciprocal(global_scale))
|
||||
@@ -131,7 +157,7 @@ def test_quantize_to_fp4(
|
||||
def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
|
||||
dtype = torch.float16
|
||||
current_platform.seed_everything(42)
|
||||
torch.set_default_device('cuda:0')
|
||||
torch.set_default_device("cuda:0")
|
||||
|
||||
m, n = pad_shape
|
||||
|
||||
|
||||
@@ -2,15 +2,16 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
from nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype)
|
||||
from nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, dequantize_nvfp4_to_dtype
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True)
|
||||
pytest.skip(
|
||||
reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
# m, n, k
|
||||
@@ -19,26 +20,31 @@ PAD_SHAPES = [(150, 128, 64), (128, 128, 96)]
|
||||
SHAPES.extend(PAD_SHAPES)
|
||||
|
||||
SEEDS = [42]
|
||||
CUDA_DEVICES = ['cuda:0']
|
||||
CUDA_DEVICES = ["cuda:0"]
|
||||
|
||||
|
||||
def get_ref_results(a_fp4, b_fp4, a_sf, b_sf, a_global_scale, b_global_scale,
|
||||
m, n, dtype, block_size, device):
|
||||
def get_ref_results(
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
a_sf,
|
||||
b_sf,
|
||||
a_global_scale,
|
||||
b_global_scale,
|
||||
m,
|
||||
n,
|
||||
dtype,
|
||||
block_size,
|
||||
device,
|
||||
):
|
||||
_, m_k = a_fp4.shape
|
||||
_, n_k = b_fp4.shape
|
||||
assert (m_k == n_k)
|
||||
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
|
||||
a_sf,
|
||||
a_global_scale,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
block_size=block_size)
|
||||
b_in_dtype = dequantize_nvfp4_to_dtype(b_fp4,
|
||||
b_sf,
|
||||
b_global_scale,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
block_size=block_size)
|
||||
assert m_k == n_k
|
||||
a_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size
|
||||
)
|
||||
b_in_dtype = dequantize_nvfp4_to_dtype(
|
||||
b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size
|
||||
)
|
||||
return torch.matmul(a_in_dtype, b_in_dtype.t())
|
||||
|
||||
|
||||
@@ -60,25 +66,34 @@ def test_nvfp4_gemm(
|
||||
a_dtype = torch.randn((m, k), dtype=dtype, device=device)
|
||||
b_dtype = torch.randn((n, k), dtype=dtype, device=device)
|
||||
|
||||
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||
torch.amax(a_dtype.flatten(), dim=-1)).to(torch.float32)
|
||||
b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||
torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32)
|
||||
alpha = 1. / (a_global_scale * b_global_scale)
|
||||
a_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
b_global_scale = (
|
||||
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)
|
||||
).to(torch.float32)
|
||||
alpha = 1.0 / (a_global_scale * b_global_scale)
|
||||
# ops.scaled_fp4_quant returns swizzled scales, while weights
|
||||
# from checkpoints are in linear scales.
|
||||
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale)
|
||||
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale)
|
||||
|
||||
# get_ref_results unswizzles the scales internally.
|
||||
expected_out = get_ref_results(a_fp4, b_fp4, a_scale_interleaved,
|
||||
b_scale_interleaved, a_global_scale,
|
||||
b_global_scale, m, n, dtype, block_size,
|
||||
device)
|
||||
out = ops.cutlass_scaled_fp4_mm(a_fp4, b_fp4, a_scale_interleaved,
|
||||
b_scale_interleaved, alpha, dtype)
|
||||
expected_out = get_ref_results(
|
||||
a_fp4,
|
||||
b_fp4,
|
||||
a_scale_interleaved,
|
||||
b_scale_interleaved,
|
||||
a_global_scale,
|
||||
b_global_scale,
|
||||
m,
|
||||
n,
|
||||
dtype,
|
||||
block_size,
|
||||
device,
|
||||
)
|
||||
out = ops.cutlass_scaled_fp4_mm(
|
||||
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
|
||||
)
|
||||
|
||||
torch.testing.assert_close(out,
|
||||
expected_out.to(dtype=dtype),
|
||||
atol=1e-1,
|
||||
rtol=1e-1)
|
||||
torch.testing.assert_close(out, expected_out.to(dtype=dtype), atol=1e-1, rtol=1e-1)
|
||||
|
||||
@@ -13,15 +13,15 @@ from vllm.model_executor.layers.quantization.utils import fp8_utils, int8_utils
|
||||
@pytest.mark.parametrize("scale_ue8m0", [False, True])
|
||||
@pytest.mark.parametrize("group_size", [64, 128])
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_per_token_group_quant_fp8(shape, column_major: bool,
|
||||
scale_ue8m0: bool, group_size: int):
|
||||
def test_per_token_group_quant_fp8(
|
||||
shape, column_major: bool, scale_ue8m0: bool, group_size: int
|
||||
):
|
||||
device = "cuda"
|
||||
|
||||
torch.manual_seed(42)
|
||||
num_tokens, hidden_dim = shape
|
||||
|
||||
x = (torch.randn(
|
||||
(num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8)
|
||||
x = torch.randn((num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8
|
||||
|
||||
# cuda path
|
||||
out_q, scale = fp8_utils.per_token_group_quant_fp8(
|
||||
@@ -53,8 +53,7 @@ def test_per_token_group_quant_int8(shape, group_size: int):
|
||||
torch.manual_seed(42)
|
||||
num_tokens, hidden_dim = shape
|
||||
|
||||
x = (torch.randn(
|
||||
(num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8)
|
||||
x = torch.randn((num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8
|
||||
|
||||
# cuda path
|
||||
out_q, scale = int8_utils.per_token_group_quant_int8(
|
||||
|
||||
@@ -63,12 +63,11 @@ SEEDS = [0]
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16])
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||
reason="only test for rocm")
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
|
||||
@torch.inference_mode()
|
||||
def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
|
||||
torch.manual_seed(seed)
|
||||
#TODO: Zero-centering the inputs causes errors for LLMM1!
|
||||
# TODO: Zero-centering the inputs causes errors for LLMM1!
|
||||
# Without that the numbers quickly saturate, and may
|
||||
# be giving false matches.
|
||||
A = torch.rand(n, k, dtype=dtype, device="cuda")
|
||||
@@ -83,14 +82,13 @@ def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed):
|
||||
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||
reason="only test for rocm")
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
|
||||
def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
cu_count = current_platform.get_cu_count()
|
||||
|
||||
A = torch.rand(n, k, dtype=dtype, device="cuda") - .5
|
||||
B = torch.rand(m, k, dtype=dtype, device="cuda") - .5
|
||||
A = torch.rand(n, k, dtype=dtype, device="cuda") - 0.5
|
||||
B = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5
|
||||
|
||||
ref_out = torch.nn.functional.linear(A, B)
|
||||
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count)
|
||||
@@ -101,16 +99,15 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
|
||||
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||
reason="only test for rocm")
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
|
||||
def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
cu_count = current_platform.get_cu_count()
|
||||
|
||||
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
|
||||
A = (torch.rand(n, k, dtype=dtype, device="cuda") - .5) * xavier
|
||||
B = (torch.rand(m, k, dtype=dtype, device="cuda") - .5) * xavier
|
||||
BIAS = torch.rand(m, dtype=dtype, device="cuda") - .5
|
||||
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
|
||||
B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier
|
||||
BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5
|
||||
|
||||
ref_out = torch.nn.functional.linear(A, B, BIAS)
|
||||
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
|
||||
@@ -121,16 +118,15 @@ def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed):
|
||||
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||
reason="only test for rocm")
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm")
|
||||
def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
cu_count = current_platform.get_cu_count()
|
||||
|
||||
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
|
||||
A = (torch.rand(n, k, dtype=dtype, device="cuda") - .5) * xavier
|
||||
B = (torch.rand(m, k, dtype=dtype, device="cuda") - .5) * xavier
|
||||
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - .5
|
||||
A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier
|
||||
B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier
|
||||
BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - 0.5
|
||||
|
||||
ref_out = torch.nn.functional.linear(A, B, BIAS)
|
||||
out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS)
|
||||
@@ -143,7 +139,8 @@ def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(
|
||||
not (current_platform.is_rocm() and current_platform.supports_fp8()),
|
||||
reason="only test for rocm fp8")
|
||||
reason="only test for rocm fp8",
|
||||
)
|
||||
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
@@ -153,13 +150,10 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
|
||||
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
|
||||
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
|
||||
|
||||
ref_out = torch._scaled_mm(A,
|
||||
B.t(),
|
||||
out_dtype=dtype,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b)
|
||||
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b,
|
||||
current_platform.get_cu_count())
|
||||
ref_out = torch._scaled_mm(
|
||||
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b
|
||||
)
|
||||
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, current_platform.get_cu_count())
|
||||
|
||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||
|
||||
@@ -169,25 +163,24 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(
|
||||
not (current_platform.is_rocm() and current_platform.supports_fp8()),
|
||||
reason="only test for rocm fp8")
|
||||
reason="only test for rocm fp8",
|
||||
)
|
||||
def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
|
||||
xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
|
||||
A = (torch.rand(n, k, device="cuda") - .5) * xavier
|
||||
B = (torch.rand(m, k, device="cuda") - .5) * xavier
|
||||
BIAS = torch.rand(m, dtype=dtype, device="cuda") - .5
|
||||
A = (torch.rand(n, k, device="cuda") - 0.5) * xavier
|
||||
B = (torch.rand(m, k, device="cuda") - 0.5) * xavier
|
||||
BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5
|
||||
|
||||
A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
|
||||
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
|
||||
|
||||
ref_out = torch._scaled_mm(A,
|
||||
B.t(),
|
||||
out_dtype=dtype,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=BIAS)
|
||||
out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b,
|
||||
current_platform.get_cu_count(), BIAS)
|
||||
ref_out = torch._scaled_mm(
|
||||
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS
|
||||
)
|
||||
out = ops.wvSplitKQ(
|
||||
B, A, dtype, scale_a, scale_b, current_platform.get_cu_count(), BIAS
|
||||
)
|
||||
|
||||
assert torch.allclose(out, ref_out, rtol=0.01)
|
||||
|
||||
@@ -3,16 +3,20 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype)
|
||||
from tests.kernels.quantization.nvfp4_utils import (
|
||||
FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX,
|
||||
dequantize_nvfp4_to_dtype,
|
||||
)
|
||||
from vllm._custom_ops import scaled_fp4_quant
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True)
|
||||
pytest.skip(
|
||||
reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
FP4_DTYPE = torch.uint8
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
@@ -30,24 +34,24 @@ def test_silu_mul_nvfp4_quant(
|
||||
shape: tuple[int, int],
|
||||
) -> None:
|
||||
current_platform.seed_everything(42)
|
||||
device = 'cuda:0'
|
||||
device = "cuda:0"
|
||||
torch.set_default_device(device)
|
||||
|
||||
x = torch.randn(shape, dtype=dtype)
|
||||
|
||||
# ref op
|
||||
ref_output = SiluAndMul().forward_native(x)
|
||||
ref_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||
torch.abs(ref_output).max().to(torch.float32))
|
||||
ref_output_quant, ref_block_scale = scaled_fp4_quant(
|
||||
ref_output, ref_global_scale)
|
||||
ref_global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(
|
||||
ref_output
|
||||
).max().to(torch.float32)
|
||||
ref_output_quant, ref_block_scale = scaled_fp4_quant(ref_output, ref_global_scale)
|
||||
|
||||
# fused op
|
||||
fused_output_quant = torch.empty_like(ref_output_quant)
|
||||
fused_block_scale = torch.empty_like(ref_block_scale)
|
||||
torch.ops._C.silu_and_mul_nvfp4_quant(fused_output_quant,
|
||||
fused_block_scale, x,
|
||||
ref_global_scale)
|
||||
torch.ops._C.silu_and_mul_nvfp4_quant(
|
||||
fused_output_quant, fused_block_scale, x, ref_global_scale
|
||||
)
|
||||
|
||||
# check dtype
|
||||
assert ref_output_quant.dtype == FP4_DTYPE
|
||||
@@ -59,17 +63,14 @@ def test_silu_mul_nvfp4_quant(
|
||||
assert ref_block_scale.shape == fused_block_scale.shape
|
||||
|
||||
# check dequantized output
|
||||
ref_output_dequant = dequantize_nvfp4_to_dtype(ref_output_quant,
|
||||
ref_block_scale,
|
||||
ref_global_scale, dtype,
|
||||
device)
|
||||
fused_output_dequant = dequantize_nvfp4_to_dtype(fused_output_quant,
|
||||
fused_block_scale,
|
||||
ref_global_scale, dtype,
|
||||
device)
|
||||
ref_output_dequant = dequantize_nvfp4_to_dtype(
|
||||
ref_output_quant, ref_block_scale, ref_global_scale, dtype, device
|
||||
)
|
||||
fused_output_dequant = dequantize_nvfp4_to_dtype(
|
||||
fused_output_quant, fused_block_scale, ref_global_scale, dtype, device
|
||||
)
|
||||
|
||||
atol, rtol = 3e-1, 3e-1
|
||||
torch.testing.assert_close(ref_output_dequant,
|
||||
fused_output_dequant,
|
||||
atol=atol,
|
||||
rtol=rtol)
|
||||
torch.testing.assert_close(
|
||||
ref_output_dequant, fused_output_dequant, atol=atol, rtol=rtol
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
Run `pytest tests/kernels/quantization/test_triton_scaled_mm.py`.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from typing import Optional
|
||||
|
||||
@@ -15,17 +16,19 @@ from vllm.platforms import current_platform
|
||||
device = "cuda"
|
||||
|
||||
triton_scaled_mm_module = importlib.import_module(
|
||||
"vllm.model_executor.layers.quantization.compressed_tensors."
|
||||
"triton_scaled_mm")
|
||||
"vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm"
|
||||
)
|
||||
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
|
||||
|
||||
|
||||
def torch_scaled_mm(a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: type[torch.dtype],
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def torch_scaled_mm(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: type[torch.dtype],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
out = torch.mm(a.to(torch.float32), b.to(torch.float32))
|
||||
out = scale_a * out
|
||||
out = scale_b.T * out
|
||||
@@ -44,20 +47,22 @@ def get_8bit_types():
|
||||
|
||||
|
||||
# This test is to check regressions for int8 support on ROCm.
|
||||
@pytest.mark.parametrize("model_path", [
|
||||
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"model_path",
|
||||
[
|
||||
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(),
|
||||
reason="Should only run on ROCm")
|
||||
def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path,
|
||||
max_tokens, num_logprobs):
|
||||
@pytest.mark.skipif(not current_platform.is_rocm(), reason="Should only run on ROCm")
|
||||
def test_rocm_compressed_tensors_w8a8(
|
||||
vllm_runner, example_prompts, model_path, max_tokens, num_logprobs
|
||||
):
|
||||
dtype = "bfloat16"
|
||||
|
||||
with vllm_runner(model_path, dtype=dtype) as vllm_model:
|
||||
vllm_model.generate_greedy_logprobs(example_prompts, max_tokens,
|
||||
num_logprobs)
|
||||
vllm_model.generate_greedy_logprobs(example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
|
||||
MNK_FACTORS = [
|
||||
@@ -76,10 +81,10 @@ MNK_FACTORS = [
|
||||
@pytest.mark.parametrize("use_scalar_scale_a", [True, False])
|
||||
@pytest.mark.parametrize("use_scalar_scale_b", [True, False])
|
||||
@pytest.mark.parametrize("use_bias", [True, False])
|
||||
def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a,
|
||||
use_scalar_scale_b, use_bias):
|
||||
is_floating_point_type = lambda t: torch.tensor([1, 1], dtype=t
|
||||
).is_floating_point()
|
||||
def test_scaled_mm(
|
||||
M, N, K, in_dtype, out_dtype, use_scalar_scale_a, use_scalar_scale_b, use_bias
|
||||
):
|
||||
is_floating_point_type = lambda t: torch.tensor([1, 1], dtype=t).is_floating_point()
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
@@ -93,10 +98,8 @@ def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a,
|
||||
#
|
||||
# So, the values here are kept small enough to avoid this situation.
|
||||
if is_floating_point_type(in_dtype):
|
||||
a = (0.25 * torch.rand(
|
||||
(M, K), dtype=torch.float32, device=device)).to(in_dtype)
|
||||
b = (0.25 * torch.rand(
|
||||
(K, N), dtype=torch.float32, device=device)).to(in_dtype)
|
||||
a = (0.25 * torch.rand((M, K), dtype=torch.float32, device=device)).to(in_dtype)
|
||||
b = (0.25 * torch.rand((K, N), dtype=torch.float32, device=device)).to(in_dtype)
|
||||
else:
|
||||
a = torch.randint(-32, 32, (M, K), dtype=in_dtype, device=device)
|
||||
b = torch.randint(-32, 32, (K, N), dtype=in_dtype, device=device)
|
||||
@@ -113,7 +116,7 @@ def test_scaled_mm(M, N, K, in_dtype, out_dtype, use_scalar_scale_a,
|
||||
|
||||
bias = None
|
||||
if use_bias:
|
||||
bias = torch.rand((N, ), device=device, dtype=out_dtype)
|
||||
bias = torch.rand((N,), device=device, dtype=out_dtype)
|
||||
|
||||
c_check = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user