[Bugfix] Fix quant RMS norm fusion for quantization with TMA-aligned scales (#33255)
Signed-off-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -50,10 +50,9 @@ def test_tp1_fp8_fusions(
|
||||
run_e2e_fusion_test,
|
||||
monkeypatch,
|
||||
):
|
||||
if use_deepgemm:
|
||||
# TODO(luka/eliza) DeepGEMM uses different quants, matching not supported
|
||||
if use_deepgemm and is_blackwell():
|
||||
# TODO(luka) DeepGEMM uses different quants, matching not supported
|
||||
# - on Blackwell, uses a special quant fp8, currently not supported
|
||||
# - on Hopper, tma-aligned scales inhibit matching (fix WIP)
|
||||
pytest.skip("DeepGEMM & quant matching not currently supported")
|
||||
|
||||
matches = matches_fn(n_layers)
|
||||
@@ -66,7 +65,6 @@ def test_tp1_fp8_fusions(
|
||||
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
|
||||
model_kwargs["load_format"] = "dummy"
|
||||
model_kwargs["max_model_len"] = 1024
|
||||
|
||||
compilation_config = dict(
|
||||
use_inductor_graph_partition=inductor_graph_partition,
|
||||
custom_ops=custom_ops.split(","),
|
||||
|
||||
Reference in New Issue
Block a user