[Attention] Refactor attention metadata builder interface (#20466)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2025-07-17 00:44:25 -04:00
committed by GitHub
parent 28a6d5423d
commit 76b494444f
18 changed files with 1441 additions and 772 deletions

View File

@@ -22,6 +22,7 @@ import vllm.envs as envs
from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout)
from vllm.logger import init_logger
from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__)
_KV_CACHE_LAYOUT_OVERRIDE = None
@@ -32,14 +33,22 @@ class CommonAttentionMetadata:
"""
Per-batch attention metadata, shared across layers and backends.
AttentionMetadataBuilder instances use it to construct per-layer metadata.
For many of the tensors we keep both GPU and CPU versions.
"""
query_start_loc: torch.Tensor
query_start_loc_cpu: torch.Tensor
"""(batch_size + 1,), the start location of each request in query Tensor"""
seq_lens: torch.Tensor
seq_lens_cpu: torch.Tensor
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""
num_computed_tokens_cpu: torch.Tensor
"""(batch_size,), the number of computed tokens for each request"""
num_reqs: int
"""Number of requests"""
num_actual_tokens: int
@@ -47,6 +56,14 @@ class CommonAttentionMetadata:
max_query_len: int
"""Longest query in batch"""
block_table_tensor: torch.Tensor
slot_mapping: torch.Tensor
def __post_init__(self):
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
# mode.
self.slot_mapping[self.num_actual_tokens:].fill_(-1)
M = TypeVar("M")
@@ -56,11 +73,25 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
full_cudagraph_supported: ClassVar[bool] = False
@abstractmethod
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata) -> M:
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
self.kv_cache_spec = kv_cache_spec
@abstractmethod
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> M:
"""
Central method that builds attention metadata.
Some builders (MLA) require reorder_batch to be called prior to build.
Args:
common_prefix_len: The length of the common prefix of the batch.
common_attn_metadata: The common attention metadata.
fast_build: The meta-data will prioritize speed of building over
then speed at execution. Can be used for spec-decode where the
result of a build call may only be used for few layers/iters.
"""
raise NotImplementedError
@@ -351,3 +382,108 @@ def make_local_attention_virtual_batches(
return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \
block_table_local
def split_decodes_and_prefills(
common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1,
) -> tuple[int, int, int, int]:
"""
Assuming a reordered batch, finds the boundary between prefill and decode
requests.
Args:
common_attn_metadata: CommonAttentionMetadata object containing the
batch metadata.
decode_threshold: The maximum query length to be considered a decode.
Returns:
num_decodes: The number of decode requests.
num_prefills: The number of prefill requests.
num_decode_tokens: The number of tokens in the decode requests.
num_prefill_tokens: The number of tokens in the prefill requests.
"""
max_query_len = common_attn_metadata.max_query_len
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc_cpu
if max_query_len <= decode_threshold:
return num_reqs, 0, num_tokens, 0
query_lens = query_start_loc[1:] - query_start_loc[:-1]
is_prefill = query_lens > decode_threshold
if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0
first_prefill = is_prefill.int().argmax(dim=-1).item()
assert torch.all(query_lens[first_prefill:] > decode_threshold)
assert torch.all(query_lens[:first_prefill] <= decode_threshold)
num_decodes = first_prefill
num_prefills = num_reqs - num_decodes
num_decode_tokens = query_start_loc[first_prefill].item()
num_prefill_tokens = num_tokens - num_decode_tokens
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
def reorder_batch_to_split_decodes_and_prefills(
input_batch: "InputBatch",
scheduler_output: "SchedulerOutput",
decode_threshold: int = 1,
) -> bool:
"""
Reorders the batch to split into prefill and decode requests; places all
requests with <= decode_threshold tokens at the front of the batch.
Returns:
True if the batch was modified, False otherwise.
"""
# We now want to reorder the batch so that the "decode" requests are at
# the front and the "prefill" requests are at the back using the least
# amount of swaps possible. (NOTE for now we loosely use "decode" to mean
# requests where attention is likely memory-bound and "prefill" to mean
# requests where attention is likely compute-bound, TODO(lucas): figure out
# a better naming here)
decodes = []
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# for now treat 1 scheduled token as "decode" even if its not,
# we should update this to something like < 8 in the future but
# currently the TritonMLA._forward_decode only supports
# num_tokens = 1
if num_tokens <= decode_threshold:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
decode_idx = decodes[num_decodes - i]
if decode_idx < num_decodes:
break
input_batch.swap_states(prefills[i - 1], decode_idx)
modified_batch = True
return modified_batch