[v1] Pass BlockTable and KVCacheSpec to AttentionMetadataBuilders (#17483)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -14,11 +14,13 @@ class BlockTable:
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_num_blocks_per_req: int,
|
||||
max_num_batched_tokens: int,
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.pin_memory = pin_memory
|
||||
self.device = device
|
||||
|
||||
@@ -36,6 +38,15 @@ class BlockTable:
|
||||
self.block_table_np = self.block_table_cpu.numpy()
|
||||
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
|
||||
self.slot_mapping_cpu = torch.zeros(self.max_num_batched_tokens,
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||
self.slot_mapping = torch.zeros(self.max_num_batched_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
|
||||
def append_row(
|
||||
self,
|
||||
block_ids: list[int],
|
||||
|
||||
@@ -59,6 +59,7 @@ class InputBatch:
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_blocks_per_req: int,
|
||||
max_num_batched_tokens: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
@@ -66,6 +67,7 @@ class InputBatch:
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
self.vocab_size = vocab_size
|
||||
@@ -100,6 +102,7 @@ class InputBatch:
|
||||
self.block_table = BlockTable(
|
||||
max_num_reqs=max_num_reqs,
|
||||
max_num_blocks_per_req=max_num_blocks_per_req,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
pin_memory=pin_memory,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@@ -150,8 +150,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
f"FA3. Current attention backend is {attn_backend_name}, "
|
||||
f"FlashAttention version is {flash_attn_version}.")
|
||||
|
||||
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
||||
weakref.proxy(self))
|
||||
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
|
||||
|
||||
# Multi-modal data support
|
||||
@@ -174,6 +172,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Initialize in initialize_kv_cache
|
||||
self.kv_caches: list[torch.Tensor] = []
|
||||
# self.kv_cache_config: KVCacheConfig
|
||||
# self.attn_metadata_builder: type[AttentionMetadataBuilder]
|
||||
|
||||
# req_id -> (input_id -> encoder_output)
|
||||
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
||||
@@ -203,6 +202,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=model_config.get_vocab_size(),
|
||||
@@ -291,11 +291,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
self.positions_np = self.positions_cpu.numpy()
|
||||
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
@@ -586,7 +581,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
block_offsets = positions_np % self.block_size
|
||||
np.add(block_numbers * self.block_size,
|
||||
block_offsets,
|
||||
out=self.slot_mapping_np[:total_num_scheduled_tokens])
|
||||
out=self.input_batch.block_table.
|
||||
slot_mapping_np[:total_num_scheduled_tokens])
|
||||
|
||||
# Prepare the attention metadata.
|
||||
self.query_start_loc_np[0] = 0
|
||||
@@ -614,12 +610,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
|
||||
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
|
||||
non_blocking=True)
|
||||
self.slot_mapping[:total_num_scheduled_tokens].copy_(
|
||||
self.slot_mapping_cpu[:total_num_scheduled_tokens],
|
||||
non_blocking=True)
|
||||
|
||||
# Fill unused with -1. Needed for reshape_and_cache
|
||||
self.slot_mapping[total_num_scheduled_tokens:].fill_(-1)
|
||||
self.seq_lens[num_reqs:].fill_(0)
|
||||
self.query_start_loc[num_reqs + 1:].fill_(-1)
|
||||
|
||||
@@ -1821,6 +1813,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
|
||||
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
||||
weakref.proxy(self),
|
||||
kv_cache_config.kv_cache_groups[0].kv_cache_spec,
|
||||
self.input_batch.block_table)
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
"""
|
||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||
|
||||
@@ -179,6 +179,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.vocab_size,
|
||||
@@ -197,10 +198,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
device="cpu")
|
||||
self.positions_np = self.positions_cpu.numpy()
|
||||
|
||||
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device="cpu")
|
||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||
self.block_table_cpu = torch.zeros(
|
||||
(self.max_num_reqs, self.max_num_blocks_per_req),
|
||||
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
|
||||
@@ -533,7 +530,8 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
block_offsets = positions_np % self.block_size
|
||||
np.add(block_numbers * self.block_size,
|
||||
block_offsets,
|
||||
out=self.slot_mapping_np[:total_num_scheduled_tokens])
|
||||
out=self.input_batch.block_table.
|
||||
slot_mapping_cpu[:total_num_scheduled_tokens])
|
||||
|
||||
# Prepare the attention metadata.
|
||||
self.query_start_loc_np[0] = 0
|
||||
@@ -557,10 +555,12 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.position_ids = self.positions_cpu[:
|
||||
padded_total_num_scheduled_tokens].to(
|
||||
self.device)
|
||||
self.slot_mapping_cpu[total_num_scheduled_tokens:] = _PAD_SLOT_ID
|
||||
slot_mapping = self.slot_mapping_cpu[:
|
||||
padded_total_num_scheduled_tokens].to(
|
||||
self.device)
|
||||
self.input_batch.block_table.slot_mapping_cpu[
|
||||
total_num_scheduled_tokens:] = _PAD_SLOT_ID
|
||||
slot_mapping = (
|
||||
self.input_batch.block_table.
|
||||
slot_mapping_cpu[:padded_total_num_scheduled_tokens].to(
|
||||
self.device))
|
||||
block_tables = self.block_table_cpu[:self.max_num_reqs]
|
||||
block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
|
||||
self.input_batch.block_table.get_cpu_tensor()[:num_reqs])
|
||||
|
||||
Reference in New Issue
Block a user