[torch.compile] Add support for non-contiguous fused RMSNorm + group quant (#36551)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -17,9 +17,12 @@ from .common import (
|
||||
)
|
||||
from .models import (
|
||||
FLASHINFER_ATTN,
|
||||
FLASHINFER_MLA_ATTN,
|
||||
ROCM_AITER_UNIFIED_ATTN,
|
||||
ROCM_ATTN,
|
||||
TRITON_ATTN,
|
||||
TRITON_MLA_ATTN,
|
||||
deepseek_v3_fp8,
|
||||
llama3_8b_fp4,
|
||||
llama3_8b_fp8,
|
||||
llama4_scout_fp4,
|
||||
@@ -33,6 +36,9 @@ from .models import (
|
||||
[
|
||||
(*llama3_8b_fp8, False),
|
||||
(*qwen3_a3b_fp8, False),
|
||||
(*qwen3_a3b_fp8, True),
|
||||
(*deepseek_v3_fp8, False),
|
||||
(*deepseek_v3_fp8, True),
|
||||
pytest.param(
|
||||
*llama4_scout_fp8,
|
||||
False,
|
||||
@@ -41,13 +47,6 @@ from .models import (
|
||||
reason="Llama4 Scout FP8 only supported on CUDA",
|
||||
),
|
||||
),
|
||||
pytest.param(
|
||||
*qwen3_a3b_fp8,
|
||||
True,
|
||||
marks=pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="DeepGemm only supported on CUDA"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
@@ -57,6 +56,8 @@ from .models import (
|
||||
FLASHINFER_ATTN,
|
||||
ROCM_ATTN,
|
||||
ROCM_AITER_UNIFIED_ATTN,
|
||||
FLASHINFER_MLA_ATTN,
|
||||
TRITON_MLA_ATTN,
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("n_layers", [6])
|
||||
@@ -75,6 +76,9 @@ def test_tp1_fp8_fusions(
|
||||
run_e2e_fusion_test,
|
||||
monkeypatch,
|
||||
):
|
||||
if use_deepgemm and not current_platform.is_cuda():
|
||||
pytest.skip("DeepGemm only supported on CUDA")
|
||||
|
||||
if use_deepgemm and is_flashinfer_fp8_blockscale_gemm_supported():
|
||||
# Flashinfer block FP8 GEMM has internal quantization, so it can't
|
||||
# be fused with other ops.
|
||||
@@ -86,7 +90,8 @@ def test_tp1_fp8_fusions(
|
||||
|
||||
matches = matches_fn(n_layers)
|
||||
|
||||
if "qwen" in model_name.lower() and "-quant_fp8" in custom_ops:
|
||||
block_fp8 = "qwen" in model_name.lower() or "deepseek" in model_name.lower()
|
||||
if block_fp8 and "-quant_fp8" in custom_ops:
|
||||
# This is why config forces +quant_fp8 by default
|
||||
pytest.skip("native QuantFP8 matching not supported for group quant")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user