[v1] Support multiple KV cache groups in GPU model runner (#17945)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -9,9 +9,11 @@ import torch
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheTensor)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu_input_batch import (BlockTable, CachedRequestState,
|
||||
InputBatch)
|
||||
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
VOCAB_SIZE = 1024
|
||||
NUM_OUTPUT_TOKENS = 20
|
||||
@@ -22,6 +24,27 @@ CUDA_DEVICES = [
|
||||
MAX_NUM_PROMPT_TOKENS = 64
|
||||
|
||||
|
||||
def get_kv_cache_config() -> KVCacheConfig:
|
||||
return KVCacheConfig(
|
||||
num_blocks=10,
|
||||
tensors={
|
||||
"layer.0": KVCacheTensor(size=1024),
|
||||
},
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
layer_names=["layer.0"],
|
||||
kv_cache_spec=FullAttentionSpec(
|
||||
block_size=1,
|
||||
num_kv_heads=1,
|
||||
head_size=16,
|
||||
dtype=torch.float16,
|
||||
use_mla=False,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _compare_objs(obj1, obj2):
|
||||
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
|
||||
attr_names = set([
|
||||
@@ -41,6 +64,10 @@ def _compare_objs(obj1, obj2):
|
||||
elif isinstance(a, np.ndarray):
|
||||
if np.allclose(a, b):
|
||||
is_same = True
|
||||
elif isinstance(a, MultiGroupBlockTable):
|
||||
for a_i, b_i in zip(a.block_tables, b.block_tables):
|
||||
_compare_objs(a_i, b_i)
|
||||
is_same = True
|
||||
elif isinstance(a, (BlockTable, SamplingMetadata)):
|
||||
_compare_objs(a, b)
|
||||
is_same = True # if we make it here must be same
|
||||
@@ -198,7 +225,7 @@ def _construct_cached_request_state(req_id_suffix: int):
|
||||
sampling_params=_create_sampling_params(),
|
||||
mm_inputs=[],
|
||||
mm_positions=[],
|
||||
block_ids=[],
|
||||
block_ids=[[]],
|
||||
generator=None,
|
||||
num_computed_tokens=len(output_token_ids),
|
||||
output_token_ids=output_token_ids,
|
||||
@@ -220,11 +247,11 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
input_batch: InputBatch = InputBatch(
|
||||
max_num_reqs=batch_size,
|
||||
max_model_len=1024,
|
||||
max_num_blocks_per_req=10,
|
||||
max_num_batched_tokens=1024,
|
||||
device=torch.device(device),
|
||||
pin_memory=is_pin_memory_available(),
|
||||
vocab_size=1024,
|
||||
kv_cache_config=get_kv_cache_config(),
|
||||
)
|
||||
reqs: list[CachedRequestState] = []
|
||||
req_id_reqs = {}
|
||||
@@ -310,20 +337,20 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
|
||||
input_batch: InputBatch = InputBatch(
|
||||
max_num_reqs=batch_size,
|
||||
max_model_len=1024,
|
||||
max_num_blocks_per_req=10,
|
||||
max_num_batched_tokens=1024,
|
||||
device=torch.device(device),
|
||||
pin_memory=is_pin_memory_available(),
|
||||
vocab_size=1024,
|
||||
kv_cache_config=get_kv_cache_config(),
|
||||
)
|
||||
ref_input_batch: InputBatch = InputBatch(
|
||||
max_num_reqs=batch_size,
|
||||
max_model_len=1024,
|
||||
max_num_blocks_per_req=10,
|
||||
max_num_batched_tokens=1024,
|
||||
device=torch.device(device),
|
||||
pin_memory=is_pin_memory_available(),
|
||||
vocab_size=1024,
|
||||
kv_cache_config=get_kv_cache_config(),
|
||||
)
|
||||
|
||||
reqs: list[CachedRequestState] = []
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import weakref
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig, VllmConfig)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||||
SchedulerOutput)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheTensor)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
|
||||
@@ -17,13 +18,34 @@ def initialize_kv_cache(runner: GPUModelRunner):
|
||||
"""
|
||||
Only perform necessary steps in GPUModelRunner.initialize_kv_cache()
|
||||
"""
|
||||
kv_cache_spec = FullAttentionSpec(block_size=16,
|
||||
num_kv_heads=1,
|
||||
head_size=64,
|
||||
dtype=torch.float16,
|
||||
use_mla=False)
|
||||
runner.attn_metadata_builder = runner.attn_backend.get_builder_cls()(
|
||||
weakref.proxy(runner), kv_cache_spec, runner.input_batch.block_table)
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=10,
|
||||
tensors={
|
||||
"layer.0": KVCacheTensor(size=1024),
|
||||
},
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
layer_names=["layer.0"],
|
||||
kv_cache_spec=FullAttentionSpec(
|
||||
block_size=16,
|
||||
num_kv_heads=runner.model_config.get_num_kv_heads(
|
||||
runner.parallel_config),
|
||||
head_size=runner.model_config.get_head_size(),
|
||||
dtype=runner.kv_cache_dtype,
|
||||
use_mla=False,
|
||||
))
|
||||
])
|
||||
runner.kv_cache_config = kv_cache_config
|
||||
runner.input_batch = InputBatch(
|
||||
max_num_reqs=runner.max_num_reqs,
|
||||
max_model_len=runner.max_model_len,
|
||||
max_num_batched_tokens=runner.max_num_tokens,
|
||||
device=runner.device,
|
||||
pin_memory=runner.pin_memory,
|
||||
vocab_size=runner.model_config.get_vocab_size(),
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
runner.initialize_attn_backend(kv_cache_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -48,10 +70,12 @@ def model_runner():
|
||||
swap_space=0,
|
||||
cache_dtype="auto",
|
||||
)
|
||||
parallel_config = ParallelConfig()
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
scheduler_config=scheduler_config,
|
||||
parallel_config=parallel_config,
|
||||
)
|
||||
|
||||
device = "cuda"
|
||||
@@ -73,7 +97,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
||||
mm_hashes=[],
|
||||
mm_positions=[],
|
||||
sampling_params=SamplingParams(),
|
||||
block_ids=[0],
|
||||
block_ids=[[0]],
|
||||
num_computed_tokens=0,
|
||||
lora_request=None,
|
||||
))
|
||||
@@ -111,13 +135,14 @@ def _is_sampling_metadata_changed(model_runner,
|
||||
|
||||
def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
|
||||
req_index = model_runner.input_batch.req_id_to_index[req_id]
|
||||
block_table = model_runner.input_batch.block_table
|
||||
block_table = model_runner.input_batch.block_table[0]
|
||||
req_state = model_runner.requests[req_id]
|
||||
if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids):
|
||||
if block_table.num_blocks_per_row[req_index] != len(
|
||||
req_state.block_ids[0]):
|
||||
return False
|
||||
num_blocks = block_table.num_blocks_per_row[req_index]
|
||||
return (block_table.block_table_np[req_index, :num_blocks] ==
|
||||
req_state.block_ids).all()
|
||||
req_state.block_ids[0]).all()
|
||||
|
||||
|
||||
def test_update_states_new_request(model_runner):
|
||||
@@ -200,7 +225,7 @@ def test_update_states_request_resumed(model_runner):
|
||||
req_id=req_id,
|
||||
resumed_from_preemption=False,
|
||||
new_token_ids=[],
|
||||
new_block_ids=[],
|
||||
new_block_ids=[[]],
|
||||
num_computed_tokens=0,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user