[v1] Redo "Support multiple KV cache groups in GPU model runner (#17945)" (#18593)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-05-24 00:39:47 +08:00
committed by GitHub
parent 9520a989df
commit 6550114c9c
15 changed files with 469 additions and 203 deletions

View File

@@ -4,6 +4,7 @@ import numpy as np
import torch
from vllm.logger import init_logger
from vllm.utils import cdiv
logger = init_logger(__name__)
@@ -96,3 +97,43 @@ class BlockTable:
def get_numpy_array(self) -> np.ndarray:
"""Returns the numpy array of the block table."""
return self.block_table_np
class MultiGroupBlockTable:
"""The BlockTables for each KV cache group."""
def __init__(self, max_num_reqs: int, max_model_len: int,
max_num_batched_tokens: int, pin_memory: bool,
device: torch.device, block_size: int) -> None:
self.block_tables = [
BlockTable(max_num_reqs, cdiv(max_model_len, block_size),
max_num_batched_tokens, pin_memory, device)
]
def append_row(self, block_ids: list[list[int]], row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.append_row(block_ids[i], row_idx)
def add_row(self, block_ids: list[list[int]], row_idx: int) -> None:
for i, block_table in enumerate(self.block_tables):
block_table.add_row(block_ids[i], row_idx)
def move_row(self, src: int, tgt: int) -> None:
for block_table in self.block_tables:
block_table.move_row(src, tgt)
def swap_row(self, src: int, tgt: int) -> None:
for block_table in self.block_tables:
block_table.swap_row(src, tgt)
def commit(self, num_reqs: int) -> None:
for block_table in self.block_tables:
block_table.commit(num_reqs)
def clear(self) -> None:
for block_table in self.block_tables:
block_table.clear()
def __getitem__(self, idx: int) -> "BlockTable":
"""Returns the BlockTable for the i-th KV cache group."""
return self.block_tables[idx]

View File

@@ -14,7 +14,7 @@ from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import copy_slice
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.block_table import MultiGroupBlockTable
_SAMPLING_EPS = 1e-5
@@ -29,7 +29,7 @@ class CachedRequestState:
sampling_params: SamplingParams
generator: Optional[torch.Generator]
block_ids: list[int]
block_ids: list[list[int]]
num_computed_tokens: int
output_token_ids: list[int]
@@ -58,15 +58,14 @@ class InputBatch:
self,
max_num_reqs: int,
max_model_len: int,
max_num_blocks_per_req: int,
max_num_batched_tokens: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
block_size: int,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_blocks_per_req = max_num_blocks_per_req
self.max_num_batched_tokens = max_num_batched_tokens
self.device = device
self.pin_memory = pin_memory
@@ -99,12 +98,13 @@ class InputBatch:
self.num_computed_tokens_cpu_tensor.numpy()
# Block table.
self.block_table = BlockTable(
self.block_table = MultiGroupBlockTable(
max_num_reqs=max_num_reqs,
max_num_blocks_per_req=max_num_blocks_per_req,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
pin_memory=pin_memory,
device=device,
block_size=block_size,
)
# Sampling-related.

View File

@@ -12,6 +12,8 @@ import torch.distributed
import torch.nn as nn
from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadataBuilder)
from vllm.attention.layer import Attention
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import (CompilationLevel, VllmConfig,
@@ -32,8 +34,8 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LayerBlockType, LazyLoader, cdiv,
check_use_alibi, is_pin_memory_available)
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
@@ -51,6 +53,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.utils import is_spec_decode_supported
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
@@ -103,59 +106,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
cache_config.cache_dtype]
# NOTE(woosuk): sliding_window is None for models with interleaved
# attention. Use interleaved_sliding_window instead.
self.sliding_window = model_config.get_sliding_window()
self.interleaved_sliding_window = getattr(
model_config.hf_text_config, "interleaved_sliding_window", None)
self.window_size = (self.sliding_window
or self.interleaved_sliding_window)
self.is_multimodal_model = model_config.is_multimodal_model
self.block_size = cache_config.block_size
self.max_model_len = model_config.max_model_len
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs
# Model-related.
self.num_attn_layers = model_config.get_num_layers_by_block_type(
parallel_config, LayerBlockType.attention)
self.num_query_heads = model_config.get_num_attention_heads(
parallel_config)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.head_size = model_config.get_head_size()
self.hidden_size = model_config.get_hidden_size()
self.attention_chunk_size = model_config.attention_chunk_size
self.attn_backend = get_attn_backend(
self.head_size,
self.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.model_config.use_mla,
)
if self.attn_backend is None:
error_msg = (
f"Error with get_att_backend: {self.head_size=}, "
f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
f"{self.model_config.is_attention_free=}, "
f"{self.model_config.use_mla=}")
logger.error(error_msg)
raise NotImplementedError(
"Non-Attention backend is not supported by V1 GPUModelRunner.")
if self.vllm_config.compilation_config.full_cuda_graph:
attn_backend_name = self.attn_backend.__name__
flash_attn_version = get_flash_attn_version()
if attn_backend_name != "FlashAttentionBackend" or \
flash_attn_version != 3:
raise ValueError(
f"full_cuda_graph is only supported with "
f"FA3. Current attention backend is {attn_backend_name}, "
f"FlashAttention version is {flash_attn_version}.")
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
# Multi-modal data support
@@ -177,8 +138,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# self.model: nn.Module # Set after load_model
# Initialize in initialize_kv_cache
self.kv_caches: list[torch.Tensor] = []
self.attn_metadata_builders: list[AttentionMetadataBuilder] = []
self.attn_backends: list[type[AttentionBackend]] = []
# self.kv_cache_config: KVCacheConfig
# self.attn_metadata_builder: type[AttentionMetadataBuilder]
# self.input_batch: InputBatch # Persistent batch.
# req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
@@ -207,15 +170,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Request states.
self.requests: dict[str, CachedRequestState] = {}
# Persistent batch.
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=model_config.get_vocab_size(),
vocab_size=self.model_config.get_vocab_size(),
block_size=self.cache_config.block_size,
)
self.use_cuda_graph = (self.vllm_config.compilation_config.level
@@ -311,6 +274,31 @@ class GPUModelRunner(LoRAModelRunnerMixin):
pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy()
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool:
"""
Update the order of requests in the batch based on the attention
backend's needs. For example, some attention backends (namely MLA) may
want to separate requests based on if the attention computation will be
compute-bound or memory-bound.
Args:
scheduler_output: The scheduler output.
Returns:
True if the batch was reordered, False otherwise.
"""
batch_reordered = self.attn_metadata_builders[0].reorder_batch(
self.input_batch, scheduler_output)
# For models with multiple KV cache groups, the groups should agree on
# the same order of requests. We ensure this by only allowing the first
# group to reorder the batch and asserting that all other groups do not
# reorder the batch.
for i in range(1, len(self.kv_cache_config.kv_cache_groups)):
assert not self.attn_metadata_builders[i].reorder_batch(
self.input_batch, scheduler_output)
return batch_reordered
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
"""Update the cached states and the persistent batch with the scheduler
output.
@@ -447,7 +435,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Update the block IDs.
if not req_data.resumed_from_preemption:
# Append the new blocks to the existing block IDs.
req_state.block_ids.extend(req_data.new_block_ids)
for i in range(len(self.kv_cache_config.kv_cache_groups)):
req_state.block_ids[i].extend(req_data.new_block_ids[i])
else:
# The request is resumed from preemption.
# Replace the existing block IDs with the new ones.
@@ -505,11 +494,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if removed_req_indices:
self.input_batch.condense(removed_req_indices)
# Some attention backends (namely MLA) may want to separate requests
# based on if the attention computation will be compute-bound or
# memory-bound. This gives them a hook to do that.
batch_reordered = self.attn_metadata_builder.reorder_batch(
self.input_batch, scheduler_output)
batch_reordered = self._may_reorder_batch(scheduler_output)
if batch_changed or batch_reordered:
self.input_batch.refresh_sampling_metadata()
@@ -577,21 +562,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
torch.from_numpy(token_indices),
out=self.input_ids_cpu[:total_num_scheduled_tokens])
# Calculate the slot mapping.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size` here
# because M (max_model_len) is not necessarily divisible by block_size.
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions_np // self.block_size)
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
block_offsets = positions_np % self.block_size
np.add(block_numbers * self.block_size,
block_offsets,
out=self.input_batch.block_table.
slot_mapping_np[:total_num_scheduled_tokens])
# Calculate the slot mapping for each KV cache group.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
block_size = kv_cache_group_spec.kv_cache_spec.block_size
block_table: BlockTable = self.input_batch.block_table[
kv_cache_group_id]
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size`
# here because M (max_model_len) is not necessarily divisible by
# block_size.
block_table_indices = (
req_indices * block_table.max_num_blocks_per_req +
positions_np // block_size)
block_table_cpu = block_table.get_cpu_tensor()
block_numbers = block_table_cpu.flatten(
)[block_table_indices].numpy()
block_offsets = positions_np % block_size
np.add(
block_numbers * block_size,
block_offsets,
out=block_table.slot_mapping_np[:total_num_scheduled_tokens])
# Prepare the attention metadata.
self.query_start_loc_np[0] = 0
@@ -633,10 +626,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
attn_metadata: dict[str, FlashAttentionMetadata] = {}
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
# NOTE(Chen): there is exactly one KV cache group that contains all
# attetnion layers in the model for now, so the current logic for
# getting attn_metadata is not related to kv_cache_group information.
# Will extend this part to support multiple KV cache groups later.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
@@ -645,15 +634,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.cascade_attn_enabled:
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
scheduler_output.num_common_prefix_blocks,
scheduler_output.
num_common_prefix_blocks[kv_cache_group_id],
kv_cache_group_spec.kv_cache_spec,
self.attn_metadata_builders[kv_cache_group_id],
)
attn_metadata_i = self.attn_metadata_builder.build(
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata)
attn_metadata_i = (
self.attn_metadata_builders[kv_cache_group_id].build(
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata))
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
@@ -691,6 +684,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self,
num_scheduled_tokens: np.ndarray,
num_common_prefix_blocks: int,
kv_cache_spec: KVCacheSpec,
attn_metadata_builder: AttentionMetadataBuilder,
) -> int:
"""Compute the length of the common prefix for cascade attention.
@@ -709,7 +704,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
Returns:
int: Length of common prefix in tokens.
"""
common_prefix_len = num_common_prefix_blocks * self.block_size
common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size
if common_prefix_len == 0:
# Common case.
return 0
@@ -758,15 +753,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
common_prefix_len,
self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
# common_prefix_len should be a multiple of the block size.
common_prefix_len = (common_prefix_len // self.block_size *
self.block_size)
use_cascade = self.attn_metadata_builder.use_cascade_attention(
common_prefix_len = (common_prefix_len // kv_cache_spec.block_size *
kv_cache_spec.block_size)
use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or
(isinstance(kv_cache_spec, FullAttentionSpec)
and kv_cache_spec.sliding_window is not None))
assert isinstance(kv_cache_spec, AttentionSpec)
use_cascade = attn_metadata_builder.use_cascade_attention(
common_prefix_len=common_prefix_len,
query_lens=num_scheduled_tokens,
num_query_heads=self.num_query_heads,
num_kv_heads=self.num_kv_heads,
num_kv_heads=kv_cache_spec.num_kv_heads,
use_alibi=self.use_alibi,
use_sliding_window=self.window_size is not None,
use_sliding_window=use_sliding_window,
num_sms=self.num_sms,
)
return common_prefix_len if use_cascade else 0
@@ -1661,7 +1660,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dtype=np.int32)
if skip_attn:
attn_metadata = None
attn_metadata: Optional[dict[str, FlashAttentionMetadata]] = None
else:
query_start_loc = self.query_start_loc[:num_reqs + 1]
seq_lens = self.seq_lens[:num_reqs]
@@ -1669,13 +1668,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, seq_lens=seq_lens)
attn_metadata = self.attn_metadata_builder.build(
num_reqs=num_tokens,
num_actual_tokens=num_tokens,
max_query_len=num_tokens,
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
attn_metadata = {}
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
attn_metadata_i = (
self.attn_metadata_builders[kv_cache_group_id].build(
num_reqs=num_tokens,
num_actual_tokens=num_tokens,
max_query_len=num_tokens,
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
))
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
@@ -1909,6 +1914,56 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, cuda_graph_size / (1 << 30))
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize the attention backends and attention metadata builders.
"""
assert len(self.attn_backends) == 0 and len(
self.attn_metadata_builders
) == 0, "Attention backends are already initialized"
for i, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
if not isinstance(kv_cache_spec, AttentionSpec):
raise NotImplementedError(
"Only AttentionSpec is supported for now.")
attn_backend_i = get_attn_backend(
kv_cache_spec.head_size,
self.dtype,
kv_cache_spec.dtype,
kv_cache_spec.block_size,
self.model_config.is_attention_free,
use_mla=kv_cache_spec.use_mla,
)
if attn_backend_i is None:
error_msg = (
f"Error with get_attn_backend: {kv_cache_spec.head_size=}, "
f"{self.dtype=}, {kv_cache_spec.dtype=}, "
f"{kv_cache_spec.block_size=}, "
f"{self.model_config.is_attention_free=}, "
f"{kv_cache_spec.use_mla=}")
logger.error(error_msg)
raise NotImplementedError(
"Non-Attention backend is not supported by V1 "
"GPUModelRunner.")
if self.vllm_config.compilation_config.full_cuda_graph:
attn_backend_name = attn_backend_i.__name__
flash_attn_version = get_flash_attn_version()
if attn_backend_name != "FlashAttentionBackend" or \
flash_attn_version != 3:
raise ValueError(
f"full_cuda_graph is only supported with "
f"FA3. Current attention backend is "
f"{attn_backend_name}, FlashAttention version is "
f"{flash_attn_version}.")
block_table_i = self.input_batch.block_table[i]
attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
weakref.proxy(self), kv_cache_spec, block_table_i)
self.attn_backends.append(attn_backend_i)
self.attn_metadata_builders.append(attn_metadata_builder_i)
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
@@ -1921,10 +1976,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
"Hybrid models with more than one KV cache type are not "
"supported yet.")
self.kv_cache_config = kv_cache_config
self.initialize_attn_backend(kv_cache_config)
kv_caches: dict[str, torch.Tensor] = {}
for kv_cache_group in kv_cache_config.kv_cache_groups:
for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group.kv_cache_spec
for layer_name in kv_cache_group.layer_names:
tensor_config = kv_cache_config.tensors[layer_name]
@@ -1939,7 +1995,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# the min of all `num_blocks`. Verify it here.
assert num_blocks >= kv_cache_config.num_blocks
if isinstance(kv_cache_spec, AttentionSpec):
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
kv_cache_shape = self.attn_backends[i].get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
@@ -1959,11 +2015,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(kv_caches)
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self),
kv_cache_config.kv_cache_groups[0].kv_cache_spec,
self.input_batch.block_table)
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"""
Generates the KVCacheSpec by parsing the kv cache format from each

View File

@@ -171,19 +171,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.kv_caches: list[torch.Tensor] = []
# req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
# self.input_batch: InputBatch # Persistent batch.
# Request states.
self.requests: dict[str, CachedRequestState] = {}
# Persistent batch.
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.vocab_size,
)
# Cached torch/numpy tensor
# The pytorch tensor and numpy array share the same buffer.
@@ -199,7 +190,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.block_table_cpu = torch.zeros(
(self.max_num_reqs, self.max_num_blocks_per_req),
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
dtype=torch.int32,
device="cpu")
self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1,
@@ -524,12 +515,12 @@ class TPUModelRunner(LoRAModelRunnerMixin):
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
block_offsets = positions_np % self.block_size
np.add(block_numbers * self.block_size,
block_offsets,
out=self.input_batch.block_table.
out=self.input_batch.block_table[0].
slot_mapping_np[:total_num_scheduled_tokens])
# Prepare the attention metadata.
@@ -554,15 +545,15 @@ class TPUModelRunner(LoRAModelRunnerMixin):
self.position_ids = self.positions_cpu[:
padded_total_num_scheduled_tokens].to(
self.device)
self.input_batch.block_table.slot_mapping_cpu[
self.input_batch.block_table[0].slot_mapping_cpu[
total_num_scheduled_tokens:] = _PAD_SLOT_ID
slot_mapping = (
self.input_batch.block_table.
self.input_batch.block_table[0].
slot_mapping_cpu[:padded_total_num_scheduled_tokens].to(
self.device))
block_tables = self.block_table_cpu[:self.max_num_reqs]
block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
self.input_batch.block_table.get_cpu_tensor()[:num_reqs])
self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
block_tables = block_tables.to(self.device)
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to(
self.device)
@@ -1263,6 +1254,19 @@ class TPUModelRunner(LoRAModelRunnerMixin):
"Hybrid models with more than one KV cache type are not "
"supported yet.")
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.
block_size,
)
assert self.block_table_cpu.dtype == self.input_batch.block_table[
0].get_cpu_tensor().dtype
kv_caches: dict[str, torch.Tensor] = {}
for kv_cache_group in kv_cache_config.kv_cache_groups: