[Core] Use individual MM items in P0/P1 cache and model runner (#22570)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,12 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.multimodal.inputs import (MultiModalBatchedField,
|
||||
MultiModalFieldElem, MultiModalKwargsItem,
|
||||
PlaceholderRange)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||
@@ -27,20 +30,29 @@ from vllm.v1.request import Request
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def make_request(request_id,
|
||||
prompt_token_ids,
|
||||
mm_positions=None,
|
||||
mm_hashes=None,
|
||||
cache_salt=None):
|
||||
def make_request(
|
||||
request_id: str,
|
||||
prompt_token_ids: list[int],
|
||||
mm_positions: Optional[list[PlaceholderRange]] = None,
|
||||
mm_hashes: Optional[list[str]] = None,
|
||||
cache_salt: Optional[str] = None,
|
||||
):
|
||||
if mm_positions is None:
|
||||
multi_modal_inputs = None
|
||||
mm_kwargs = None
|
||||
else:
|
||||
multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions)
|
||||
mm_elem = MultiModalFieldElem(
|
||||
modality="dummy_m",
|
||||
key="dummy_k",
|
||||
data=None,
|
||||
field=MultiModalBatchedField(),
|
||||
)
|
||||
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
|
||||
mm_kwargs = [mm_item] * len(mm_positions)
|
||||
|
||||
return Request(
|
||||
request_id=request_id,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_inputs=multi_modal_inputs,
|
||||
multi_modal_kwargs=mm_kwargs,
|
||||
multi_modal_hashes=mm_hashes,
|
||||
multi_modal_placeholders=mm_positions,
|
||||
sampling_params=SamplingParams(max_tokens=17),
|
||||
@@ -316,7 +328,7 @@ def test_free_kv_cache_block_queue_get_all_free_blocks():
|
||||
|
||||
def test_generate_block_hash_extra_keys():
|
||||
request = make_request(
|
||||
request_id=0,
|
||||
request_id="0",
|
||||
prompt_token_ids=[_ for _ in range(20)],
|
||||
mm_positions=[
|
||||
PlaceholderRange(offset=0, length=5),
|
||||
@@ -348,7 +360,7 @@ def test_generate_block_hash_extra_keys():
|
||||
|
||||
def test_generate_block_hash_extra_keys_no_mm_inputs():
|
||||
request = make_request(
|
||||
request_id=0,
|
||||
request_id="0",
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
mm_positions=None,
|
||||
mm_hashes=None,
|
||||
@@ -361,7 +373,7 @@ def test_generate_block_hash_extra_keys_no_mm_inputs():
|
||||
|
||||
def test_generate_block_hash_extra_keys_cache_salt():
|
||||
request = make_request(
|
||||
request_id=0,
|
||||
request_id="0",
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
mm_positions=None,
|
||||
mm_hashes=None,
|
||||
@@ -382,7 +394,7 @@ def test_generate_block_hash_extra_keys_cache_salt():
|
||||
|
||||
# works together with other extra keys
|
||||
request_mm = make_request(
|
||||
request_id=0,
|
||||
request_id="0",
|
||||
prompt_token_ids=[_ for _ in range(20)],
|
||||
mm_positions=[
|
||||
PlaceholderRange(offset=0, length=5),
|
||||
@@ -420,7 +432,7 @@ def test_hash_request_tokens(hash_fn):
|
||||
import vllm.v1.core.kv_cache_utils
|
||||
init_none_hash(hash_fn)
|
||||
request = make_request(
|
||||
request_id=0,
|
||||
request_id="0",
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
mm_positions=[
|
||||
PlaceholderRange(offset=0, length=3),
|
||||
@@ -450,7 +462,7 @@ def test_hash_tokens_different_mm_input(hash_fn):
|
||||
init_none_hash(hash_fn)
|
||||
|
||||
request1 = make_request(
|
||||
request_id=0,
|
||||
request_id="0",
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
mm_positions=[
|
||||
PlaceholderRange(offset=0, length=3),
|
||||
@@ -459,7 +471,7 @@ def test_hash_tokens_different_mm_input(hash_fn):
|
||||
mm_hashes=["hash1", "hash2"],
|
||||
)
|
||||
request2 = make_request(
|
||||
request_id=1,
|
||||
request_id="1",
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
mm_positions=[
|
||||
PlaceholderRange(offset=0, length=3),
|
||||
@@ -479,7 +491,7 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
|
||||
init_none_hash(hash_fn)
|
||||
|
||||
request = make_request(
|
||||
request_id=0,
|
||||
request_id="0",
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
mm_positions=None,
|
||||
mm_hashes=None,
|
||||
@@ -844,7 +856,7 @@ def test_allocate_with_lookahead():
|
||||
)
|
||||
|
||||
request = make_request(
|
||||
request_id=0,
|
||||
request_id="0",
|
||||
prompt_token_ids=[],
|
||||
mm_positions=None,
|
||||
mm_hashes=None,
|
||||
|
||||
@@ -9,7 +9,9 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.multimodal.inputs import (MultiModalBatchedField,
|
||||
MultiModalFieldElem, MultiModalKwargsItem,
|
||||
PlaceholderRange)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import sha256, sha256_cbor_64bit
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
@@ -21,21 +23,30 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, SlidingWindowSpec)
|
||||
|
||||
|
||||
def make_request(request_id,
|
||||
prompt_token_ids,
|
||||
mm_positions=None,
|
||||
mm_hashes=None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
cache_salt: Optional[str] = None):
|
||||
def make_request(
|
||||
request_id: str,
|
||||
prompt_token_ids: list[int],
|
||||
mm_positions: Optional[list[PlaceholderRange]] = None,
|
||||
mm_hashes: Optional[list[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
cache_salt: Optional[str] = None,
|
||||
):
|
||||
if mm_positions is None:
|
||||
multi_modal_inputs = None
|
||||
mm_kwargs = None
|
||||
else:
|
||||
multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions)
|
||||
mm_elem = MultiModalFieldElem(
|
||||
modality="dummy_m",
|
||||
key="dummy_k",
|
||||
data=None,
|
||||
field=MultiModalBatchedField(),
|
||||
)
|
||||
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
|
||||
mm_kwargs = [mm_item] * len(mm_positions)
|
||||
|
||||
return Request(
|
||||
request_id=request_id,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_inputs=multi_modal_inputs,
|
||||
multi_modal_kwargs=mm_kwargs,
|
||||
multi_modal_hashes=mm_hashes,
|
||||
multi_modal_placeholders=mm_positions,
|
||||
sampling_params=SamplingParams(max_tokens=17,
|
||||
|
||||
@@ -8,7 +8,9 @@ import torch
|
||||
|
||||
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.multimodal.inputs import (MultiModalBatchedField,
|
||||
MultiModalFieldElem, MultiModalKwargsItem,
|
||||
PlaceholderRange)
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
@@ -1304,7 +1306,7 @@ def create_requests_with_priority(
|
||||
priorities: list[int],
|
||||
arrival_times: Optional[list[float]] = None,
|
||||
num_tokens: int = 10,
|
||||
mm_positions: Optional[list[PlaceholderRange]] = None,
|
||||
mm_positions: Optional[list[list[PlaceholderRange]]] = None,
|
||||
max_tokens: int = 16,
|
||||
stop_token_ids: Optional[list[int]] = None,
|
||||
prompt_logprobs: Optional[int] = None):
|
||||
@@ -1323,16 +1325,23 @@ def create_requests_with_priority(
|
||||
for i in range(num_requests):
|
||||
if mm_positions is not None:
|
||||
mm_position = mm_positions[i]
|
||||
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
|
||||
mm_elem = MultiModalFieldElem(
|
||||
modality="dummy_m",
|
||||
key="dummy_k",
|
||||
data=None,
|
||||
field=MultiModalBatchedField(),
|
||||
)
|
||||
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
|
||||
mm_kwargs = [mm_item] * len(mm_position)
|
||||
else:
|
||||
mm_position = None
|
||||
mm_inputs = None
|
||||
mm_kwargs = None
|
||||
request = Request(
|
||||
request_id=f"{i}",
|
||||
prompt_token_ids=[i] * num_tokens,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
multi_modal_inputs=mm_inputs,
|
||||
multi_modal_kwargs=mm_kwargs,
|
||||
multi_modal_placeholders=mm_position,
|
||||
multi_modal_hashes=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
@@ -1816,7 +1825,7 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
|
||||
request = Request(
|
||||
request_id="0",
|
||||
prompt_token_ids=[0, 1],
|
||||
multi_modal_inputs=None,
|
||||
multi_modal_kwargs=None,
|
||||
multi_modal_hashes=None,
|
||||
multi_modal_placeholders=None,
|
||||
sampling_params=sampling_params,
|
||||
|
||||
@@ -6,7 +6,9 @@ import torch
|
||||
|
||||
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.multimodal.inputs import (MultiModalBatchedField,
|
||||
MultiModalFieldElem, MultiModalKwargsItem,
|
||||
PlaceholderRange)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
@@ -115,7 +117,7 @@ def create_scheduler(
|
||||
def create_requests(
|
||||
num_requests: int,
|
||||
num_tokens: int = 10,
|
||||
mm_positions: Optional[list[PlaceholderRange]] = None,
|
||||
mm_positions: Optional[list[list[PlaceholderRange]]] = None,
|
||||
max_tokens: int = 16,
|
||||
stop_token_ids: Optional[list[int]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
@@ -129,10 +131,17 @@ def create_requests(
|
||||
for i in range(num_requests):
|
||||
if mm_positions is not None:
|
||||
mm_position = mm_positions[i]
|
||||
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
|
||||
mm_elem = MultiModalFieldElem(
|
||||
modality="dummy_m",
|
||||
key="dummy_k",
|
||||
data=None,
|
||||
field=MultiModalBatchedField(),
|
||||
)
|
||||
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
|
||||
mm_kwargs = [mm_item] * len(mm_position)
|
||||
else:
|
||||
mm_position = None
|
||||
mm_inputs = None
|
||||
mm_kwargs = None
|
||||
prompt_token_ids = ([0] * num_tokens if same_prompt else [i] *
|
||||
num_tokens)
|
||||
request = Request(
|
||||
@@ -140,7 +149,7 @@ def create_requests(
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
multi_modal_inputs=mm_inputs,
|
||||
multi_modal_kwargs=mm_kwargs,
|
||||
multi_modal_placeholders=mm_position,
|
||||
multi_modal_hashes=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
|
||||
Reference in New Issue
Block a user