[ROCM] Optimize ROCM_AITER_FA spec decode eagle performance (#34541)
Signed-off-by: jennyyyyzhen <yzhen@hmc.edu>
This commit is contained in:
@@ -396,8 +396,7 @@ class AiterFlashAttentionMetadata:
|
||||
class AiterFlashAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[AiterFlashAttentionMetadata]
|
||||
):
|
||||
_cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
reorder_batch_threshold: int = 1
|
||||
_cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -422,6 +421,7 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
# populated on first build() call.
|
||||
self.aot_sliding_window: tuple[int, int] | None = None
|
||||
self.total_tokens: int = 0
|
||||
self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
|
||||
|
||||
sliding_window_configs: set[tuple[int, int] | None] = set()
|
||||
layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
@@ -466,6 +466,7 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> "AiterFlashAttentionMetadata":
|
||||
assert self.reorder_batch_threshold is not None
|
||||
split_ret = split_decodes_prefills_and_extends(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold,
|
||||
@@ -677,6 +678,53 @@ class AiterFlashAttentionMetadataBuilder(
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def build_for_drafting(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
draft_index: int,
|
||||
) -> AiterFlashAttentionMetadata:
|
||||
"""
|
||||
Build attention metadata for draft model without CPU-GPU sync.
|
||||
|
||||
During EAGLE drafting all requests are uniform decodes, so we can
|
||||
skip split_decodes_prefills_and_extends() and avoid all .cpu() /
|
||||
.item() calls that would otherwise break CUDA graph capture.
|
||||
"""
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
decode_metadata = AiterFlashAttentionDecodeMetadata(
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
min_query_len=common_attn_metadata.max_query_len, # uniform batch
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
)
|
||||
|
||||
return AiterFlashAttentionMetadata(
|
||||
num_actual_tokens=num_tokens,
|
||||
num_actual_kv_tokens=0, # not used in unified_attention path
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
block_table=common_attn_metadata.block_table_tensor,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
num_decodes=num_reqs,
|
||||
num_decode_tokens=num_tokens,
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_extends=0,
|
||||
num_extend_tokens=0,
|
||||
decode_metadata=decode_metadata,
|
||||
prefill_metadata=None,
|
||||
extend_metadata=None,
|
||||
use_cascade=False,
|
||||
common_prefix_len=0,
|
||||
total_tokens=self.total_tokens,
|
||||
k_scale=self.scale,
|
||||
v_scale=self.scale,
|
||||
)
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user