[ROCM] Optimize ROCM_AITER_FA spec decode eagle performance (#34541)

Signed-off-by: jennyyyyzhen <yzhen@hmc.edu>
This commit is contained in:
jennyyyyzhen
2026-02-20 20:32:05 -08:00
committed by GitHub
parent 54254f7a61
commit 2aab2bb543

View File

@@ -396,8 +396,7 @@ class AiterFlashAttentionMetadata:
class AiterFlashAttentionMetadataBuilder( class AiterFlashAttentionMetadataBuilder(
AttentionMetadataBuilder[AiterFlashAttentionMetadata] AttentionMetadataBuilder[AiterFlashAttentionMetadata]
): ):
_cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE _cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
reorder_batch_threshold: int = 1
def __init__( def __init__(
self, self,
@@ -422,6 +421,7 @@ class AiterFlashAttentionMetadataBuilder(
# populated on first build() call. # populated on first build() call.
self.aot_sliding_window: tuple[int, int] | None = None self.aot_sliding_window: tuple[int, int] | None = None
self.total_tokens: int = 0 self.total_tokens: int = 0
self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
sliding_window_configs: set[tuple[int, int] | None] = set() sliding_window_configs: set[tuple[int, int] | None] = set()
layers = get_layers_from_vllm_config(self.vllm_config, Attention) layers = get_layers_from_vllm_config(self.vllm_config, Attention)
@@ -466,6 +466,7 @@ class AiterFlashAttentionMetadataBuilder(
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False, fast_build: bool = False,
) -> "AiterFlashAttentionMetadata": ) -> "AiterFlashAttentionMetadata":
assert self.reorder_batch_threshold is not None
split_ret = split_decodes_prefills_and_extends( split_ret = split_decodes_prefills_and_extends(
common_attn_metadata, common_attn_metadata,
decode_threshold=self.reorder_batch_threshold, decode_threshold=self.reorder_batch_threshold,
@@ -677,6 +678,53 @@ class AiterFlashAttentionMetadataBuilder(
) )
return attn_metadata return attn_metadata
def build_for_drafting(
self,
common_attn_metadata: CommonAttentionMetadata,
draft_index: int,
) -> AiterFlashAttentionMetadata:
"""
Build attention metadata for draft model without CPU-GPU sync.
During EAGLE drafting all requests are uniform decodes, so we can
skip split_decodes_prefills_and_extends() and avoid all .cpu() /
.item() calls that would otherwise break CUDA graph capture.
"""
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
decode_metadata = AiterFlashAttentionDecodeMetadata(
max_query_len=common_attn_metadata.max_query_len,
min_query_len=common_attn_metadata.max_query_len, # uniform batch
max_seq_len=common_attn_metadata.max_seq_len,
query_start_loc=common_attn_metadata.query_start_loc,
)
return AiterFlashAttentionMetadata(
num_actual_tokens=num_tokens,
num_actual_kv_tokens=0, # not used in unified_attention path
max_query_len=common_attn_metadata.max_query_len,
query_start_loc=common_attn_metadata.query_start_loc,
max_seq_len=common_attn_metadata.max_seq_len,
seq_lens=common_attn_metadata.seq_lens,
block_table=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping,
num_decodes=num_reqs,
num_decode_tokens=num_tokens,
num_prefills=0,
num_prefill_tokens=0,
num_extends=0,
num_extend_tokens=0,
decode_metadata=decode_metadata,
prefill_metadata=None,
extend_metadata=None,
use_cascade=False,
common_prefix_len=0,
total_tokens=self.total_tokens,
k_scale=self.scale,
v_scale=self.scale,
)
def use_cascade_attention(self, *args, **kwargs) -> bool: def use_cascade_attention(self, *args, **kwargs) -> bool:
return False return False