[Attention] Refactor attention metadata builder interface (#20466)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson
2025-07-17 00:44:25 -04:00
committed by GitHub
parent 28a6d5423d
commit 76b494444f
18 changed files with 1441 additions and 772 deletions

View File

@@ -15,22 +15,20 @@ from flashinfer.decode import trtllm_batch_decode_with_kv_cache
import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata,
PerLayerParameters,
get_kv_cache_layout,
get_per_layer_parameters,
infer_global_hyperparameters)
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters,
get_kv_cache_layout, get_per_layer_parameters,
infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
@@ -226,9 +224,9 @@ class FlashInferMetadata:
class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec,
block_table: BlockTable):
self.runner = runner
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
device: torch.device):
self.device = device
self._workspace_buffer = None
self._prefill_wrapper = None # Wrapper for prefill/append
self._decode_wrapper = None # Wrapper for decode
@@ -237,75 +235,22 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# Global hyperparameters shared by all attention layers
self.global_hyperparameters: Optional[PerLayerParameters] = None
self.vllm_config = runner.vllm_config
self.vllm_config = vllm_config
self.cache_config = vllm_config.cache_config
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table
def reorder_batch(self, input_batch: InputBatch,
scheduler_output: SchedulerOutput) -> bool:
# We now want to reorder the batch so that the "decode" requests are and
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# where attention is likely memory-bound and "prefill" to mean requests
# where attention is likely compute-bound, TODO(lucas): figure out a
# better naming here)
decodes = []
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# for now treat 1 scheduled token as "decode" even if its not,
# we should update this to something like < 8 in the future but
# currently the decode run only supports num_tokens = 1
if num_tokens == 1:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
decode_idx = decodes[num_decodes - i]
if decode_idx < num_decodes:
break
input_batch.swap_states(prefills[i - 1], decode_idx)
modified_batch = True
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
self._num_decodes = num_decodes
self._num_prefills = num_prefills
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens
return modified_batch
return reorder_batch_to_split_decodes_and_prefills(input_batch,
scheduler_output,
decode_threshold=1)
def _get_workspace_buffer(self):
if self._workspace_buffer is None:
self._workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.runner.device)
device=self.device)
return self._workspace_buffer
def _get_prefill_wrapper(self):
@@ -316,10 +261,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def _get_decode_wrapper(self):
if self._decode_wrapper is None:
num_qo_heads = (self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
num_qo_heads = (
self.vllm_config.model_config.get_num_attention_heads(
self.vllm_config.parallel_config))
num_kv_heads = self.vllm_config.model_config.get_num_kv_heads(
self.vllm_config.parallel_config)
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
num_qo_heads // num_kv_heads > 4)
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
@@ -334,7 +280,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
2, self._get_workspace_buffer(), get_kv_cache_layout())
return self._cascade_wrapper
def _plan(self, attn_metadata: FlashInferMetadata):
def _plan(self, num_prefills: int, num_decodes: int,
attn_metadata: FlashInferMetadata):
if self.global_hyperparameters is None:
self.global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(self.vllm_config, FlashInferImpl))
@@ -369,16 +316,16 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
# Regular attention (common case).
# Decodes are at the front and prefills are at the back,
# according to reorder_batch()
if self._num_prefills > 0:
if num_prefills > 0:
# Decodes are first so prefills start after the last decode
prefill_start = self._num_decodes
prefill_start = num_decodes
attn_metadata.prefill_wrapper = self._get_prefill_wrapper()
assert attn_metadata.qo_indptr[prefill_start:].shape[
0] == self._num_prefills + 1
0] == num_prefills + 1
assert attn_metadata.paged_kv_indptr[prefill_start:].shape[
0] == self._num_prefills + 1
0] == num_prefills + 1
assert attn_metadata.paged_kv_last_page_len[
prefill_start:].shape[0] == self._num_prefills
prefill_start:].shape[0] == num_prefills
# Since prefill_wrapper.run() will be called with
# query[num_decode_tokens:] we need to adjust the qo_indptr
# to be relative to the start of the prefill queries.
@@ -402,17 +349,16 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
kv_data_type=attn_metadata.kv_data_type,
)
if self._num_decodes > 0:
if num_decodes > 0:
attn_metadata.decode_wrapper = self._get_decode_wrapper()
if not FlashInferBackend.use_trtllm_decode_attention(
self._num_decodes, attn_metadata.max_seq_len,
num_decodes, attn_metadata.max_seq_len,
attn_metadata.kv_data_type, attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads, attn_metadata.head_dim):
attn_metadata.decode_wrapper.plan(
attn_metadata.paged_kv_indptr[:self._num_decodes + 1],
attn_metadata.paged_kv_indptr[:num_decodes + 1],
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_len[:self.
_num_decodes],
attn_metadata.paged_kv_last_page_len[:num_decodes],
attn_metadata.num_qo_heads,
attn_metadata.num_kv_heads,
attn_metadata.head_dim,
@@ -427,22 +373,20 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
kv_data_type=attn_metadata.kv_data_type,
)
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
num_reqs = common_attn_metadata.num_reqs
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> FlashInferMetadata:
num_actual_tokens = common_attn_metadata.num_actual_tokens
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
split_decodes_and_prefills(common_attn_metadata)
assert self._num_decodes + self._num_prefills == num_reqs
assert (self._num_decode_tokens +
self._num_prefill_tokens == num_actual_tokens)
page_size = self.kv_cache_spec.block_size
device = self.runner.device
device = self.device
qo_indptr = common_attn_metadata.query_start_loc
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
max_seq_len = common_attn_metadata.seq_lens_cpu.max()
seq_lens = common_attn_metadata.seq_lens
block_table_tensor = self.block_table.get_device_tensor()[:num_reqs]
slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to(
self.runner.device, non_blocking=True).long()
block_table_tensor = common_attn_metadata.block_table_tensor
block_table_bounds = (seq_lens + page_size - 1) // page_size
@@ -487,7 +431,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_last_page_len = seq_lens % page_size
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
page_size, paged_kv_last_page_len)
cache_dtype = self.runner.cache_config.cache_dtype
cache_dtype = self.cache_config.cache_dtype
if cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
cache_dtype)
@@ -499,17 +443,18 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
num_qo_heads=self.runner.num_query_heads,
num_qo_heads=self.vllm_config.model_config.get_num_attention_heads(
self.vllm_config.parallel_config),
num_kv_heads=self.kv_cache_spec.num_kv_heads,
head_dim=self.kv_cache_spec.head_size,
page_size=page_size,
kv_data_type=kv_cache_dtype,
q_data_type=self.runner.dtype,
slot_mapping=slot_mapping,
num_decodes=self._num_decodes,
num_decode_tokens=self._num_decode_tokens,
num_prefills=self._num_prefills,
num_prefill_tokens=self._num_prefill_tokens,
q_data_type=self.vllm_config.model_config.dtype,
slot_mapping=common_attn_metadata.slot_mapping,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
use_cascade=use_cascade,
shared_qo_indptr=shared_qo_indptr,
shared_kv_page_indptr=shared_kv_page_indptr,
@@ -521,12 +466,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
workspace_buffer=self._workspace_buffer,
)
self._plan(attn_metadata)
self._plan(num_prefills, num_decodes, attn_metadata)
return attn_metadata
def use_cascade_attention(self, *args, **kwargs) -> bool:
if self.kv_cache_spec.dtype != self.runner.model_config.dtype:
if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype:
# TODO: The cascade wrapper currently does not support setting
# kv cache dtype to something different from query dtype.
return False