[torch.compile] Add support for non-contiguous fused RMSNorm + group quant (#36551)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
This commit is contained in:
Luka Govedič
2026-03-11 13:56:55 -04:00
committed by GitHub
parent a1a3523a56
commit 9556af87d5
9 changed files with 219 additions and 87 deletions

View File

@@ -17,9 +17,12 @@ from .common import (
)
from .models import (
FLASHINFER_ATTN,
FLASHINFER_MLA_ATTN,
ROCM_AITER_UNIFIED_ATTN,
ROCM_ATTN,
TRITON_ATTN,
TRITON_MLA_ATTN,
deepseek_v3_fp8,
llama3_8b_fp4,
llama3_8b_fp8,
llama4_scout_fp4,
@@ -33,6 +36,9 @@ from .models import (
[
(*llama3_8b_fp8, False),
(*qwen3_a3b_fp8, False),
(*qwen3_a3b_fp8, True),
(*deepseek_v3_fp8, False),
(*deepseek_v3_fp8, True),
pytest.param(
*llama4_scout_fp8,
False,
@@ -41,13 +47,6 @@ from .models import (
reason="Llama4 Scout FP8 only supported on CUDA",
),
),
pytest.param(
*qwen3_a3b_fp8,
True,
marks=pytest.mark.skipif(
not current_platform.is_cuda(), reason="DeepGemm only supported on CUDA"
),
),
],
)
@pytest.mark.parametrize(
@@ -57,6 +56,8 @@ from .models import (
FLASHINFER_ATTN,
ROCM_ATTN,
ROCM_AITER_UNIFIED_ATTN,
FLASHINFER_MLA_ATTN,
TRITON_MLA_ATTN,
],
)
@pytest.mark.parametrize("n_layers", [6])
@@ -75,6 +76,9 @@ def test_tp1_fp8_fusions(
run_e2e_fusion_test,
monkeypatch,
):
if use_deepgemm and not current_platform.is_cuda():
pytest.skip("DeepGemm only supported on CUDA")
if use_deepgemm and is_flashinfer_fp8_blockscale_gemm_supported():
# Flashinfer block FP8 GEMM has internal quantization, so it can't
# be fused with other ops.
@@ -86,7 +90,8 @@ def test_tp1_fp8_fusions(
matches = matches_fn(n_layers)
if "qwen" in model_name.lower() and "-quant_fp8" in custom_ops:
block_fp8 = "qwen" in model_name.lower() or "deepseek" in model_name.lower()
if block_fp8 and "-quant_fp8" in custom_ops:
# This is why config forces +quant_fp8 by default
pytest.skip("native QuantFP8 matching not supported for group quant")