[Performance] Move apply_w8a8_block_fp8_linear to an op class (#24666)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Signed-off-by: ElizaWszola <elizaw.9289@gmail.com>
Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
ElizaWszola
2025-09-23 21:03:10 +02:00
committed by GitHub
parent 8c1c81a3de
commit 63400259d0
14 changed files with 341 additions and 201 deletions

View File

@@ -17,8 +17,6 @@ from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
from vllm.model_executor.layers.layernorm import (RMSNorm,
dispatch_rocm_rmsnorm_func,
fused_add_rms_norm, rms_norm)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul)
from vllm.platforms import current_platform
RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
@@ -111,34 +109,6 @@ def test_enabled_ops_invalid(env: str):
RMSNorm(1024).enabled()
@pytest.mark.skipif(
not current_platform.is_rocm() or not current_platform.is_fp8_fnuz(),
reason="AITER is a feature exclusive for ROCm and FP8_FNUZ")
@pytest.mark.parametrize("use_cutlass", [True, False])
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter_gemm_w8a8_blockscale", ["0", "1"])
def test_w8a8_blockscale_dispatch(use_cutlass: bool, use_rocm_aiter: str,
use_rocm_aiter_gemm_w8a8_blockscale: str,
monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR",
use_rocm_aiter_gemm_w8a8_blockscale)
use_aiter_and_is_supported = (bool(int(use_rocm_aiter)) and bool(
int(use_rocm_aiter_gemm_w8a8_blockscale)))
block_scale_func = dispatch_w8a8_blockscale_func(
use_cutlass, use_aiter_and_is_supported=use_aiter_and_is_supported)
if use_cutlass:
assert block_scale_func == cutlass_scaled_mm
elif current_platform.is_rocm() and int(use_rocm_aiter) and int(
use_rocm_aiter_gemm_w8a8_blockscale):
assert block_scale_func == (
torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale)
else:
assert block_scale_func == w8a8_block_fp8_matmul
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)