diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 2ea3c346f..0c1e1b5e0 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -396,8 +396,7 @@ class AiterFlashAttentionMetadata: class AiterFlashAttentionMetadataBuilder( AttentionMetadataBuilder[AiterFlashAttentionMetadata] ): - _cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE - reorder_batch_threshold: int = 1 + _cudagraph_support = AttentionCGSupport.UNIFORM_BATCH def __init__( self, @@ -422,6 +421,7 @@ class AiterFlashAttentionMetadataBuilder( # populated on first build() call. self.aot_sliding_window: tuple[int, int] | None = None self.total_tokens: int = 0 + self._init_reorder_batch_threshold(1, supports_spec_as_decode=True) sliding_window_configs: set[tuple[int, int] | None] = set() layers = get_layers_from_vllm_config(self.vllm_config, Attention) @@ -466,6 +466,7 @@ class AiterFlashAttentionMetadataBuilder( common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False, ) -> "AiterFlashAttentionMetadata": + assert self.reorder_batch_threshold is not None split_ret = split_decodes_prefills_and_extends( common_attn_metadata, decode_threshold=self.reorder_batch_threshold, @@ -677,6 +678,53 @@ class AiterFlashAttentionMetadataBuilder( ) return attn_metadata + def build_for_drafting( + self, + common_attn_metadata: CommonAttentionMetadata, + draft_index: int, + ) -> AiterFlashAttentionMetadata: + """ + Build attention metadata for draft model without CPU-GPU sync. + + During EAGLE drafting all requests are uniform decodes, so we can + skip split_decodes_prefills_and_extends() and avoid all .cpu() / + .item() calls that would otherwise break CUDA graph capture. + """ + num_reqs = common_attn_metadata.num_reqs + num_tokens = common_attn_metadata.num_actual_tokens + + decode_metadata = AiterFlashAttentionDecodeMetadata( + max_query_len=common_attn_metadata.max_query_len, + min_query_len=common_attn_metadata.max_query_len, # uniform batch + max_seq_len=common_attn_metadata.max_seq_len, + query_start_loc=common_attn_metadata.query_start_loc, + ) + + return AiterFlashAttentionMetadata( + num_actual_tokens=num_tokens, + num_actual_kv_tokens=0, # not used in unified_attention path + max_query_len=common_attn_metadata.max_query_len, + query_start_loc=common_attn_metadata.query_start_loc, + max_seq_len=common_attn_metadata.max_seq_len, + seq_lens=common_attn_metadata.seq_lens, + block_table=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping, + num_decodes=num_reqs, + num_decode_tokens=num_tokens, + num_prefills=0, + num_prefill_tokens=0, + num_extends=0, + num_extend_tokens=0, + decode_metadata=decode_metadata, + prefill_metadata=None, + extend_metadata=None, + use_cascade=False, + common_prefix_len=0, + total_tokens=self.total_tokens, + k_scale=self.scale, + v_scale=self.scale, + ) + def use_cascade_attention(self, *args, **kwargs) -> bool: return False