[torch.compile] Add support for non-contiguous fused RMSNorm + group quant (#36551)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -17,7 +17,9 @@ from .common import (
|
||||
)
|
||||
from .models import (
|
||||
FLASHINFER_ATTN,
|
||||
FLASHINFER_MLA_ATTN,
|
||||
TRITON_ATTN,
|
||||
deepseek_v3_fp8,
|
||||
llama3_8b,
|
||||
llama3_8b_fp4,
|
||||
llama3_8b_fp8,
|
||||
@@ -33,10 +35,12 @@ pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only tes
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, matches_fn, model_kwargs, hf_overrides",
|
||||
# qwen3-fp8 should still fuse AR+rms even though group quant is not yet supported
|
||||
[llama3_8b_fp8, llama4_scout_fp8, qwen3_a3b_fp8],
|
||||
# qwen3 & dsv3 should still fuse AR+rms even though group quant is not yet supported
|
||||
[llama3_8b_fp8, llama4_scout_fp8, qwen3_a3b_fp8, deepseek_v3_fp8],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"attn_backend", [TRITON_ATTN, FLASHINFER_ATTN, FLASHINFER_MLA_ATTN]
|
||||
)
|
||||
@pytest.mark.parametrize("attn_backend", [TRITON_ATTN, FLASHINFER_ATTN])
|
||||
@pytest.mark.parametrize("n_layers", [4])
|
||||
@pytest.mark.parametrize("custom_ops", custom_ops_combos("quant_fp8", "rms_norm"))
|
||||
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
|
||||
@@ -54,7 +58,8 @@ def test_tp2_ar_rms_fp8_fusions(
|
||||
):
|
||||
matches = matches_fn(n_layers)
|
||||
|
||||
if "qwen" in model_name.lower() and "-quant_fp8" in custom_ops:
|
||||
block_fp8 = "qwen" in model_name.lower() or "deepseek" in model_name.lower()
|
||||
if block_fp8 and "-quant_fp8" in custom_ops:
|
||||
# This is why config forces +quant_fp8 by default
|
||||
pytest.skip("native QuantFP8 matching not supported for group quant")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user