[torch.compile] Add support for non-contiguous fused RMSNorm + group quant (#36551)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
This commit is contained in:
Luka Govedič
2026-03-11 13:56:55 -04:00
committed by GitHub
parent a1a3523a56
commit 9556af87d5
9 changed files with 219 additions and 87 deletions

View File

@@ -72,6 +72,16 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
rocm_aiter_ops.refresh_env_variables()
# Filter here to reduce code duplication
requires_mla = "deepseek" in model_name.lower()
is_mla = "mla" in attn_backend.backend.name.lower()
if requires_mla != is_mla:
pytest.skip(
f"Incompatible model '{model_name}' and "
f"attention backend '{attn_backend.backend.name}'"
)
# Disable, compile cache to make sure custom passes run.
# Otherwise, we can't verify fusion happened through the logs.
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")