[Attention] Refactor attention metadata builder interface (#20466)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2025-07-17 00:44:25 -04:00
committed by GitHub
parent 28a6d5423d
commit 76b494444f
18 changed files with 1441 additions and 772 deletions

View File

@@ -7,15 +7,15 @@ from typing import TYPE_CHECKING, Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import MambaSpec
from vllm.v1.worker.block_table import BlockTable
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata,
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
@@ -87,80 +87,24 @@ class Mamba2AttentionMetadata:
class Mamba2AttentionMetadataBuilder(
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: MambaSpec,
block_table: BlockTable):
self.runner = runner
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
assert isinstance(kv_cache_spec, MambaSpec)
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table
self.chunk_size = runner.vllm_config.model_config.get_mamba_chunk_size(
)
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
assert self.chunk_size is not None, (
"chunk_size needs to be set in the model config for Mamba2 models")
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
# NOTE (Chen): Copied from MLACommonMetadataBuilder and
# FlashInferMetadataBuilder. Should be refactored later to avoid code
# duplication of these 3 functions.
# We now want to reorder the batch so that the "decode" requests are and
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# where attention is likely memory-bound and "prefill" to mean requests
# where attention is likely compute-bound, TODO(lucas): figure out a
# better naming here)
decodes = []
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
return reorder_batch_to_split_decodes_and_prefills(input_batch,
scheduler_output,
decode_threshold=1)
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# for now treat 1 scheduled token as "decode" even if its not,
# we should update this to something like < 8 in the future but
# currently the decode run only supports num_tokens = 1
if num_tokens == 1:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
decode_idx = decodes[num_decodes - i]
if decode_idx < num_decodes:
break
input_batch.swap_states(prefills[i - 1], decode_idx)
modified_batch = True
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
self._num_decodes = num_decodes
self._num_prefills = num_prefills
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens
return modified_batch
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> Mamba2AttentionMetadata:
num_reqs = common_attn_metadata.num_reqs
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
@@ -172,29 +116,31 @@ class Mamba2AttentionMetadataBuilder(
has_initial_states = None
prep_initial_states = False
state_indices_tensor = self.block_table.block_table[:num_reqs, 0]
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=1))
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
if self._num_prefills > 0:
if num_prefills > 0:
#[batch,]
has_initial_states_cpu = (
self.runner.input_batch.
num_computed_tokens_cpu_tensor[num_reqs -
self._num_prefills:num_reqs]
> 0)
common_attn_metadata.
num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0)
prep_initial_states = torch.any(has_initial_states_cpu).item()
has_initial_states = has_initial_states_cpu.to(
query_start_loc.device)
query_start_loc_p = common_attn_metadata.query_start_loc[
-self._num_prefills - 1:] - self._num_decode_tokens
-num_prefills - 1:] - num_decode_tokens
seq_idx = torch.repeat_interleave(
torch.arange(self._num_prefills,
dtype=torch.int32,
device=query_start_loc_p.device),
query_start_loc_p.diff(),
output_size=self._num_prefill_tokens)
seq_idx = torch.repeat_interleave(torch.arange(
num_prefills,
dtype=torch.int32,
device=query_start_loc_p.device),
query_start_loc_p.diff(),
output_size=num_prefill_tokens)
seq_idx.unsqueeze_(0)
# We compute metadata for chunked prefill once at the top level
@@ -204,13 +150,13 @@ class Mamba2AttentionMetadataBuilder(
chunk_indices, chunk_offsets = (
_query_start_loc_to_chunk_indices_offsets(
query_start_loc_p, self.chunk_size,
self._num_prefill_tokens))
num_prefill_tokens))
attn_metadata = Mamba2AttentionMetadata(
num_prefills=self._num_prefills,
num_prefill_tokens=self._num_prefill_tokens,
num_decodes=self._num_decodes,
num_decode_tokens=self._num_decode_tokens,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
query_start_loc=query_start_loc,
seq_lens=seq_lens,
has_initial_states=has_initial_states,