Add tree attention backend for v1 (part 1) (#20401)

Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
This commit is contained in:
Giancarlo Delfin
2025-08-03 22:13:26 -07:00
committed by GitHub
parent c2e75b3c11
commit aa7012eb6d
12 changed files with 1098 additions and 25 deletions

View File

@@ -214,6 +214,26 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
return self.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata)
def build_for_drafting(
self,
common_attn_metadata: CommonAttentionMetadata,
draft_index: int,
) -> M:
"""
Build attention metadata for draft model. Uses build by default.
Args:
common_attn_metadata: The common attention metadata.
draft_index: The index of the current draft operation.
When speculating a chain of tokens, this index refers to the
draft attempt for the i-th token.
For tree-based attention, this index instead refers to the
draft attempt for the i-th level in the tree of tokens.
"""
return self.build(common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
fast_build=True)
def use_cascade_attention(
self,
common_prefix_len: int,