[v1] Re-init input batch for multiple kv cache groups (#18654)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -10,8 +10,6 @@ import torch
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheTensor)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
@@ -25,27 +23,6 @@ CUDA_DEVICES = [
|
||||
MAX_NUM_PROMPT_TOKENS = 64
|
||||
|
||||
|
||||
def get_kv_cache_config() -> KVCacheConfig:
|
||||
return KVCacheConfig(
|
||||
num_blocks=10,
|
||||
tensors={
|
||||
"layer.0": KVCacheTensor(size=1024),
|
||||
},
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
layer_names=["layer.0"],
|
||||
kv_cache_spec=FullAttentionSpec(
|
||||
block_size=1,
|
||||
num_kv_heads=1,
|
||||
head_size=16,
|
||||
dtype=torch.float16,
|
||||
use_mla=False,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _compare_objs(obj1, obj2):
|
||||
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
|
||||
attr_names = set([
|
||||
@@ -252,7 +229,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
device=torch.device(device),
|
||||
pin_memory=is_pin_memory_available(),
|
||||
vocab_size=1024,
|
||||
block_size=1,
|
||||
block_sizes=[1],
|
||||
)
|
||||
reqs: list[CachedRequestState] = []
|
||||
req_id_reqs = {}
|
||||
@@ -342,7 +319,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
|
||||
device=torch.device(device),
|
||||
pin_memory=is_pin_memory_available(),
|
||||
vocab_size=1024,
|
||||
block_size=1,
|
||||
block_sizes=[1],
|
||||
)
|
||||
ref_input_batch: InputBatch = InputBatch(
|
||||
max_num_reqs=batch_size,
|
||||
@@ -351,7 +328,7 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
|
||||
device=torch.device(device),
|
||||
pin_memory=is_pin_memory_available(),
|
||||
vocab_size=1024,
|
||||
block_size=1,
|
||||
block_sizes=[1],
|
||||
)
|
||||
|
||||
reqs: list[CachedRequestState] = []
|
||||
|
||||
Reference in New Issue
Block a user