[ROCM] Optimize ROCM_AITER_FA spec decode eagle performance (#34541)
Signed-off-by: jennyyyyzhen <yzhen@hmc.edu>
This commit is contained in:
@@ -396,8 +396,7 @@ class AiterFlashAttentionMetadata:
|
|||||||
class AiterFlashAttentionMetadataBuilder(
|
class AiterFlashAttentionMetadataBuilder(
|
||||||
AttentionMetadataBuilder[AiterFlashAttentionMetadata]
|
AttentionMetadataBuilder[AiterFlashAttentionMetadata]
|
||||||
):
|
):
|
||||||
_cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
_cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
|
||||||
reorder_batch_threshold: int = 1
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -422,6 +421,7 @@ class AiterFlashAttentionMetadataBuilder(
|
|||||||
# populated on first build() call.
|
# populated on first build() call.
|
||||||
self.aot_sliding_window: tuple[int, int] | None = None
|
self.aot_sliding_window: tuple[int, int] | None = None
|
||||||
self.total_tokens: int = 0
|
self.total_tokens: int = 0
|
||||||
|
self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
|
||||||
|
|
||||||
sliding_window_configs: set[tuple[int, int] | None] = set()
|
sliding_window_configs: set[tuple[int, int] | None] = set()
|
||||||
layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||||
@@ -466,6 +466,7 @@ class AiterFlashAttentionMetadataBuilder(
|
|||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
fast_build: bool = False,
|
fast_build: bool = False,
|
||||||
) -> "AiterFlashAttentionMetadata":
|
) -> "AiterFlashAttentionMetadata":
|
||||||
|
assert self.reorder_batch_threshold is not None
|
||||||
split_ret = split_decodes_prefills_and_extends(
|
split_ret = split_decodes_prefills_and_extends(
|
||||||
common_attn_metadata,
|
common_attn_metadata,
|
||||||
decode_threshold=self.reorder_batch_threshold,
|
decode_threshold=self.reorder_batch_threshold,
|
||||||
@@ -677,6 +678,53 @@ class AiterFlashAttentionMetadataBuilder(
|
|||||||
)
|
)
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
|
def build_for_drafting(
|
||||||
|
self,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
draft_index: int,
|
||||||
|
) -> AiterFlashAttentionMetadata:
|
||||||
|
"""
|
||||||
|
Build attention metadata for draft model without CPU-GPU sync.
|
||||||
|
|
||||||
|
During EAGLE drafting all requests are uniform decodes, so we can
|
||||||
|
skip split_decodes_prefills_and_extends() and avoid all .cpu() /
|
||||||
|
.item() calls that would otherwise break CUDA graph capture.
|
||||||
|
"""
|
||||||
|
num_reqs = common_attn_metadata.num_reqs
|
||||||
|
num_tokens = common_attn_metadata.num_actual_tokens
|
||||||
|
|
||||||
|
decode_metadata = AiterFlashAttentionDecodeMetadata(
|
||||||
|
max_query_len=common_attn_metadata.max_query_len,
|
||||||
|
min_query_len=common_attn_metadata.max_query_len, # uniform batch
|
||||||
|
max_seq_len=common_attn_metadata.max_seq_len,
|
||||||
|
query_start_loc=common_attn_metadata.query_start_loc,
|
||||||
|
)
|
||||||
|
|
||||||
|
return AiterFlashAttentionMetadata(
|
||||||
|
num_actual_tokens=num_tokens,
|
||||||
|
num_actual_kv_tokens=0, # not used in unified_attention path
|
||||||
|
max_query_len=common_attn_metadata.max_query_len,
|
||||||
|
query_start_loc=common_attn_metadata.query_start_loc,
|
||||||
|
max_seq_len=common_attn_metadata.max_seq_len,
|
||||||
|
seq_lens=common_attn_metadata.seq_lens,
|
||||||
|
block_table=common_attn_metadata.block_table_tensor,
|
||||||
|
slot_mapping=common_attn_metadata.slot_mapping,
|
||||||
|
num_decodes=num_reqs,
|
||||||
|
num_decode_tokens=num_tokens,
|
||||||
|
num_prefills=0,
|
||||||
|
num_prefill_tokens=0,
|
||||||
|
num_extends=0,
|
||||||
|
num_extend_tokens=0,
|
||||||
|
decode_metadata=decode_metadata,
|
||||||
|
prefill_metadata=None,
|
||||||
|
extend_metadata=None,
|
||||||
|
use_cascade=False,
|
||||||
|
common_prefix_len=0,
|
||||||
|
total_tokens=self.total_tokens,
|
||||||
|
k_scale=self.scale,
|
||||||
|
v_scale=self.scale,
|
||||||
|
)
|
||||||
|
|
||||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user