[mla] Support fused FP8/NVFP4 output quantization in MLA attention (#35792) (#36205)

Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl Y <4531192+carlyou@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Carl Y
2026-04-02 18:16:11 -07:00
committed by GitHub
parent ee3cf45739
commit 1f5ec2889c
12 changed files with 928 additions and 17 deletions

View File

@@ -18,11 +18,14 @@ from .common import (
from .models import (
FLASHINFER_ATTN,
FLASHINFER_MLA_ATTN,
FLASHMLA_SPARSE_ATTN,
ROCM_AITER_UNIFIED_ATTN,
ROCM_ATTN,
TRITON_ATTN,
TRITON_MLA_ATTN,
deepseek_coder_v2_lite_fp8,
deepseek_v3_fp8,
deepseek_v32_fp4,
llama3_8b_fp4,
llama3_8b_fp8,
llama4_scout_fp4,
@@ -37,6 +40,7 @@ from .models import (
(*llama3_8b_fp8, False),
(*qwen3_a3b_fp8, False),
(*qwen3_a3b_fp8, True),
(*deepseek_coder_v2_lite_fp8, False),
(*deepseek_v3_fp8, False),
(*deepseek_v3_fp8, True),
pytest.param(
@@ -144,9 +148,12 @@ def test_tp1_fp8_fusions(
@pytest.mark.parametrize(
"model_name, matches_fn, model_kwargs, hf_overrides",
[llama3_8b_fp4, llama4_scout_fp4],
[llama3_8b_fp4, llama4_scout_fp4, deepseek_v32_fp4],
)
@pytest.mark.parametrize(
"attn_backend",
[FLASHINFER_ATTN, FLASHMLA_SPARSE_ATTN],
)
@pytest.mark.parametrize("attn_backend", [FLASHINFER_ATTN])
@pytest.mark.parametrize("n_layers", [6])
@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm"))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)