[v1] Re-init input batch for multiple kv cache groups (#18654)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-06-04 05:41:36 +08:00
committed by GitHub
parent 6865fe0074
commit 6cac54f4d1
6 changed files with 61 additions and 46 deletions

View File

@@ -10,8 +10,6 @@ import torch
from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheTensor)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
@@ -25,27 +23,6 @@ CUDA_DEVICES = [
MAX_NUM_PROMPT_TOKENS = 64
def get_kv_cache_config() -> KVCacheConfig:
return KVCacheConfig(
num_blocks=10,
tensors={
"layer.0": KVCacheTensor(size=1024),
},
kv_cache_groups=[
KVCacheGroupSpec(
layer_names=["layer.0"],
kv_cache_spec=FullAttentionSpec(
block_size=1,
num_kv_heads=1,
head_size=16,
dtype=torch.float16,
use_mla=False,
),
),
],
)
def _compare_objs(obj1, obj2):
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
attr_names = set([
@@ -252,7 +229,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_size=1,
block_sizes=[1],
)
reqs: list[CachedRequestState] = []
req_id_reqs = {}
@@ -342,7 +319,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_size=1,
block_sizes=[1],
)
ref_input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
@@ -351,7 +328,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
block_size=1,
block_sizes=[1],
)
reqs: list[CachedRequestState] = []