[mla] Support fused FP8/NVFP4 output quantization in MLA attention (#35792) (#36205)

Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl Y <4531192+carlyou@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Carl Y
2026-04-02 18:16:11 -07:00
committed by GitHub
parent ee3cf45739
commit 1f5ec2889c
12 changed files with 928 additions and 17 deletions

View File

@@ -18,8 +18,11 @@ from .common import (
from .models import (
FLASHINFER_ATTN,
FLASHINFER_MLA_ATTN,
FLASHMLA_SPARSE_ATTN,
TRITON_ATTN,
deepseek_coder_v2_lite_fp8,
deepseek_v3_fp8,
deepseek_v32_fp4,
gpt_oss_20b,
llama3_8b,
llama3_8b_fp4,
@@ -37,7 +40,13 @@ pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only tes
@pytest.mark.parametrize(
"model_name, matches_fn, model_kwargs, hf_overrides",
# qwen3 & dsv3 should still fuse AR+rms even though group quant is not yet supported
[llama3_8b_fp8, llama4_scout_fp8, qwen3_a3b_fp8, deepseek_v3_fp8],
[
llama3_8b_fp8,
llama4_scout_fp8,
qwen3_a3b_fp8,
deepseek_coder_v2_lite_fp8,
deepseek_v3_fp8,
],
)
@pytest.mark.parametrize(
"attn_backend", [TRITON_ATTN, FLASHINFER_ATTN, FLASHINFER_MLA_ATTN]
@@ -104,9 +113,12 @@ def test_tp2_ar_rms_fp8_fusions(
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"model_name, matches_fn, model_kwargs, hf_overrides",
[llama3_8b_fp4, llama4_scout_fp4],
[llama3_8b_fp4, llama4_scout_fp4, deepseek_v32_fp4],
)
@pytest.mark.parametrize(
"attn_backend",
[FLASHINFER_ATTN, FLASHMLA_SPARSE_ATTN],
)
@pytest.mark.parametrize("attn_backend", [FLASHINFER_ATTN])
@pytest.mark.parametrize("n_layers", [4])
@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm"))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)