[torch.compile] Add support for non-contiguous fused RMSNorm + group quant (#36551)
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -72,6 +72,16 @@ def run_e2e_fusion_test(monkeypatch, caplog_mp_spawn):
|
||||
|
||||
rocm_aiter_ops.refresh_env_variables()
|
||||
|
||||
# Filter here to reduce code duplication
|
||||
requires_mla = "deepseek" in model_name.lower()
|
||||
is_mla = "mla" in attn_backend.backend.name.lower()
|
||||
|
||||
if requires_mla != is_mla:
|
||||
pytest.skip(
|
||||
f"Incompatible model '{model_name}' and "
|
||||
f"attention backend '{attn_backend.backend.name}'"
|
||||
)
|
||||
|
||||
# Disable, compile cache to make sure custom passes run.
|
||||
# Otherwise, we can't verify fusion happened through the logs.
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
Reference in New Issue
Block a user