diff --git a/tests/compile/passes/test_fusion_attn.py b/tests/compile/passes/test_fusion_attn.py index 75d5c42f0..2b29cf605 100644 --- a/tests/compile/passes/test_fusion_attn.py +++ b/tests/compile/passes/test_fusion_attn.py @@ -267,7 +267,7 @@ elif current_platform.is_rocm(): PATTERN_TEST_MODELS_FP8 = [ ("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel) ] - BACKENDS = [ + BACKENDS_FP8 = [ AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN, AttentionBackendEnum.ROCM_ATTN, AttentionBackendEnum.TRITON_ATTN, @@ -474,6 +474,17 @@ def test_attention_quant_pattern( assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, ( "Attention should not have output_block_scale before fusion" ) + + kv_cache_dummy_dep_pre_is_none = ( + attn_nodes_pre[0].kwargs.get("kv_cache_dummy_dep") is None + ) + kv_cache_dummy_dep_post_is_none = ( + attn_nodes_post[0].kwargs.get("kv_cache_dummy_dep") is None + ) + assert not (kv_cache_dummy_dep_pre_is_none ^ kv_cache_dummy_dep_post_is_none), ( + "The kv_cache_dummy_dep should be consistent before and after fusion" + ) + if quant_key.dtype == FP8_DTYPE: assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, ( "Attention should not have output_block_scale after FP8 fusion" diff --git a/vllm/compilation/passes/fusion/attn_quant_fusion.py b/vllm/compilation/passes/fusion/attn_quant_fusion.py index a104aab6c..bb064f58c 100644 --- a/vllm/compilation/passes/fusion/attn_quant_fusion.py +++ b/vllm/compilation/passes/fusion/attn_quant_fusion.py @@ -142,6 +142,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): v: torch.Tensor, output_attn: torch.Tensor, scale: torch.Tensor, + kv_cache_dummy_dep: torch.Tensor, ) -> torch.Tensor: at1 = auto_functionalized( ATTN_OP, @@ -152,6 +153,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): layer_name=self.layer_name, output_scale=None, output_block_scale=None, + kv_cache_dummy_dep=kv_cache_dummy_dep, ) attn_out_view = RESHAPE_OP( at1[1], [q.shape[0], self.num_heads * self.head_size] @@ -165,6 +167,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): v: torch.Tensor, output_attn: torch.Tensor, scale: torch.Tensor, + kv_cache_dummy_dep: torch.Tensor, ) -> torch.Tensor: # attn output in quant_dtype output_attn = torch.ops.aten.full.default( @@ -182,6 +185,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): layer_name=self.layer_name, output_scale=scale, output_block_scale=None, + kv_cache_dummy_dep=kv_cache_dummy_dep, ) return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) @@ -191,6 +195,7 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): self.empty(5, self.num_heads, self.head_size), # v self.empty(5, self.num_heads, self.head_size), # attn_output empty_fp32(1, 1), # scale + self.empty(0), # kv_cache_dummy_dep ] pm.register_replacement( @@ -228,6 +233,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): output_quant: torch.Tensor, output_scale: torch.Tensor, input_scale: torch.Tensor, + kv_cache_dummy_dep: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: at1 = auto_functionalized( ATTN_OP, @@ -238,6 +244,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): layer_name=self.layer_name, output_scale=None, output_block_scale=None, + kv_cache_dummy_dep=kv_cache_dummy_dep, ) attn_out_view = RESHAPE_OP( at1[1], [q.shape[0], self.num_heads * self.head_size] @@ -261,6 +268,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): output_quant: torch.Tensor, output_scale: torch.Tensor, input_scale: torch.Tensor, + kv_cache_dummy_dep: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: # attention output in quant_dtype output_attn = torch.ops.aten.full.default( @@ -280,6 +288,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): layer_name=self.layer_name, output_scale=input_scale, output_block_scale=output_scale_view, + kv_cache_dummy_dep=kv_cache_dummy_dep, ) output = RESHAPE_OP(at2[1], [-1, self.num_heads * self.head_size // 2]) return output, at2[2] @@ -294,6 +303,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): 128, round_up(self.num_heads * self.head_size // 16, 4) ), # output_scale empty_fp32(1, 1), # input_scale + self.empty(0), # kv_cache_dummy_dep ] pm.register_replacement(