[torch.compile][Fusion] Fix attention fusion pass removing kv_udpate op. (#33945)
Signed-off-by: charlifu <charlifu@amd.com>
(cherry picked from commit bb9f97308d)
This commit is contained in:
@@ -267,7 +267,7 @@ elif current_platform.is_rocm():
|
||||
PATTERN_TEST_MODELS_FP8 = [
|
||||
("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
|
||||
]
|
||||
BACKENDS = [
|
||||
BACKENDS_FP8 = [
|
||||
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
|
||||
AttentionBackendEnum.ROCM_ATTN,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
@@ -474,6 +474,17 @@ def test_attention_quant_pattern(
|
||||
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, (
|
||||
"Attention should not have output_block_scale before fusion"
|
||||
)
|
||||
|
||||
kv_cache_dummy_dep_pre_is_none = (
|
||||
attn_nodes_pre[0].kwargs.get("kv_cache_dummy_dep") is None
|
||||
)
|
||||
kv_cache_dummy_dep_post_is_none = (
|
||||
attn_nodes_post[0].kwargs.get("kv_cache_dummy_dep") is None
|
||||
)
|
||||
assert not (kv_cache_dummy_dep_pre_is_none ^ kv_cache_dummy_dep_post_is_none), (
|
||||
"The kv_cache_dummy_dep should be consistent before and after fusion"
|
||||
)
|
||||
|
||||
if quant_key.dtype == FP8_DTYPE:
|
||||
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, (
|
||||
"Attention should not have output_block_scale after FP8 fusion"
|
||||
|
||||
@@ -142,6 +142,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
kv_cache_dummy_dep: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
at1 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
@@ -152,6 +153,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
layer_name=self.layer_name,
|
||||
output_scale=None,
|
||||
output_block_scale=None,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
attn_out_view = RESHAPE_OP(
|
||||
at1[1], [q.shape[0], self.num_heads * self.head_size]
|
||||
@@ -165,6 +167,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
v: torch.Tensor,
|
||||
output_attn: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
kv_cache_dummy_dep: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# attn output in quant_dtype
|
||||
output_attn = torch.ops.aten.full.default(
|
||||
@@ -182,6 +185,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
layer_name=self.layer_name,
|
||||
output_scale=scale,
|
||||
output_block_scale=None,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
|
||||
|
||||
@@ -191,6 +195,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
||||
self.empty(5, self.num_heads, self.head_size), # v
|
||||
self.empty(5, self.num_heads, self.head_size), # attn_output
|
||||
empty_fp32(1, 1), # scale
|
||||
self.empty(0), # kv_cache_dummy_dep
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
@@ -228,6 +233,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
output_quant: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
kv_cache_dummy_dep: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
at1 = auto_functionalized(
|
||||
ATTN_OP,
|
||||
@@ -238,6 +244,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
layer_name=self.layer_name,
|
||||
output_scale=None,
|
||||
output_block_scale=None,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
attn_out_view = RESHAPE_OP(
|
||||
at1[1], [q.shape[0], self.num_heads * self.head_size]
|
||||
@@ -261,6 +268,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
output_quant: torch.Tensor,
|
||||
output_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
kv_cache_dummy_dep: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# attention output in quant_dtype
|
||||
output_attn = torch.ops.aten.full.default(
|
||||
@@ -280,6 +288,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
layer_name=self.layer_name,
|
||||
output_scale=input_scale,
|
||||
output_block_scale=output_scale_view,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
output = RESHAPE_OP(at2[1], [-1, self.num_heads * self.head_size // 2])
|
||||
return output, at2[2]
|
||||
@@ -294,6 +303,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
||||
128, round_up(self.num_heads * self.head_size // 16, 4)
|
||||
), # output_scale
|
||||
empty_fp32(1, 1), # input_scale
|
||||
self.empty(0), # kv_cache_dummy_dep
|
||||
]
|
||||
|
||||
pm.register_replacement(
|
||||
|
||||
Reference in New Issue
Block a user