[v1] Support multiple KV cache groups in GPU model runner (#17945)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-05-15 09:54:54 +08:00
committed by GitHub
parent f25e0d1125
commit e60f550b38
16 changed files with 482 additions and 215 deletions

View File

@@ -1,15 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
import weakref
import pytest
import torch
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VllmConfig)
from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
SchedulerOutput)
from vllm.v1.kv_cache_interface import FullAttentionSpec
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
@@ -17,13 +18,34 @@ def initialize_kv_cache(runner: GPUModelRunner):
"""
Only perform necessary steps in GPUModelRunner.initialize_kv_cache()
"""
kv_cache_spec = FullAttentionSpec(block_size=16,
num_kv_heads=1,
head_size=64,
dtype=torch.float16,
use_mla=False)
runner.attn_metadata_builder = runner.attn_backend.get_builder_cls()(
weakref.proxy(runner), kv_cache_spec, runner.input_batch.block_table)
kv_cache_config = KVCacheConfig(
num_blocks=10,
tensors={
"layer.0": KVCacheTensor(size=1024),
},
kv_cache_groups=[
KVCacheGroupSpec(
layer_names=["layer.0"],
kv_cache_spec=FullAttentionSpec(
block_size=16,
num_kv_heads=runner.model_config.get_num_kv_heads(
runner.parallel_config),
head_size=runner.model_config.get_head_size(),
dtype=runner.kv_cache_dtype,
use_mla=False,
))
])
runner.kv_cache_config = kv_cache_config
runner.input_batch = InputBatch(
max_num_reqs=runner.max_num_reqs,
max_model_len=runner.max_model_len,
max_num_batched_tokens=runner.max_num_tokens,
device=runner.device,
pin_memory=runner.pin_memory,
vocab_size=runner.model_config.get_vocab_size(),
kv_cache_config=kv_cache_config,
)
runner.initialize_attn_backend(kv_cache_config)
@pytest.fixture
@@ -48,10 +70,12 @@ def model_runner():
swap_space=0,
cache_dtype="auto",
)
parallel_config = ParallelConfig()
vllm_config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
scheduler_config=scheduler_config,
parallel_config=parallel_config,
)
device = "cuda"
@@ -73,7 +97,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
mm_hashes=[],
mm_positions=[],
sampling_params=SamplingParams(),
block_ids=[0],
block_ids=[[0]],
num_computed_tokens=0,
lora_request=None,
))
@@ -111,13 +135,14 @@ def _is_sampling_metadata_changed(model_runner,
def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
req_index = model_runner.input_batch.req_id_to_index[req_id]
block_table = model_runner.input_batch.block_table
block_table = model_runner.input_batch.block_table[0]
req_state = model_runner.requests[req_id]
if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids):
if block_table.num_blocks_per_row[req_index] != len(
req_state.block_ids[0]):
return False
num_blocks = block_table.num_blocks_per_row[req_index]
return (block_table.block_table_np[req_index, :num_blocks] ==
req_state.block_ids).all()
req_state.block_ids[0]).all()
def test_update_states_new_request(model_runner):
@@ -200,7 +225,7 @@ def test_update_states_request_resumed(model_runner):
req_id=req_id,
resumed_from_preemption=False,
new_token_ids=[],
new_block_ids=[],
new_block_ids=[[]],
num_computed_tokens=0,
)