[torch.compile][Fusion] Fix attention fusion pass removing kv_udpate op. (#33945)
Signed-off-by: charlifu <charlifu@amd.com>
(cherry picked from commit bb9f97308d)
This commit is contained in:
@@ -267,7 +267,7 @@ elif current_platform.is_rocm():
|
|||||||
PATTERN_TEST_MODELS_FP8 = [
|
PATTERN_TEST_MODELS_FP8 = [
|
||||||
("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
|
("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
|
||||||
]
|
]
|
||||||
BACKENDS = [
|
BACKENDS_FP8 = [
|
||||||
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
|
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
|
||||||
AttentionBackendEnum.ROCM_ATTN,
|
AttentionBackendEnum.ROCM_ATTN,
|
||||||
AttentionBackendEnum.TRITON_ATTN,
|
AttentionBackendEnum.TRITON_ATTN,
|
||||||
@@ -474,6 +474,17 @@ def test_attention_quant_pattern(
|
|||||||
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, (
|
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, (
|
||||||
"Attention should not have output_block_scale before fusion"
|
"Attention should not have output_block_scale before fusion"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
kv_cache_dummy_dep_pre_is_none = (
|
||||||
|
attn_nodes_pre[0].kwargs.get("kv_cache_dummy_dep") is None
|
||||||
|
)
|
||||||
|
kv_cache_dummy_dep_post_is_none = (
|
||||||
|
attn_nodes_post[0].kwargs.get("kv_cache_dummy_dep") is None
|
||||||
|
)
|
||||||
|
assert not (kv_cache_dummy_dep_pre_is_none ^ kv_cache_dummy_dep_post_is_none), (
|
||||||
|
"The kv_cache_dummy_dep should be consistent before and after fusion"
|
||||||
|
)
|
||||||
|
|
||||||
if quant_key.dtype == FP8_DTYPE:
|
if quant_key.dtype == FP8_DTYPE:
|
||||||
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, (
|
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, (
|
||||||
"Attention should not have output_block_scale after FP8 fusion"
|
"Attention should not have output_block_scale after FP8 fusion"
|
||||||
|
|||||||
@@ -142,6 +142,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
|||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
output_attn: torch.Tensor,
|
output_attn: torch.Tensor,
|
||||||
scale: torch.Tensor,
|
scale: torch.Tensor,
|
||||||
|
kv_cache_dummy_dep: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
at1 = auto_functionalized(
|
at1 = auto_functionalized(
|
||||||
ATTN_OP,
|
ATTN_OP,
|
||||||
@@ -152,6 +153,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
|||||||
layer_name=self.layer_name,
|
layer_name=self.layer_name,
|
||||||
output_scale=None,
|
output_scale=None,
|
||||||
output_block_scale=None,
|
output_block_scale=None,
|
||||||
|
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||||
)
|
)
|
||||||
attn_out_view = RESHAPE_OP(
|
attn_out_view = RESHAPE_OP(
|
||||||
at1[1], [q.shape[0], self.num_heads * self.head_size]
|
at1[1], [q.shape[0], self.num_heads * self.head_size]
|
||||||
@@ -165,6 +167,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
|||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
output_attn: torch.Tensor,
|
output_attn: torch.Tensor,
|
||||||
scale: torch.Tensor,
|
scale: torch.Tensor,
|
||||||
|
kv_cache_dummy_dep: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# attn output in quant_dtype
|
# attn output in quant_dtype
|
||||||
output_attn = torch.ops.aten.full.default(
|
output_attn = torch.ops.aten.full.default(
|
||||||
@@ -182,6 +185,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
|||||||
layer_name=self.layer_name,
|
layer_name=self.layer_name,
|
||||||
output_scale=scale,
|
output_scale=scale,
|
||||||
output_block_scale=None,
|
output_block_scale=None,
|
||||||
|
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||||
)
|
)
|
||||||
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
|
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
|
||||||
|
|
||||||
@@ -191,6 +195,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
|
|||||||
self.empty(5, self.num_heads, self.head_size), # v
|
self.empty(5, self.num_heads, self.head_size), # v
|
||||||
self.empty(5, self.num_heads, self.head_size), # attn_output
|
self.empty(5, self.num_heads, self.head_size), # attn_output
|
||||||
empty_fp32(1, 1), # scale
|
empty_fp32(1, 1), # scale
|
||||||
|
self.empty(0), # kv_cache_dummy_dep
|
||||||
]
|
]
|
||||||
|
|
||||||
pm.register_replacement(
|
pm.register_replacement(
|
||||||
@@ -228,6 +233,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
|||||||
output_quant: torch.Tensor,
|
output_quant: torch.Tensor,
|
||||||
output_scale: torch.Tensor,
|
output_scale: torch.Tensor,
|
||||||
input_scale: torch.Tensor,
|
input_scale: torch.Tensor,
|
||||||
|
kv_cache_dummy_dep: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
at1 = auto_functionalized(
|
at1 = auto_functionalized(
|
||||||
ATTN_OP,
|
ATTN_OP,
|
||||||
@@ -238,6 +244,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
|||||||
layer_name=self.layer_name,
|
layer_name=self.layer_name,
|
||||||
output_scale=None,
|
output_scale=None,
|
||||||
output_block_scale=None,
|
output_block_scale=None,
|
||||||
|
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||||
)
|
)
|
||||||
attn_out_view = RESHAPE_OP(
|
attn_out_view = RESHAPE_OP(
|
||||||
at1[1], [q.shape[0], self.num_heads * self.head_size]
|
at1[1], [q.shape[0], self.num_heads * self.head_size]
|
||||||
@@ -261,6 +268,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
|||||||
output_quant: torch.Tensor,
|
output_quant: torch.Tensor,
|
||||||
output_scale: torch.Tensor,
|
output_scale: torch.Tensor,
|
||||||
input_scale: torch.Tensor,
|
input_scale: torch.Tensor,
|
||||||
|
kv_cache_dummy_dep: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
# attention output in quant_dtype
|
# attention output in quant_dtype
|
||||||
output_attn = torch.ops.aten.full.default(
|
output_attn = torch.ops.aten.full.default(
|
||||||
@@ -280,6 +288,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
|||||||
layer_name=self.layer_name,
|
layer_name=self.layer_name,
|
||||||
output_scale=input_scale,
|
output_scale=input_scale,
|
||||||
output_block_scale=output_scale_view,
|
output_block_scale=output_scale_view,
|
||||||
|
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||||
)
|
)
|
||||||
output = RESHAPE_OP(at2[1], [-1, self.num_heads * self.head_size // 2])
|
output = RESHAPE_OP(at2[1], [-1, self.num_heads * self.head_size // 2])
|
||||||
return output, at2[2]
|
return output, at2[2]
|
||||||
@@ -294,6 +303,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
|
|||||||
128, round_up(self.num_heads * self.head_size // 16, 4)
|
128, round_up(self.num_heads * self.head_size // 16, 4)
|
||||||
), # output_scale
|
), # output_scale
|
||||||
empty_fp32(1, 1), # input_scale
|
empty_fp32(1, 1), # input_scale
|
||||||
|
self.empty(0), # kv_cache_dummy_dep
|
||||||
]
|
]
|
||||||
|
|
||||||
pm.register_replacement(
|
pm.register_replacement(
|
||||||
|
|||||||
Reference in New Issue
Block a user