[torch.compile][Fusion] Fix attention fusion pass removing kv_udpate op. (#33945)

Signed-off-by: charlifu <charlifu@amd.com>
(cherry picked from commit bb9f97308d)
This commit is contained in:
Charlie Fu
2026-02-09 15:15:43 -06:00
committed by khluu
parent 5bd63387c3
commit e1e9841631
2 changed files with 22 additions and 1 deletions

View File

@@ -267,7 +267,7 @@ elif current_platform.is_rocm():
PATTERN_TEST_MODELS_FP8 = [
("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
]
BACKENDS = [
BACKENDS_FP8 = [
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
AttentionBackendEnum.ROCM_ATTN,
AttentionBackendEnum.TRITON_ATTN,
@@ -474,6 +474,17 @@ def test_attention_quant_pattern(
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, (
"Attention should not have output_block_scale before fusion"
)
kv_cache_dummy_dep_pre_is_none = (
attn_nodes_pre[0].kwargs.get("kv_cache_dummy_dep") is None
)
kv_cache_dummy_dep_post_is_none = (
attn_nodes_post[0].kwargs.get("kv_cache_dummy_dep") is None
)
assert not (kv_cache_dummy_dep_pre_is_none ^ kv_cache_dummy_dep_post_is_none), (
"The kv_cache_dummy_dep should be consistent before and after fusion"
)
if quant_key.dtype == FP8_DTYPE:
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, (
"Attention should not have output_block_scale after FP8 fusion"