Add tree attention backend for v1 (part 1) (#20401)
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
This commit is contained in:
@@ -214,6 +214,26 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
||||
return self.build(common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata)
|
||||
|
||||
def build_for_drafting(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
draft_index: int,
|
||||
) -> M:
|
||||
"""
|
||||
Build attention metadata for draft model. Uses build by default.
|
||||
|
||||
Args:
|
||||
common_attn_metadata: The common attention metadata.
|
||||
draft_index: The index of the current draft operation.
|
||||
When speculating a chain of tokens, this index refers to the
|
||||
draft attempt for the i-th token.
|
||||
For tree-based attention, this index instead refers to the
|
||||
draft attempt for the i-th level in the tree of tokens.
|
||||
"""
|
||||
return self.build(common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
fast_build=True)
|
||||
|
||||
def use_cascade_attention(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
|
||||
Reference in New Issue
Block a user