[FEAT] [ROCm] Add AITER int8 scaled gemm kernel (#15433)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
TJian
2025-03-29 18:33:56 +08:00
committed by GitHub
parent 73aa7041bf
commit 4965ec42d2
4 changed files with 202 additions and 5 deletions

View File

@@ -20,6 +20,23 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
sparse_cutlass_supported)
from vllm.platforms import current_platform
# AITER only supports per-channel-per-channel INT8 gemm
# and per-tensor-per-tensor INT8 GEMM.
# It does not support mix precision MM and mix quantization scheme.
ROCM_AITER_SUPPORTED_INT8_MODEL = [
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2"
]
# TritonScaledMMLinearKernel only supports symmetric quantization.
ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL = [
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
"nm-testing/tinyllama-oneshot-w8-channel-a8-tensor",
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
"nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2",
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2",
]
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
@@ -57,6 +74,11 @@ def use_v0_only(monkeypatch):
)
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
model_path, strategy, quant_type, shape_0, is_symmetric = model_args
if current_platform.is_rocm(
) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL:
pytest.skip(f"Skip model {model_path} as it is not support on ROCm.")
with vllm_runner(model_path, enforce_eager=True) as llm:
def check_model(model):
@@ -123,6 +145,8 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
)
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10])
@pytest.mark.parametrize(
"use_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_compressed_tensors_w8a8_logprobs(
hf_runner,
vllm_runner,
@@ -130,7 +154,21 @@ def test_compressed_tensors_w8a8_logprobs(
model_path,
max_tokens,
num_logprobs,
use_aiter,
monkeypatch,
):
if current_platform.is_rocm(
) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL:
pytest.skip(f"Skip model {model_path} as it is not support on ROCm.")
if use_aiter:
if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL:
pytest.skip(
f"Skip model {model_path} as it is not support by aiter.")
# this will enable VLLM_ROCM_USE_AITER_LINEAR
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
dtype = "bfloat16"
# skip language translation prompt for the static per tensor asym model
@@ -154,6 +192,9 @@ def test_compressed_tensors_w8a8_logprobs(
name_1="vllm",
)
if current_platform.is_rocm():
torch.cuda.synchronize()
def test_compressed_tensors_no_enforce_eager(vllm_runner):
model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
@@ -177,8 +218,27 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
),
],
)
def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
@pytest.mark.parametrize(
"use_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_compressed_tensors_w8a8_dynamic_per_token(
vllm_runner,
model_args,
use_aiter,
monkeypatch,
):
model_path, strategy = model_args
if current_platform.is_rocm(
) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL:
pytest.skip(f"Skip model {model_path} as it is not support on ROCm.")
if use_aiter:
if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL:
pytest.skip(
f"Skip model {model_path} as it is not support by aiter.")
# this will enable VLLM_ROCM_USE_AITER_LINEAR
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
with vllm_runner(model_path, dtype=torch.float16) as llm:
def check_model(model):
@@ -207,6 +267,8 @@ def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4),
],
)
@pytest.mark.skipif(not current_platform.is_cuda(),
reason="The tests are skipped on non-CUDA platform.")
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
model, strategy, group, pack_factor = wNa16_args
with vllm_runner(model) as llm:
@@ -231,6 +293,8 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
assert output
@pytest.mark.skipif(not current_platform.is_cuda(),
reason="This test is skipped on non-CUDA platform.")
def test_compressed_tensors_w4a16_marlin24(vllm_runner):
model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
with vllm_runner(model_path) as llm:
@@ -271,7 +335,7 @@ def test_compressed_tensors_fp8(vllm_runner):
if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8):
assert len(qkv_proj.input_scale.shape) == 0
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
assert qkv_proj.weight_scale.dtype is torch.float32
assert len(qkv_proj.weight_scale.shape) == 0
@@ -281,6 +345,8 @@ def test_compressed_tensors_fp8(vllm_runner):
assert output
@pytest.mark.skipif(not current_platform.is_cuda(),
reason="This test is skipped on non-CUDA platform.")
def test_compressed_tensors_kv_cache(vllm_runner):
model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
@@ -309,7 +375,8 @@ def _test_2of4_quant_models(qkv_proj,
@pytest.mark.skipif(
not current_platform.has_device_capability(90),
not current_platform.is_cuda()
or not current_platform.has_device_capability(90),
reason="Sparse FP8 is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
@@ -356,7 +423,8 @@ def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
@pytest.mark.skipif(
not current_platform.has_device_capability(90),
not current_platform.is_cuda()
or not current_platform.has_device_capability(90),
reason="Sparse FP8 is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(