[torch.compile][Fusion] Fix attention fusion pass removing kv_udpate op. (#33945)
Signed-off-by: charlifu <charlifu@amd.com>
(cherry picked from commit bb9f97308d)
This commit is contained in:
@@ -267,7 +267,7 @@ elif current_platform.is_rocm():
|
||||
PATTERN_TEST_MODELS_FP8 = [
|
||||
("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
|
||||
]
|
||||
BACKENDS = [
|
||||
BACKENDS_FP8 = [
|
||||
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
|
||||
AttentionBackendEnum.ROCM_ATTN,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
@@ -474,6 +474,17 @@ def test_attention_quant_pattern(
|
||||
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, (
|
||||
"Attention should not have output_block_scale before fusion"
|
||||
)
|
||||
|
||||
kv_cache_dummy_dep_pre_is_none = (
|
||||
attn_nodes_pre[0].kwargs.get("kv_cache_dummy_dep") is None
|
||||
)
|
||||
kv_cache_dummy_dep_post_is_none = (
|
||||
attn_nodes_post[0].kwargs.get("kv_cache_dummy_dep") is None
|
||||
)
|
||||
assert not (kv_cache_dummy_dep_pre_is_none ^ kv_cache_dummy_dep_post_is_none), (
|
||||
"The kv_cache_dummy_dep should be consistent before and after fusion"
|
||||
)
|
||||
|
||||
if quant_key.dtype == FP8_DTYPE:
|
||||
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, (
|
||||
"Attention should not have output_block_scale after FP8 fusion"
|
||||
|
||||
Reference in New Issue
Block a user