[torch.compile] Add support for non-contiguous fused RMSNorm + group quant (#36551)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -72,6 +72,16 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
|
||||
rocm_aiter_ops.refresh_env_variables()
|
||||
|
||||
# Filter here to reduce code duplication
|
||||
requires_mla = "deepseek" in model_name.lower()
|
||||
is_mla = "mla" in attn_backend.backend.name.lower()
|
||||
|
||||
if requires_mla != is_mla:
|
||||
pytest.skip(
|
||||
f"Incompatible model '{model_name}' and "
|
||||
f"attention backend '{attn_backend.backend.name}'"
|
||||
)
|
||||
|
||||
# Disable, compile cache to make sure custom passes run.
|
||||
# Otherwise, we can't verify fusion happened through the logs.
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
@@ -44,6 +44,20 @@ ROCM_AITER_UNIFIED_ATTN = pytest.param(
|
||||
),
|
||||
)
|
||||
|
||||
FLASHINFER_MLA_ATTN = pytest.param(
|
||||
AttentionBackendCase(backend=AttentionBackendEnum.FLASHINFER_MLA),
|
||||
id="FLASHINFER_MLA",
|
||||
marks=pytest.mark.skipif(
|
||||
not is_blackwell() or not has_flashinfer(),
|
||||
reason="FI backend requires Blackwell and FlashInfer",
|
||||
),
|
||||
)
|
||||
|
||||
TRITON_MLA_ATTN = pytest.param(
|
||||
AttentionBackendCase(backend=AttentionBackendEnum.TRITON_MLA),
|
||||
id="TRITON_MLA",
|
||||
)
|
||||
|
||||
# Models
|
||||
llama3_8b = ModelFusionInfo(
|
||||
model_name="meta-llama/Llama-3.1-8B-Instruct",
|
||||
@@ -126,3 +140,25 @@ qwen3_a3b_fp8 = ModelFusionInfo(
|
||||
async_tp=n_layers * 2,
|
||||
),
|
||||
)
|
||||
|
||||
deepseek_v3_fp8 = ModelFusionInfo(
|
||||
model_name="deepseek-ai/DeepSeek-V3",
|
||||
matches=lambda n_layers: Matches(
|
||||
# 3 per dense layer (first 3):
|
||||
# - input_rms + qkv_proj
|
||||
# - q_a_layernorm + q_b_proj (inside MLA wrapper)
|
||||
# - post_attn_layernorm + MLP
|
||||
# 2 per MoE layer (remaining) due to MoE wrapping
|
||||
rms_quant_fusion=n_layers * 2 + min(3, n_layers), # add for 3 dense layers
|
||||
# TODO silu+block quant
|
||||
# act_quant_fusion=min(3, n_layers), # dense layers only
|
||||
act_quant_fusion=0,
|
||||
# MLA attn + quant not supported yet:
|
||||
# https://github.com/vllm-project/vllm/issues/35792
|
||||
attn_quant_fusion=0,
|
||||
ar_rms_fusion=n_layers * 2 + 1,
|
||||
# TODO
|
||||
# sequence_parallel= n_layers * 2 + 1,
|
||||
# async_tp=n_layers * 2,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -17,9 +17,12 @@ from .common import (
|
||||
)
|
||||
from .models import (
|
||||
FLASHINFER_ATTN,
|
||||
FLASHINFER_MLA_ATTN,
|
||||
ROCM_AITER_UNIFIED_ATTN,
|
||||
ROCM_ATTN,
|
||||
TRITON_ATTN,
|
||||
TRITON_MLA_ATTN,
|
||||
deepseek_v3_fp8,
|
||||
llama3_8b_fp4,
|
||||
llama3_8b_fp8,
|
||||
llama4_scout_fp4,
|
||||
@@ -33,6 +36,9 @@ from .models import (
|
||||
[
|
||||
(*llama3_8b_fp8, False),
|
||||
(*qwen3_a3b_fp8, False),
|
||||
(*qwen3_a3b_fp8, True),
|
||||
(*deepseek_v3_fp8, False),
|
||||
(*deepseek_v3_fp8, True),
|
||||
pytest.param(
|
||||
*llama4_scout_fp8,
|
||||
False,
|
||||
@@ -41,13 +47,6 @@ from .models import (
|
||||
reason="Llama4 Scout FP8 only supported on CUDA",
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
*qwen3_a3b_fp8,
|
||||
True,
|
||||
marks=pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="DeepGemm only supported on CUDA"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
@@ -57,6 +56,8 @@ from .models import (
|
||||
FLASHINFER_ATTN,
|
||||
ROCM_ATTN,
|
||||
ROCM_AITER_UNIFIED_ATTN,
|
||||
FLASHINFER_MLA_ATTN,
|
||||
TRITON_MLA_ATTN,
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("n_layers", [6])
|
||||
@@ -75,6 +76,9 @@ def test_tp1_fp8_fusions(
|
||||
run_e2e_fusion_test,
|
||||
monkeypatch,
|
||||
):
|
||||
if use_deepgemm and not current_platform.is_cuda():
|
||||
pytest.skip("DeepGemm only supported on CUDA")
|
||||
|
||||
if use_deepgemm and is_flashinfer_fp8_blockscale_gemm_supported():
|
||||
# Flashinfer block FP8 GEMM has internal quantization, so it can't
|
||||
# be fused with other ops.
|
||||
@@ -86,7 +90,8 @@ def test_tp1_fp8_fusions(
|
||||
|
||||
matches = matches_fn(n_layers)
|
||||
|
||||
if "qwen" in model_name.lower() and "-quant_fp8" in custom_ops:
|
||||
block_fp8 = "qwen" in model_name.lower() or "deepseek" in model_name.lower()
|
||||
if block_fp8 and "-quant_fp8" in custom_ops:
|
||||
# This is why config forces +quant_fp8 by default
|
||||
pytest.skip("native QuantFP8 matching not supported for group quant")
|
||||
|
||||
|
||||
@@ -17,7 +17,9 @@ from .common import (
|
||||
)
|
||||
from .models import (
|
||||
FLASHINFER_ATTN,
|
||||
FLASHINFER_MLA_ATTN,
|
||||
TRITON_ATTN,
|
||||
deepseek_v3_fp8,
|
||||
llama3_8b,
|
||||
llama3_8b_fp4,
|
||||
llama3_8b_fp8,
|
||||
@@ -33,10 +35,12 @@ pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only tes
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, matches_fn, model_kwargs, hf_overrides",
|
||||
# qwen3-fp8 should still fuse AR+rms even though group quant is not yet supported
|
||||
[llama3_8b_fp8, llama4_scout_fp8, qwen3_a3b_fp8],
|
||||
# qwen3 & dsv3 should still fuse AR+rms even though group quant is not yet supported
|
||||
[llama3_8b_fp8, llama4_scout_fp8, qwen3_a3b_fp8, deepseek_v3_fp8],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"attn_backend", [TRITON_ATTN, FLASHINFER_ATTN, FLASHINFER_MLA_ATTN]
|
||||
)
|
||||
@pytest.mark.parametrize("attn_backend", [TRITON_ATTN, FLASHINFER_ATTN])
|
||||
@pytest.mark.parametrize("n_layers", [4])
|
||||
@pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm"))
|
||||
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
|
||||
@@ -54,7 +58,8 @@ def test_tp2_ar_rms_fp8_fusions(
|
||||
):
|
||||
matches = matches_fn(n_layers)
|
||||
|
||||
if "qwen" in model_name.lower() and "-quant_fp8" in custom_ops:
|
||||
block_fp8 = "qwen" in model_name.lower() or "deepseek" in model_name.lower()
|
||||
if block_fp8 and "-quant_fp8" in custom_ops:
|
||||
# This is why config forces +quant_fp8 by default
|
||||
pytest.skip("native QuantFP8 matching not supported for group quant")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user