[v1] Pass BlockTable and KVCacheSpec to AttentionMetadataBuilders (#17483)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -19,6 +19,8 @@ from vllm.config import (VllmConfig, get_current_vllm_config,
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@@ -202,7 +204,8 @@ class FlashInferMetadata:
|
||||
|
||||
class FlashInferMetadataBuilder:
|
||||
|
||||
def __init__(self, runner: GPUModelRunner):
|
||||
def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec,
|
||||
block_table: BlockTable):
|
||||
self.runner = runner
|
||||
self._workspace_buffer = None
|
||||
self._prefill_wrapper = None # Wrapper for prefill/append
|
||||
@@ -213,6 +216,8 @@ class FlashInferMetadataBuilder:
|
||||
self.global_hyperparameters: Optional[PerLayerParameters] = None
|
||||
|
||||
self.vllm_config = get_current_vllm_config()
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_table = block_table
|
||||
|
||||
def reorder_batch(self, input_batch: InputBatch,
|
||||
scheduler_output: SchedulerOutput) -> bool:
|
||||
@@ -400,13 +405,12 @@ class FlashInferMetadataBuilder:
|
||||
assert self._num_decodes + self._num_prefills == num_reqs
|
||||
assert (self._num_decode_tokens +
|
||||
self._num_prefill_tokens == num_actual_tokens)
|
||||
page_size = self.runner.block_size
|
||||
page_size = self.kv_cache_spec.block_size
|
||||
device = self.runner.device
|
||||
qo_indptr = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table = (
|
||||
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
|
||||
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
block_table_tensor = self.block_table.get_device_tensor()[:num_reqs]
|
||||
slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
self.runner.device, non_blocking=True).long()
|
||||
|
||||
block_table_bounds = (seq_lens + page_size - 1) // page_size
|
||||
@@ -422,12 +426,13 @@ class FlashInferMetadataBuilder:
|
||||
shared_kv_page_indptr = torch.tensor([0, num_common_kv_blocks],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
shared_kv_page_indices = block_table[0, :num_common_kv_blocks]
|
||||
shared_kv_page_indices = block_table_tensor[
|
||||
0, :num_common_kv_blocks]
|
||||
shared_kv_last_page_len = torch.tensor([page_size],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
# Remove the blocks of the shared prefix from all requests.
|
||||
block_table = block_table[:, num_common_kv_blocks:]
|
||||
block_table_tensor = block_table_tensor[:, num_common_kv_blocks:]
|
||||
block_table_bounds -= num_common_kv_blocks
|
||||
else:
|
||||
shared_qo_indptr = None
|
||||
@@ -435,11 +440,11 @@ class FlashInferMetadataBuilder:
|
||||
shared_kv_page_indices = None
|
||||
shared_kv_last_page_len = None
|
||||
|
||||
mask = (torch.arange(block_table.size(1),
|
||||
dtype=block_table.dtype,
|
||||
device=block_table.device).unsqueeze(0)
|
||||
mask = (torch.arange(block_table_tensor.size(1),
|
||||
dtype=block_table_tensor.dtype,
|
||||
device=block_table_tensor.device).unsqueeze(0)
|
||||
< block_table_bounds.unsqueeze(1))
|
||||
paged_kv_indices = block_table[mask]
|
||||
paged_kv_indices = block_table_tensor[mask]
|
||||
|
||||
paged_kv_indptr = torch.cat([
|
||||
torch.zeros(1,
|
||||
@@ -459,10 +464,10 @@ class FlashInferMetadataBuilder:
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||
num_qo_heads=self.runner.num_query_heads,
|
||||
num_kv_heads=self.runner.num_kv_heads,
|
||||
head_dim=self.runner.head_size,
|
||||
num_kv_heads=self.kv_cache_spec.num_kv_heads,
|
||||
head_dim=self.kv_cache_spec.head_size,
|
||||
page_size=page_size,
|
||||
data_type=self.runner.kv_cache_dtype,
|
||||
data_type=self.kv_cache_spec.dtype,
|
||||
q_data_type=self.runner.dtype,
|
||||
slot_mapping=slot_mapping,
|
||||
num_decodes=self._num_decodes,
|
||||
@@ -481,7 +486,7 @@ class FlashInferMetadataBuilder:
|
||||
return attn_metadata
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
if self.runner.kv_cache_dtype != self.runner.model_config.dtype:
|
||||
if self.kv_cache_spec.dtype != self.runner.model_config.dtype:
|
||||
# TODO: The cascade wrapper currently does not support setting
|
||||
# kv cache dtype to something different from query dtype.
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user