Revert "[v1] Support multiple KV cache groups in GPU model runner (#17945) (#18459)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Mark McLoughlin
2025-05-21 18:25:23 +01:00
committed by GitHub
parent dd5fa7e04f
commit bb0a311213
15 changed files with 214 additions and 481 deletions

View File

@@ -9,11 +9,9 @@ import torch
from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.gpu_input_batch import (BlockTable, CachedRequestState,
InputBatch)
VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20
@@ -24,27 +22,6 @@ CUDA_DEVICES = [
MAX_NUM_PROMPT_TOKENS = 64
def get_kv_cache_config() -> KVCacheConfig:
return KVCacheConfig(
num_blocks=10,
tensors={
"layer.0": KVCacheTensor(size=1024),
},
kv_cache_groups=[
KVCacheGroupSpec(
layer_names=["layer.0"],
kv_cache_spec=FullAttentionSpec(
block_size=1,
num_kv_heads=1,
head_size=16,
dtype=torch.float16,
use_mla=False,
),
),
],
)
def _compare_objs(obj1, obj2):
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
attr_names = set([
@@ -64,10 +41,6 @@ def _compare_objs(obj1, obj2):
elif isinstance(a, np.ndarray):
if np.allclose(a, b):
is_same = True
elif isinstance(a, MultiGroupBlockTable):
for a_i, b_i in zip(a.block_tables, b.block_tables):
_compare_objs(a_i, b_i)
is_same = True
elif isinstance(a, (BlockTable, SamplingMetadata)):
_compare_objs(a, b)
is_same = True # if we make it here must be same
@@ -225,7 +198,7 @@ def _construct_cached_request_state(req_id_suffix: int):
sampling_params=_create_sampling_params(),
mm_inputs=[],
mm_positions=[],
block_ids=[[]],
block_ids=[],
generator=None,
num_computed_tokens=len(output_token_ids),
output_token_ids=output_token_ids,
@@ -247,11 +220,11 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_blocks_per_req=10,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
kv_cache_config=get_kv_cache_config(),
)
reqs: list[CachedRequestState] = []
req_id_reqs = {}
@@ -337,20 +310,20 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_blocks_per_req=10,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
kv_cache_config=get_kv_cache_config(),
)
ref_input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_blocks_per_req=10,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
kv_cache_config=get_kv_cache_config(),
)
reqs: list[CachedRequestState] = []