[Bugfix] Fix quant RMS norm fusion for quantization with TMA-aligned scales (#33255)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
ElizaWszola
2026-02-18 08:35:04 +01:00
committed by GitHub
parent a49ea5a58f
commit a88b3be7c4
12 changed files with 234 additions and 75 deletions

View File

@@ -50,10 +50,9 @@ def test_tp1_fp8_fusions(
run_e2e_fusion_test,
monkeypatch,
):
if use_deepgemm:
# TODO(luka/eliza) DeepGEMM uses different quants, matching not supported
if use_deepgemm and is_blackwell():
# TODO(luka) DeepGEMM uses different quants, matching not supported
# - on Blackwell, uses a special quant fp8, currently not supported
# - on Hopper, tma-aligned scales inhibit matching (fix WIP)
pytest.skip("DeepGEMM & quant matching not currently supported")
matches = matches_fn(n_layers)
@@ -66,7 +65,6 @@ def test_tp1_fp8_fusions(
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024
compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition,
custom_ops=custom_ops.split(","),