[Core] Use individual MM items in P0/P1 cache and model runner (#22570)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-08-13 22:18:07 +08:00
committed by GitHub
parent 20d65aa755
commit 19b927e52d
24 changed files with 549 additions and 486 deletions

View File

@@ -1,12 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
from typing import Optional
import pytest
import torch
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargsItem,
PlaceholderRange)
from vllm.sampling_params import SamplingParams
from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit
from vllm.v1.core.kv_cache_manager import KVCacheManager
@@ -27,20 +30,29 @@ from vllm.v1.request import Request
# yapf: enable
def make_request(request_id,
prompt_token_ids,
mm_positions=None,
mm_hashes=None,
cache_salt=None):
def make_request(
request_id: str,
prompt_token_ids: list[int],
mm_positions: Optional[list[PlaceholderRange]] = None,
mm_hashes: Optional[list[str]] = None,
cache_salt: Optional[str] = None,
):
if mm_positions is None:
multi_modal_inputs = None
mm_kwargs = None
else:
multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions)
mm_elem = MultiModalFieldElem(
modality="dummy_m",
key="dummy_k",
data=None,
field=MultiModalBatchedField(),
)
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
mm_kwargs = [mm_item] * len(mm_positions)
return Request(
request_id=request_id,
prompt_token_ids=prompt_token_ids,
multi_modal_inputs=multi_modal_inputs,
multi_modal_kwargs=mm_kwargs,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17),
@@ -316,7 +328,7 @@ def test_free_kv_cache_block_queue_get_all_free_blocks():
def test_generate_block_hash_extra_keys():
request = make_request(
request_id=0,
request_id="0",
prompt_token_ids=[_ for _ in range(20)],
mm_positions=[
PlaceholderRange(offset=0, length=5),
@@ -348,7 +360,7 @@ def test_generate_block_hash_extra_keys():
def test_generate_block_hash_extra_keys_no_mm_inputs():
request = make_request(
request_id=0,
request_id="0",
prompt_token_ids=[_ for _ in range(6)],
mm_positions=None,
mm_hashes=None,
@@ -361,7 +373,7 @@ def test_generate_block_hash_extra_keys_no_mm_inputs():
def test_generate_block_hash_extra_keys_cache_salt():
request = make_request(
request_id=0,
request_id="0",
prompt_token_ids=[_ for _ in range(6)],
mm_positions=None,
mm_hashes=None,
@@ -382,7 +394,7 @@ def test_generate_block_hash_extra_keys_cache_salt():
# works together with other extra keys
request_mm = make_request(
request_id=0,
request_id="0",
prompt_token_ids=[_ for _ in range(20)],
mm_positions=[
PlaceholderRange(offset=0, length=5),
@@ -420,7 +432,7 @@ def test_hash_request_tokens(hash_fn):
import vllm.v1.core.kv_cache_utils
init_none_hash(hash_fn)
request = make_request(
request_id=0,
request_id="0",
prompt_token_ids=[_ for _ in range(6)],
mm_positions=[
PlaceholderRange(offset=0, length=3),
@@ -450,7 +462,7 @@ def test_hash_tokens_different_mm_input(hash_fn):
init_none_hash(hash_fn)
request1 = make_request(
request_id=0,
request_id="0",
prompt_token_ids=[_ for _ in range(6)],
mm_positions=[
PlaceholderRange(offset=0, length=3),
@@ -459,7 +471,7 @@ def test_hash_tokens_different_mm_input(hash_fn):
mm_hashes=["hash1", "hash2"],
)
request2 = make_request(
request_id=1,
request_id="1",
prompt_token_ids=[_ for _ in range(6)],
mm_positions=[
PlaceholderRange(offset=0, length=3),
@@ -479,7 +491,7 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
init_none_hash(hash_fn)
request = make_request(
request_id=0,
request_id="0",
prompt_token_ids=[_ for _ in range(6)],
mm_positions=None,
mm_hashes=None,
@@ -844,7 +856,7 @@ def test_allocate_with_lookahead():
)
request = make_request(
request_id=0,
request_id="0",
prompt_token_ids=[],
mm_positions=None,
mm_hashes=None,

View File

@@ -9,7 +9,9 @@ import pytest
import torch
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargsItem,
PlaceholderRange)
from vllm.sampling_params import SamplingParams
from vllm.utils import sha256, sha256_cbor_64bit
from vllm.v1.core.block_pool import BlockPool
@@ -21,21 +23,30 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, SlidingWindowSpec)
def make_request(request_id,
prompt_token_ids,
mm_positions=None,
mm_hashes=None,
prompt_logprobs: Optional[int] = None,
cache_salt: Optional[str] = None):
def make_request(
request_id: str,
prompt_token_ids: list[int],
mm_positions: Optional[list[PlaceholderRange]] = None,
mm_hashes: Optional[list[str]] = None,
prompt_logprobs: Optional[int] = None,
cache_salt: Optional[str] = None,
):
if mm_positions is None:
multi_modal_inputs = None
mm_kwargs = None
else:
multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions)
mm_elem = MultiModalFieldElem(
modality="dummy_m",
key="dummy_k",
data=None,
field=MultiModalBatchedField(),
)
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
mm_kwargs = [mm_item] * len(mm_positions)
return Request(
request_id=request_id,
prompt_token_ids=prompt_token_ids,
multi_modal_inputs=multi_modal_inputs,
multi_modal_kwargs=mm_kwargs,
multi_modal_hashes=mm_hashes,
multi_modal_placeholders=mm_positions,
sampling_params=SamplingParams(max_tokens=17,

View File

@@ -8,7 +8,9 @@ import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargsItem,
PlaceholderRange)
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
@@ -1304,7 +1306,7 @@ def create_requests_with_priority(
priorities: list[int],
arrival_times: Optional[list[float]] = None,
num_tokens: int = 10,
mm_positions: Optional[list[PlaceholderRange]] = None,
mm_positions: Optional[list[list[PlaceholderRange]]] = None,
max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[int] = None):
@@ -1323,16 +1325,23 @@ def create_requests_with_priority(
for i in range(num_requests):
if mm_positions is not None:
mm_position = mm_positions[i]
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
mm_elem = MultiModalFieldElem(
modality="dummy_m",
key="dummy_k",
data=None,
field=MultiModalBatchedField(),
)
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
mm_kwargs = [mm_item] * len(mm_position)
else:
mm_position = None
mm_inputs = None
mm_kwargs = None
request = Request(
request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
pooling_params=None,
multi_modal_inputs=mm_inputs,
multi_modal_kwargs=mm_kwargs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
@@ -1816,7 +1825,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
request = Request(
request_id="0",
prompt_token_ids=[0, 1],
multi_modal_inputs=None,
multi_modal_kwargs=None,
multi_modal_hashes=None,
multi_modal_placeholders=None,
sampling_params=sampling_params,

View File

@@ -6,7 +6,9 @@ import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalKwargsItem,
PlaceholderRange)
from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
from vllm.v1.core.sched.scheduler import Scheduler
@@ -115,7 +117,7 @@ def create_scheduler(
def create_requests(
num_requests: int,
num_tokens: int = 10,
mm_positions: Optional[list[PlaceholderRange]] = None,
mm_positions: Optional[list[list[PlaceholderRange]]] = None,
max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[int] = None,
@@ -129,10 +131,17 @@ def create_requests(
for i in range(num_requests):
if mm_positions is not None:
mm_position = mm_positions[i]
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
mm_elem = MultiModalFieldElem(
modality="dummy_m",
key="dummy_k",
data=None,
field=MultiModalBatchedField(),
)
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
mm_kwargs = [mm_item] * len(mm_position)
else:
mm_position = None
mm_inputs = None
mm_kwargs = None
prompt_token_ids = ([0] * num_tokens if same_prompt else [i] *
num_tokens)
request = Request(
@@ -140,7 +149,7 @@ def create_requests(
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
pooling_params=None,
multi_modal_inputs=mm_inputs,
multi_modal_kwargs=mm_kwargs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,