[v1] Pass BlockTable and KVCacheSpec to AttentionMetadataBuilders (#17483)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-05-11 07:12:04 +08:00
committed by GitHub
parent 4c31218f80
commit 950751a987
11 changed files with 132 additions and 68 deletions

View File

@@ -19,6 +19,8 @@ from vllm.config import (VllmConfig, get_current_vllm_config,
from vllm.logger import init_logger
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
@@ -202,7 +204,8 @@ class FlashInferMetadata:
class FlashInferMetadataBuilder:
def __init__(self, runner: GPUModelRunner):
def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec,
block_table: BlockTable):
self.runner = runner
self._workspace_buffer = None
self._prefill_wrapper = None # Wrapper for prefill/append
@@ -213,6 +216,8 @@ class FlashInferMetadataBuilder:
self.global_hyperparameters: Optional[PerLayerParameters] = None
self.vllm_config = get_current_vllm_config()
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table
def reorder_batch(self, input_batch: InputBatch,
scheduler_output: SchedulerOutput) -> bool:
@@ -400,13 +405,12 @@ class FlashInferMetadataBuilder:
assert self._num_decodes + self._num_prefills == num_reqs
assert (self._num_decode_tokens +
self._num_prefill_tokens == num_actual_tokens)
page_size = self.runner.block_size
page_size = self.kv_cache_spec.block_size
device = self.runner.device
qo_indptr = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
block_table_tensor = self.block_table.get_device_tensor()[:num_reqs]
slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to(
self.runner.device, non_blocking=True).long()
block_table_bounds = (seq_lens + page_size - 1) // page_size
@@ -422,12 +426,13 @@ class FlashInferMetadataBuilder:
shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks],
dtype=torch.int32,
device=device)
shared_kv_page_indices = block_table[0, :num_common_kv_blocks]
shared_kv_page_indices = block_table_tensor[
0, :num_common_kv_blocks]
shared_kv_last_page_len = torch.tensor([page_size],
dtype=torch.int32,
device=device)
# Remove the blocks of the shared prefix from all requests.
block_table = block_table[:, num_common_kv_blocks:]
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
block_table_bounds -= num_common_kv_blocks
else:
shared_qo_indptr = None
@@ -435,11 +440,11 @@ class FlashInferMetadataBuilder:
shared_kv_page_indices = None
shared_kv_last_page_len = None
mask = (torch.arange(block_table.size(1),
dtype=block_table.dtype,
device=block_table.device).unsqueeze(0)
mask = (torch.arange(block_table_tensor.size(1),
dtype=block_table_tensor.dtype,
device=block_table_tensor.device).unsqueeze(0)
< block_table_bounds.unsqueeze(1))
paged_kv_indices = block_table[mask]
paged_kv_indices = block_table_tensor[mask]
paged_kv_indptr = torch.cat([
torch.zeros(1,
@@ -459,10 +464,10 @@ class FlashInferMetadataBuilder:
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
num_qo_heads=self.runner.num_query_heads,
num_kv_heads=self.runner.num_kv_heads,
head_dim=self.runner.head_size,
num_kv_heads=self.kv_cache_spec.num_kv_heads,
head_dim=self.kv_cache_spec.head_size,
page_size=page_size,
data_type=self.runner.kv_cache_dtype,
data_type=self.kv_cache_spec.dtype,
q_data_type=self.runner.dtype,
slot_mapping=slot_mapping,
num_decodes=self._num_decodes,
@@ -481,7 +486,7 @@ class FlashInferMetadataBuilder:
return attn_metadata
def use_cascade_attention(self, *args, **kwargs) -> bool:
if self.runner.kv_cache_dtype != self.runner.model_config.dtype:
if self.kv_cache_spec.dtype != self.runner.model_config.dtype:
# TODO: The cascade wrapper currently does not support setting
# kv cache dtype to something different from query dtype.
return False