[v1] Pass BlockTable and KVCacheSpec to AttentionMetadataBuilders (#17483)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-05-11 07:12:04 +08:00
committed by GitHub
parent 4c31218f80
commit 950751a987
11 changed files with 132 additions and 68 deletions

View File

@@ -14,11 +14,13 @@ class BlockTable:
self,
max_num_reqs: int,
max_num_blocks_per_req: int,
max_num_batched_tokens: int,
pin_memory: bool,
device: torch.device,
):
self.max_num_reqs = max_num_reqs
self.max_num_blocks_per_req = max_num_blocks_per_req
self.max_num_batched_tokens = max_num_batched_tokens
self.pin_memory = pin_memory
self.device = device
@@ -36,6 +38,15 @@ class BlockTable:
self.block_table_np = self.block_table_cpu.numpy()
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
dtype=torch.int64,
device=self.device)
def append_row(
self,
block_ids: list[int],

View File

@@ -59,6 +59,7 @@ class InputBatch:
max_num_reqs: int,
max_model_len: int,
max_num_blocks_per_req: int,
max_num_batched_tokens: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
@@ -66,6 +67,7 @@ class InputBatch:
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_blocks_per_req = max_num_blocks_per_req
self.max_num_batched_tokens = max_num_batched_tokens
self.device = device
self.pin_memory = pin_memory
self.vocab_size = vocab_size
@@ -100,6 +102,7 @@ class InputBatch:
self.block_table = BlockTable(
max_num_reqs=max_num_reqs,
max_num_blocks_per_req=max_num_blocks_per_req,
max_num_batched_tokens=max_num_batched_tokens,
pin_memory=pin_memory,
device=device,
)

View File

@@ -150,8 +150,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
f"FA3. Current attention backend is {attn_backend_name}, "
f"FlashAttention version is {flash_attn_version}.")
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self))
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
# Multi-modal data support
@@ -174,6 +172,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Initialize in initialize_kv_cache
self.kv_caches: list[torch.Tensor] = []
# self.kv_cache_config: KVCacheConfig
# self.attn_metadata_builder: type[AttentionMetadataBuilder]
# req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
@@ -203,6 +202,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=model_config.get_vocab_size(),
@@ -291,11 +291,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
device="cpu",
pin_memory=self.pin_memory)
self.positions_np = self.positions_cpu.numpy()
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
@@ -586,7 +581,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
block_offsets = positions_np % self.block_size
np.add(block_numbers * self.block_size,
block_offsets,
out=self.slot_mapping_np[:total_num_scheduled_tokens])
out=self.input_batch.block_table.
slot_mapping_np[:total_num_scheduled_tokens])
# Prepare the attention metadata.
self.query_start_loc_np[0] = 0
@@ -614,12 +610,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True)
self.slot_mapping[:total_num_scheduled_tokens].copy_(
self.slot_mapping_cpu[:total_num_scheduled_tokens],
non_blocking=True)
# Fill unused with -1. Needed for reshape_and_cache
self.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
self.seq_lens[num_reqs:].fill_(0)
self.query_start_loc[num_reqs + 1:].fill_(-1)
@@ -1821,6 +1813,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self),
kv_cache_config.kv_cache_groups[0].kv_cache_spec,
self.input_batch.block_table)
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""
Generates the KVCacheSpec by parsing the kv cache format from each

View File

@@ -179,6 +179,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.vocab_size,
@@ -197,10 +198,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
device="cpu")
self.positions_np = self.positions_cpu.numpy()
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device="cpu")
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
self.block_table_cpu = torch.zeros(
(self.max_num_reqs, self.max_num_blocks_per_req),
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
@@ -533,7 +530,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
block_offsets = positions_np % self.block_size
np.add(block_numbers * self.block_size,
block_offsets,
out=self.slot_mapping_np[:total_num_scheduled_tokens])
out=self.input_batch.block_table.
slot_mapping_cpu[:total_num_scheduled_tokens])
# Prepare the attention metadata.
self.query_start_loc_np[0] = 0
@@ -557,10 +555,12 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.position_ids = self.positions_cpu[:
padded_total_num_scheduled_tokens].to(
self.device)
self.slot_mapping_cpu[total_num_scheduled_tokens:] = _PAD_SLOT_ID
slot_mapping = self.slot_mapping_cpu[:
padded_total_num_scheduled_tokens].to(
self.device)
self.input_batch.block_table.slot_mapping_cpu[
total_num_scheduled_tokens:] = _PAD_SLOT_ID
slot_mapping = (
self.input_batch.block_table.
slot_mapping_cpu[:padded_total_num_scheduled_tokens].to(
self.device))
block_tables = self.block_table_cpu[:self.max_num_reqs]
block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
self.input_batch.block_table.get_cpu_tensor()[:num_reqs])