[Renderer] Move InputPreprocessor into Renderer (2/2) (#34560)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -195,18 +195,15 @@ def test_chat_batch_failure_cleanup(llm_for_failure_test):
|
||||
valid_msg = [{"role": "user", "content": "Hello"}]
|
||||
long_text = "This is a very long text to test the error " * 50
|
||||
invalid_msg = [{"role": "user", "content": long_text}]
|
||||
batch_1 = [
|
||||
valid_msg,
|
||||
valid_msg,
|
||||
invalid_msg,
|
||||
]
|
||||
batch_2 = [
|
||||
valid_msg,
|
||||
valid_msg,
|
||||
]
|
||||
|
||||
batch_1 = [valid_msg, valid_msg, invalid_msg]
|
||||
batch_2 = [valid_msg, valid_msg]
|
||||
sampling_params = SamplingParams(temperature=0, max_tokens=10)
|
||||
|
||||
with pytest.raises(ValueError, match="context length is only"):
|
||||
llm.chat(batch_1, sampling_params=sampling_params)
|
||||
assert llm.llm_engine.get_num_unfinished_requests() == 0
|
||||
|
||||
outputs_2 = llm.chat(batch_2, sampling_params=sampling_params)
|
||||
assert len(outputs_2) == len(batch_2)
|
||||
assert llm.llm_engine.get_num_unfinished_requests() == 0
|
||||
|
||||
@@ -489,8 +489,9 @@ def _assert_inputs_equal(
|
||||
if ignore_mm_keys is None:
|
||||
ignore_mm_keys = set()
|
||||
|
||||
a_rest = {k: v for k, v in a.items() if k != "mm_kwargs"}
|
||||
b_rest = {k: v for k, v in b.items() if k != "mm_kwargs"}
|
||||
ignore_prompt_keys = ("prompt", "mm_kwargs")
|
||||
a_rest = {k: v for k, v in a.items() if k not in ignore_prompt_keys}
|
||||
b_rest = {k: v for k, v in b.items() if k not in ignore_prompt_keys}
|
||||
|
||||
assert a_rest == b_rest, msg
|
||||
|
||||
|
||||
165
tests/renderers/test_process_multi_modal_uuids.py
Normal file
165
tests/renderers/test_process_multi_modal_uuids.py
Normal file
@@ -0,0 +1,165 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.assets.video import VideoAsset
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.renderers.hf import HfRenderer
|
||||
from vllm.tokenizers.registry import tokenizer_args_from_config
|
||||
|
||||
cherry_pil_image = ImageAsset("cherry_blossom").pil_image
|
||||
stop_pil_image = ImageAsset("stop_sign").pil_image
|
||||
baby_reading_np_ndarrays = VideoAsset("baby_reading").np_ndarrays
|
||||
|
||||
|
||||
def _build_renderer(
|
||||
*, mm_cache_gb: float = 4.0, enable_prefix_caching: bool = True
|
||||
) -> HfRenderer:
|
||||
model_config = ModelConfig(
|
||||
model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
max_model_len=128,
|
||||
mm_processor_cache_gb=mm_cache_gb,
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching),
|
||||
)
|
||||
|
||||
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
||||
|
||||
return HfRenderer.from_config(
|
||||
vllm_config,
|
||||
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
||||
)
|
||||
|
||||
|
||||
def test_multi_modal_uuids_length_mismatch_raises():
|
||||
renderer = _build_renderer()
|
||||
|
||||
mm_data = {"image": [cherry_pil_image, stop_pil_image]}
|
||||
|
||||
# Mismatch: 2 items but only 1 uuid provided
|
||||
mm_uuids = {"image": ["hash_cherry"]}
|
||||
|
||||
mm_processor = renderer.get_mm_processor()
|
||||
mm_items = mm_processor.info.parse_mm_data(mm_data)
|
||||
|
||||
with pytest.raises(ValueError, match="must have same length as"):
|
||||
renderer._process_mm_uuids(mm_data, mm_items, mm_uuids, "req-1")
|
||||
|
||||
|
||||
def test_multi_modal_uuids_missing_modality_raises():
|
||||
renderer = _build_renderer()
|
||||
|
||||
mm_data = {
|
||||
"image": [cherry_pil_image],
|
||||
"video": None,
|
||||
}
|
||||
|
||||
# Only image uuids provided; video missing should raise
|
||||
mm_uuids = {"image": ["hash_cherry"]}
|
||||
|
||||
mm_processor = renderer.get_mm_processor()
|
||||
mm_items = mm_processor.info.parse_mm_data(mm_data)
|
||||
|
||||
with pytest.raises(ValueError, match="is empty but .* is missing"):
|
||||
renderer._process_mm_uuids(mm_data, mm_items, mm_uuids, "req-2")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mm_cache_gb, enable_prefix_caching",
|
||||
[
|
||||
(4.0, True), # default behavior
|
||||
(4.0, False), # prefix caching disabled
|
||||
(0.0, True), # processor cache disabled
|
||||
],
|
||||
)
|
||||
def test_multi_modal_uuids_accepts_none_and_passes_through(
|
||||
monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool
|
||||
):
|
||||
renderer = _build_renderer(
|
||||
mm_cache_gb=mm_cache_gb,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
)
|
||||
|
||||
mm_data = {
|
||||
"image": [cherry_pil_image, stop_pil_image],
|
||||
"video": baby_reading_np_ndarrays,
|
||||
}
|
||||
|
||||
# Use a consistent two-image scenario across all configurations
|
||||
mm_uuids = {"image": [None, "hash_stop"], "video": None}
|
||||
|
||||
mm_processor = renderer.get_mm_processor()
|
||||
mm_items = mm_processor.info.parse_mm_data(mm_data)
|
||||
processed_mm_uuids = renderer._process_mm_uuids(
|
||||
mm_data, mm_items, mm_uuids, "req-3"
|
||||
)
|
||||
|
||||
assert processed_mm_uuids == mm_uuids
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mm_cache_gb, enable_prefix_caching",
|
||||
[
|
||||
(4.0, True), # default behavior
|
||||
(4.0, False), # prefix caching disabled
|
||||
(0.0, True), # processor cache disabled
|
||||
],
|
||||
)
|
||||
def test_multi_modal_uuids_accepts_empty(
|
||||
monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool
|
||||
):
|
||||
renderer = _build_renderer(
|
||||
mm_cache_gb=mm_cache_gb,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
)
|
||||
|
||||
# While None means cached multi-modal input requiring UUIDs
|
||||
# an empty list means no multi-modal input
|
||||
mm_data = {"image": [], "video": []} # type: ignore[var-annotated]
|
||||
mm_uuids = {"image": [], "video": None} # type: ignore[var-annotated]
|
||||
|
||||
mm_processor = renderer.get_mm_processor()
|
||||
mm_items = mm_processor.info.parse_mm_data(mm_data)
|
||||
processed_mm_uuids = renderer._process_mm_uuids(
|
||||
mm_data, mm_items, mm_uuids, "req-4"
|
||||
)
|
||||
|
||||
assert processed_mm_uuids == mm_uuids
|
||||
|
||||
|
||||
def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
|
||||
# When both processor cache is 0 and prefix caching disabled, the
|
||||
# processor builds overrides from request id instead of using user UUIDs.
|
||||
renderer = _build_renderer(mm_cache_gb=0.0, enable_prefix_caching=False)
|
||||
|
||||
request_id = "req-42"
|
||||
mm_data = {
|
||||
"image": [cherry_pil_image, stop_pil_image],
|
||||
"video": baby_reading_np_ndarrays,
|
||||
}
|
||||
mm_uuids = {"image": ["hash_cherry", "hash_stop"], "video": ["hash_video"]}
|
||||
|
||||
mm_processor = renderer.get_mm_processor()
|
||||
mm_items = mm_processor.info.parse_mm_data(mm_data)
|
||||
processed_mm_uuids = renderer._process_mm_uuids(
|
||||
mm_data, mm_items, mm_uuids, request_id
|
||||
)
|
||||
|
||||
# Expect request-id-based overrides are passed through
|
||||
assert set(mm_uuids.keys()) == {"image", "video"}
|
||||
assert len(mm_uuids["image"]) == 2
|
||||
assert len(mm_uuids["video"]) == 1
|
||||
assert processed_mm_uuids["image"][0].startswith(
|
||||
f"{request_id}-image-"
|
||||
) and processed_mm_uuids["image"][0].endswith("-0")
|
||||
assert processed_mm_uuids["image"][1].startswith(
|
||||
f"{request_id}-image-"
|
||||
) and processed_mm_uuids["image"][1].endswith("-1")
|
||||
assert processed_mm_uuids["video"][0].startswith(
|
||||
f"{request_id}-video-"
|
||||
) and processed_mm_uuids["video"][0].endswith("-0")
|
||||
@@ -20,7 +20,6 @@ MM_BEAM_WIDTHS = [2]
|
||||
MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"]
|
||||
|
||||
|
||||
@pytest.mark.skip_v1 # V1 engine does not yet support beam search
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
|
||||
@@ -62,7 +61,6 @@ def test_beam_search_single_input(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip_v1 # V1 engine does not yet support beam search
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", MAX_TOKENS)
|
||||
|
||||
@@ -1,174 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.assets.video import VideoAsset
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.multimodal import MultiModalUUIDDict
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.engine.input_processor import InputProcessor
|
||||
|
||||
cherry_pil_image = ImageAsset("cherry_blossom").pil_image
|
||||
stop_pil_image = ImageAsset("stop_sign").pil_image
|
||||
baby_reading_np_ndarrays = VideoAsset("baby_reading").np_ndarrays
|
||||
|
||||
|
||||
def _build_input_processor(
|
||||
*, mm_cache_gb: float = 4.0, enable_prefix_caching: bool = True
|
||||
) -> InputProcessor:
|
||||
model_config = ModelConfig(
|
||||
model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
max_model_len=128,
|
||||
mm_processor_cache_gb=mm_cache_gb,
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching),
|
||||
)
|
||||
|
||||
return InputProcessor(vllm_config)
|
||||
|
||||
|
||||
def test_multi_modal_uuids_length_mismatch_raises():
|
||||
input_processor = _build_input_processor()
|
||||
|
||||
prompt = {
|
||||
"prompt": "USER: <image>\nDescribe\nASSISTANT:",
|
||||
"multi_modal_data": {"image": [cherry_pil_image, stop_pil_image]},
|
||||
# Mismatch: 2 items but only 1 uuid provided
|
||||
"multi_modal_uuids": {"image": ["hash_cherry"]},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="must have same length as"):
|
||||
input_processor.process_inputs(
|
||||
request_id="req-1",
|
||||
prompt=prompt, # type: ignore[arg-type]
|
||||
params=SamplingParams(),
|
||||
)
|
||||
|
||||
|
||||
def test_multi_modal_uuids_missing_modality_raises():
|
||||
input_processor = _build_input_processor()
|
||||
|
||||
prompt = {
|
||||
"prompt": "USER: <image><video>\nDescribe\nASSISTANT:",
|
||||
# Two modalities provided in data
|
||||
"multi_modal_data": {
|
||||
"image": [cherry_pil_image],
|
||||
"video": None,
|
||||
},
|
||||
# Only image uuids provided; video missing should raise
|
||||
"multi_modal_uuids": {"image": ["hash_cherry"]},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="is empty but .* is missing"):
|
||||
input_processor.process_inputs(
|
||||
request_id="req-2",
|
||||
prompt=prompt, # type: ignore[arg-type]
|
||||
params=SamplingParams(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mm_cache_gb, enable_prefix_caching",
|
||||
[
|
||||
(4.0, True), # default behavior
|
||||
(4.0, False), # prefix caching disabled
|
||||
(0.0, True), # processor cache disabled
|
||||
],
|
||||
)
|
||||
def test_multi_modal_uuids_accepts_none_and_passes_through(
|
||||
monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool
|
||||
):
|
||||
input_processor = _build_input_processor(
|
||||
mm_cache_gb=mm_cache_gb,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
)
|
||||
|
||||
# Capture the overrides passed to InputPreprocessor.preprocess
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def fake_preprocess(
|
||||
prompt, *, tokenization_kwargs=None, lora_request=None, mm_uuids=None
|
||||
):
|
||||
captured["mm_uuids"] = mm_uuids
|
||||
# Minimal processed inputs for decoder-only flow
|
||||
return {"type": "token", "prompt_token_ids": [1]}
|
||||
|
||||
# Monkeypatch only the bound preprocess method on this instance
|
||||
monkeypatch.setattr(
|
||||
input_processor.input_preprocessor, "preprocess", fake_preprocess, raising=True
|
||||
)
|
||||
|
||||
# Use a consistent two-image scenario across all configurations
|
||||
mm_uuids = {"image": [None, "hash_stop"], "video": None}
|
||||
prompt = {
|
||||
"prompt": "USER: <image><image>\nTwo images\nASSISTANT:",
|
||||
"multi_modal_data": {
|
||||
"image": [cherry_pil_image, stop_pil_image],
|
||||
"video": baby_reading_np_ndarrays,
|
||||
},
|
||||
"multi_modal_uuids": mm_uuids,
|
||||
}
|
||||
|
||||
input_processor.process_inputs(
|
||||
request_id="req-3",
|
||||
prompt=prompt, # type: ignore[arg-type]
|
||||
params=SamplingParams(),
|
||||
)
|
||||
|
||||
assert captured["mm_uuids"] == mm_uuids
|
||||
|
||||
|
||||
def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
|
||||
# When both processor cache is 0 and prefix caching disabled, the
|
||||
# processor builds overrides from request id instead of using user UUIDs.
|
||||
input_processor = _build_input_processor(
|
||||
mm_cache_gb=0.0, enable_prefix_caching=False
|
||||
)
|
||||
|
||||
captured: dict[str, MultiModalUUIDDict] = {}
|
||||
|
||||
def fake_preprocess(
|
||||
prompt, *, tokenization_kwargs=None, lora_request=None, mm_uuids=None
|
||||
):
|
||||
captured["mm_uuids"] = mm_uuids
|
||||
return {"type": "token", "prompt_token_ids": [1]}
|
||||
|
||||
monkeypatch.setattr(
|
||||
input_processor.input_preprocessor, "preprocess", fake_preprocess, raising=True
|
||||
)
|
||||
|
||||
request_id = "req-42"
|
||||
mm_uuids = {"image": ["hash_cherry", "hash_stop"], "video": ["hash_video"]}
|
||||
prompt = {
|
||||
"prompt": "USER: <image><image><video>\nDescribe\nASSISTANT:",
|
||||
"multi_modal_data": {
|
||||
"image": [cherry_pil_image, stop_pil_image],
|
||||
"video": [baby_reading_np_ndarrays],
|
||||
},
|
||||
"multi_modal_uuids": mm_uuids,
|
||||
}
|
||||
|
||||
input_processor.process_inputs(
|
||||
request_id=request_id,
|
||||
prompt=prompt, # type: ignore[arg-type]
|
||||
params=SamplingParams(),
|
||||
)
|
||||
|
||||
# Expect request-id-based overrides are passed through
|
||||
assert set(mm_uuids.keys()) == {"image", "video"}
|
||||
assert len(mm_uuids["image"]) == 2
|
||||
assert len(mm_uuids["video"]) == 1
|
||||
assert captured["mm_uuids"]["image"][0].startswith(
|
||||
f"{request_id}-image-"
|
||||
) and captured["mm_uuids"]["image"][0].endswith("-0")
|
||||
assert captured["mm_uuids"]["image"][1].startswith(
|
||||
f"{request_id}-image-"
|
||||
) and captured["mm_uuids"]["image"][1].endswith("-1")
|
||||
assert captured["mm_uuids"]["video"][0].startswith(
|
||||
f"{request_id}-video-"
|
||||
) and captured["mm_uuids"]["video"][0].endswith("-0")
|
||||
@@ -2,13 +2,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from vllm.inputs import TokenInputs, token_inputs
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.multimodal.inputs import MultiModalInputs, mm_inputs
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -19,6 +17,8 @@ class BeamSearchSequence:
|
||||
about to be returned to the user.
|
||||
"""
|
||||
|
||||
orig_prompt: TokenInputs | MultiModalInputs
|
||||
|
||||
# The tokens include the prompt.
|
||||
tokens: list[int]
|
||||
logprobs: list[dict[int, Logprob]]
|
||||
@@ -27,8 +27,28 @@ class BeamSearchSequence:
|
||||
text: str | None = None
|
||||
finish_reason: str | None = None
|
||||
stop_reason: int | str | None = None
|
||||
multi_modal_data: "MultiModalDataDict | None" = None
|
||||
mm_processor_kwargs: dict[str, Any] | None = None
|
||||
|
||||
def get_prompt(self):
|
||||
prompt = self.orig_prompt
|
||||
|
||||
prompt_text = prompt.get("prompt")
|
||||
cache_salt = prompt.get("cache_salt")
|
||||
|
||||
if prompt["type"] == "token":
|
||||
return token_inputs(
|
||||
self.tokens,
|
||||
prompt=prompt_text,
|
||||
cache_salt=cache_salt,
|
||||
)
|
||||
|
||||
return mm_inputs(
|
||||
prompt_token_ids=self.tokens,
|
||||
mm_kwargs=prompt["mm_kwargs"],
|
||||
mm_hashes=prompt["mm_hashes"],
|
||||
mm_placeholders=prompt["mm_placeholders"],
|
||||
prompt=prompt_text,
|
||||
cache_salt=cache_salt,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -44,14 +64,15 @@ class BeamSearchOutput:
|
||||
class BeamSearchInstance:
|
||||
def __init__(
|
||||
self,
|
||||
prompt_tokens: list[int],
|
||||
prompt: TokenInputs | MultiModalInputs,
|
||||
lora_request: LoRARequest | None = None,
|
||||
logprobs: list[dict[int, Logprob]] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.beams: list[BeamSearchSequence] = [
|
||||
BeamSearchSequence(
|
||||
tokens=prompt_tokens,
|
||||
orig_prompt=prompt,
|
||||
tokens=prompt["prompt_token_ids"],
|
||||
logprobs=[] if logprobs is None else list(logprobs),
|
||||
lora_request=lora_request,
|
||||
**kwargs,
|
||||
|
||||
@@ -11,13 +11,12 @@ from vllm.distributed.weight_transfer.base import (
|
||||
WeightTransferInitRequest,
|
||||
WeightTransferUpdateRequest,
|
||||
)
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.inputs.data import ProcessorInputs, PromptType
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors import IOProcessor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import BaseRenderer
|
||||
from vllm.renderers.inputs import DictPrompt, TokPrompt
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
@@ -35,7 +34,7 @@ class StreamingInput:
|
||||
where inputs are provided via an async generator.
|
||||
"""
|
||||
|
||||
prompt: PromptType
|
||||
prompt: ProcessorInputs
|
||||
sampling_params: SamplingParams | None = None
|
||||
|
||||
|
||||
@@ -69,8 +68,7 @@ class EngineClient(ABC):
|
||||
self,
|
||||
prompt: EngineCoreRequest
|
||||
| PromptType
|
||||
| DictPrompt
|
||||
| TokPrompt
|
||||
| ProcessorInputs
|
||||
| AsyncGenerator[StreamingInput, None],
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
@@ -81,6 +79,7 @@ class EngineClient(ABC):
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: int | None = None,
|
||||
reasoning_ended: bool | None = None,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""Generate outputs for a request."""
|
||||
...
|
||||
@@ -88,13 +87,14 @@ class EngineClient(ABC):
|
||||
@abstractmethod
|
||||
def encode(
|
||||
self,
|
||||
prompt: PromptType | DictPrompt | TokPrompt,
|
||||
prompt: PromptType | ProcessorInputs,
|
||||
pooling_params: PoolingParams,
|
||||
request_id: str,
|
||||
lora_request: LoRARequest | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
priority: int = 0,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
reasoning_ended: bool | None = None,
|
||||
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
||||
"""Generate outputs for a request from a pooling model."""
|
||||
...
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
|
||||
import itertools
|
||||
import warnings
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import cloudpickle
|
||||
import torch.nn as nn
|
||||
@@ -55,6 +55,7 @@ from vllm.entrypoints.pooling.score.utils import (
|
||||
from vllm.entrypoints.utils import log_non_default_args
|
||||
from vllm.inputs.data import (
|
||||
DataPrompt,
|
||||
ProcessorInputs,
|
||||
PromptType,
|
||||
SingletonPrompt,
|
||||
TextPrompt,
|
||||
@@ -73,10 +74,8 @@ from vllm.outputs import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import ChatParams, merge_kwargs
|
||||
from vllm.renderers.inputs import DictPrompt, TokPrompt
|
||||
from vllm.renderers.inputs.preprocess import (
|
||||
conversation_to_seq,
|
||||
extract_prompt_components,
|
||||
parse_model_prompt,
|
||||
prompt_to_seq,
|
||||
)
|
||||
@@ -86,6 +85,7 @@ from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils.counter import Counter
|
||||
from vllm.utils.tqdm_utils import maybe_tqdm
|
||||
from vllm.v1.engine.llm_engine import LLMEngine
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessor
|
||||
|
||||
@@ -400,7 +400,7 @@ class LLM:
|
||||
sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
|
||||
priority: list[int] | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
) -> list[RequestOutput]:
|
||||
@@ -462,7 +462,7 @@ class LLM:
|
||||
self,
|
||||
prompts: PromptType | Sequence[PromptType],
|
||||
sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
|
||||
priority: list[int] | None = None,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
@@ -495,34 +495,32 @@ class LLM:
|
||||
# Use the same preprocessing as _run_completion
|
||||
seq_prompts = prompt_to_seq(prompts)
|
||||
seq_params = self._params_to_seq(sampling_params, len(seq_prompts))
|
||||
|
||||
if any(param.truncate_prompt_tokens is not None for param in seq_params):
|
||||
engine_prompts: Sequence[DictPrompt | TokPrompt] = [
|
||||
engine_prompt
|
||||
for prompt, param in zip(seq_prompts, seq_params)
|
||||
for engine_prompt in self._preprocess_cmpl(
|
||||
[prompt],
|
||||
tokenization_kwargs=merge_kwargs(
|
||||
tokenization_kwargs,
|
||||
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
|
||||
),
|
||||
)
|
||||
]
|
||||
else:
|
||||
engine_prompts = self._preprocess_cmpl(
|
||||
seq_prompts,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
|
||||
seq_tok_kwargs = [
|
||||
merge_kwargs(
|
||||
tokenization_kwargs,
|
||||
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
|
||||
)
|
||||
for param in seq_params
|
||||
]
|
||||
seq_priority = self._priority_to_seq(priority, len(prompts))
|
||||
|
||||
request_ids = self._validate_and_add_requests(
|
||||
prompts=engine_prompts,
|
||||
params=seq_params,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=self._get_modality_specific_lora_reqs(
|
||||
engine_prompts, lora_request
|
||||
request_ids = self._render_and_add_requests(
|
||||
prompts=(
|
||||
self._preprocess_cmpl_one(prompt, tok_kwargs)
|
||||
for prompt, tok_kwargs in zip(
|
||||
maybe_tqdm(
|
||||
seq_prompts,
|
||||
use_tqdm=use_tqdm,
|
||||
desc="Rendering prompts",
|
||||
),
|
||||
seq_tok_kwargs,
|
||||
)
|
||||
),
|
||||
params=seq_params,
|
||||
lora_requests=seq_lora_requests,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
priority=priority,
|
||||
priorities=seq_priority,
|
||||
)
|
||||
|
||||
return request_ids
|
||||
@@ -545,53 +543,41 @@ class LLM:
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
return self.engine_class.validate_outputs(outputs, RequestOutput)
|
||||
|
||||
def _get_modality_specific_lora_reqs(
|
||||
def _resolve_lora_reqs(
|
||||
self,
|
||||
prompts: Sequence[DictPrompt | TokPrompt],
|
||||
lora_request: list[LoRARequest] | LoRARequest | None,
|
||||
prompts: Sequence[ProcessorInputs],
|
||||
lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
|
||||
):
|
||||
# Grab the lora config off the vllm config on the engine,
|
||||
# since this is the same for both v0 & v1.
|
||||
lora_config = self.llm_engine.vllm_config.lora_config
|
||||
seq_lora_requests = self._lora_request_to_seq(lora_request, len(prompts))
|
||||
|
||||
# If there's no lora config / default_mm_loras, or the model
|
||||
# isn't multimodal, leave the lora as is.
|
||||
if (
|
||||
lora_config is None
|
||||
or not self.model_config.is_multimodal_model
|
||||
or (lora_config and lora_config.default_mm_loras is None)
|
||||
):
|
||||
return lora_request
|
||||
|
||||
optional_loras = (
|
||||
[lora_request] * len(prompts)
|
||||
if not isinstance(lora_request, Sequence)
|
||||
else lora_request
|
||||
)
|
||||
return seq_lora_requests
|
||||
|
||||
return [
|
||||
self._resolve_single_prompt_mm_lora(
|
||||
prompt,
|
||||
opt_lora_req,
|
||||
lora_req,
|
||||
lora_config.default_mm_loras,
|
||||
)
|
||||
for prompt, opt_lora_req in zip(prompts, optional_loras)
|
||||
for prompt, lora_req in zip(prompts, seq_lora_requests)
|
||||
]
|
||||
|
||||
def _resolve_single_prompt_mm_lora(
|
||||
self,
|
||||
prompt: DictPrompt | TokPrompt,
|
||||
prompt: ProcessorInputs,
|
||||
lora_request: LoRARequest | None,
|
||||
default_mm_loras: dict[str, str] | None,
|
||||
):
|
||||
if not default_mm_loras or not (
|
||||
mm_data := prompt.get("multi_modal_data") or {}
|
||||
):
|
||||
if not default_mm_loras or prompt["type"] != "multimodal":
|
||||
return lora_request
|
||||
|
||||
intersection = set(
|
||||
mm_data.keys() # type: ignore
|
||||
).intersection(default_mm_loras.keys())
|
||||
prompt_modalities = prompt["mm_placeholders"].keys()
|
||||
intersection = set(prompt_modalities).intersection(default_mm_loras.keys())
|
||||
if not intersection:
|
||||
return lora_request
|
||||
if len(intersection) > 1:
|
||||
@@ -674,22 +660,6 @@ class LLM:
|
||||
"""
|
||||
return self.llm_engine.apply_model(func)
|
||||
|
||||
def _get_beam_search_lora_requests(
|
||||
self,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None,
|
||||
prompts: list[TokensPrompt | TextPrompt],
|
||||
) -> list[LoRARequest | None]:
|
||||
"""Get the optional lora request corresponding to each prompt."""
|
||||
if isinstance(lora_request, Sequence) and len(lora_request) != len(prompts):
|
||||
raise ValueError(
|
||||
"Lora request list should be the same length as the prompts"
|
||||
)
|
||||
|
||||
if lora_request is None or isinstance(lora_request, LoRARequest):
|
||||
return [lora_request] * len(prompts)
|
||||
|
||||
raise TypeError(f"Invalid lora_request type {type(lora_request)}")
|
||||
|
||||
def beam_search(
|
||||
self,
|
||||
prompts: list[TokensPrompt | TextPrompt],
|
||||
@@ -718,13 +688,12 @@ class LLM:
|
||||
ignore_eos = params.ignore_eos
|
||||
length_penalty = params.length_penalty
|
||||
|
||||
lora_requests = self._get_beam_search_lora_requests(lora_request, prompts)
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
sort_beams_key = create_sort_beams_key_function(
|
||||
tokenizer.eos_token_id,
|
||||
length_penalty,
|
||||
)
|
||||
engine_prompts = self._preprocess_cmpl(prompts)
|
||||
lora_requests = self._lora_request_to_seq(lora_request, len(engine_prompts))
|
||||
|
||||
if use_tqdm and concurrency_limit is not None:
|
||||
logger.warning(
|
||||
@@ -734,21 +703,12 @@ class LLM:
|
||||
use_tqdm = False
|
||||
|
||||
if concurrency_limit is None:
|
||||
concurrency_limit = len(prompts)
|
||||
|
||||
def create_tokens_prompt_from_beam(beam: BeamSearchSequence) -> TokensPrompt:
|
||||
token_prompt_kwargs: TokensPrompt = {"prompt_token_ids": beam.tokens}
|
||||
if beam.multi_modal_data is not None:
|
||||
token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data
|
||||
|
||||
if beam.mm_processor_kwargs is not None:
|
||||
token_prompt_kwargs["mm_processor_kwargs"] = beam.mm_processor_kwargs
|
||||
return TokensPrompt(**token_prompt_kwargs)
|
||||
concurrency_limit = len(engine_prompts)
|
||||
|
||||
# generate 2 * beam_width candidates at each step
|
||||
# following the huggingface transformers implementation
|
||||
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
|
||||
beam_search_params = SamplingParams(
|
||||
sampling_params = SamplingParams(
|
||||
logprobs=2 * beam_width,
|
||||
max_tokens=1,
|
||||
temperature=temperature,
|
||||
@@ -756,30 +716,25 @@ class LLM:
|
||||
)
|
||||
instances: list[BeamSearchInstance] = []
|
||||
|
||||
for lora_req, prompt in zip(lora_requests, prompts):
|
||||
# Add multimodal processor kwargs & data
|
||||
mm_kwargs = {}
|
||||
if "multi_modal_data" in prompt:
|
||||
mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
|
||||
if "mm_processor_kwargs" in prompt:
|
||||
mm_kwargs["mm_processor_kwargs"] = prompt["mm_processor_kwargs"]
|
||||
|
||||
if "prompt_token_ids" in prompt:
|
||||
prompt = cast(TokensPrompt, prompt) # Needed for mypy
|
||||
prompt_tokens = prompt["prompt_token_ids"]
|
||||
else:
|
||||
prompt_tokens = tokenizer.encode(prompt["prompt"])
|
||||
for lora_req, prompt in zip(lora_requests, engine_prompts):
|
||||
if prompt["type"] == "embeds":
|
||||
raise NotImplementedError(
|
||||
"Embedding prompt not supported for beam search"
|
||||
)
|
||||
if prompt["type"] == "enc_dec":
|
||||
raise NotImplementedError(
|
||||
"Encoder-decoder prompt not supported for beam search"
|
||||
)
|
||||
|
||||
instances.append(
|
||||
BeamSearchInstance(
|
||||
prompt_tokens,
|
||||
prompt,
|
||||
lora_request=lora_req,
|
||||
logprobs=None,
|
||||
**mm_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
for prompt_start in range(0, len(prompts), concurrency_limit):
|
||||
for prompt_start in range(0, len(instances), concurrency_limit):
|
||||
instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
|
||||
|
||||
token_iter = range(max_tokens)
|
||||
@@ -808,22 +763,15 @@ class LLM:
|
||||
if len(all_beams) == 0:
|
||||
break
|
||||
|
||||
# create corresponding batch entries for prompt & optional lora
|
||||
prompts_batch, lora_req_batch = zip(
|
||||
*[
|
||||
(create_tokens_prompt_from_beam(beam), beam.lora_request)
|
||||
for beam in all_beams
|
||||
]
|
||||
)
|
||||
|
||||
# only runs for one step
|
||||
# we don't need to use tqdm here
|
||||
output = self.generate(
|
||||
prompts_batch,
|
||||
sampling_params=beam_search_params,
|
||||
raw_output = self._render_and_run_requests(
|
||||
prompts=(beam.get_prompt() for beam in all_beams),
|
||||
params=self._params_to_seq(sampling_params, len(all_beams)),
|
||||
lora_requests=[beam.lora_request for beam in all_beams],
|
||||
use_tqdm=False,
|
||||
lora_request=lora_req_batch,
|
||||
)
|
||||
output = self.engine_class.validate_outputs(raw_output, RequestOutput)
|
||||
|
||||
for (start, end), instance in zip(
|
||||
instance_start_and_end, instances_batch
|
||||
@@ -841,19 +789,15 @@ class LLM:
|
||||
logprobs = result.outputs[0].logprobs[0]
|
||||
for token_id, logprob_obj in logprobs.items():
|
||||
new_beam = BeamSearchSequence(
|
||||
current_beam.orig_prompt,
|
||||
tokens=current_beam.tokens + [token_id],
|
||||
logprobs=current_beam.logprobs + [logprobs],
|
||||
lora_request=current_beam.lora_request,
|
||||
cum_logprob=current_beam.cum_logprob
|
||||
+ logprob_obj.logprob,
|
||||
multi_modal_data=current_beam.multi_modal_data,
|
||||
mm_processor_kwargs=current_beam.mm_processor_kwargs,
|
||||
)
|
||||
|
||||
if (
|
||||
token_id == tokenizer.eos_token_id
|
||||
and not ignore_eos
|
||||
):
|
||||
if token_id == eos_token_id and not ignore_eos:
|
||||
instance.completed.append(new_beam)
|
||||
else:
|
||||
instance_new_beams.append(new_beam)
|
||||
@@ -872,6 +816,7 @@ class LLM:
|
||||
|
||||
for beam in best_beams:
|
||||
beam.text = tokenizer.decode(beam.tokens)
|
||||
|
||||
outputs.append(BeamSearchOutput(sequences=best_beams))
|
||||
|
||||
return outputs
|
||||
@@ -880,7 +825,7 @@ class LLM:
|
||||
self,
|
||||
prompts: Sequence[PromptType],
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
) -> Sequence[DictPrompt | TokPrompt]:
|
||||
) -> Sequence[ProcessorInputs]:
|
||||
"""
|
||||
Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
|
||||
a format that can be passed to `_add_request`.
|
||||
@@ -888,8 +833,7 @@ class LLM:
|
||||
Refer to [LLM.generate][] for a complete description of the arguments.
|
||||
|
||||
Returns:
|
||||
A list of `TokPrompt` objects containing the tokenized prompt
|
||||
after chat template interpolation, and the raw multi-modal inputs.
|
||||
A list of `ProcessorInputs` objects ready to be passed into LLMEngine.
|
||||
"""
|
||||
renderer = self.renderer
|
||||
model_config = self.model_config
|
||||
@@ -903,6 +847,14 @@ class LLM:
|
||||
|
||||
return renderer.render_cmpl(parsed_prompts, tok_params)
|
||||
|
||||
def _preprocess_cmpl_one(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
) -> ProcessorInputs:
|
||||
(engine_prompt,) = self._preprocess_cmpl([prompt], tokenization_kwargs)
|
||||
return engine_prompt
|
||||
|
||||
def _preprocess_chat(
|
||||
self,
|
||||
conversations: Sequence[list[ChatCompletionMessageParam]],
|
||||
@@ -914,7 +866,7 @@ class LLM:
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
mm_processor_kwargs: dict[str, Any] | None = None,
|
||||
) -> Sequence[TokPrompt]:
|
||||
) -> Sequence[ProcessorInputs]:
|
||||
"""
|
||||
Convert a list of conversations into prompts so that they can then
|
||||
be used as input for other LLM APIs.
|
||||
@@ -922,8 +874,7 @@ class LLM:
|
||||
Refer to [LLM.chat][] for a complete description of the arguments.
|
||||
|
||||
Returns:
|
||||
A list of `TokPrompt` objects containing the tokenized prompt
|
||||
after chat template interpolation, and the raw multi-modal inputs.
|
||||
A list of `ProcessorInputs` objects ready to be passed into LLMEngine.
|
||||
"""
|
||||
renderer = self.renderer
|
||||
|
||||
@@ -953,13 +904,39 @@ class LLM:
|
||||
|
||||
return engine_prompts
|
||||
|
||||
def _preprocess_chat_one(
|
||||
self,
|
||||
conversation: list[ChatCompletionMessageParam],
|
||||
chat_template: str | None = None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
|
||||
chat_template_kwargs: dict[str, Any] | None = None,
|
||||
add_generation_prompt: bool = True,
|
||||
continue_final_message: bool = False,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
mm_processor_kwargs: dict[str, Any] | None = None,
|
||||
) -> ProcessorInputs:
|
||||
(engine_prompt,) = self._preprocess_chat(
|
||||
[conversation],
|
||||
chat_template=chat_template,
|
||||
chat_template_content_format=chat_template_content_format,
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=continue_final_message,
|
||||
tools=tools,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
return engine_prompt
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam]
|
||||
| Sequence[list[ChatCompletionMessageParam]],
|
||||
sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
lora_request: LoRARequest | None = None,
|
||||
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
|
||||
chat_template: str | None = None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
|
||||
add_generation_prompt: bool = True,
|
||||
@@ -1805,47 +1782,41 @@ class LLM:
|
||||
| Sequence[SamplingParams | PoolingParams],
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
lora_request: list[LoRARequest] | LoRARequest | None = None,
|
||||
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
|
||||
priority: list[int] | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
):
|
||||
seq_prompts = prompt_to_seq(prompts)
|
||||
seq_params = self._params_to_seq(params, len(seq_prompts))
|
||||
|
||||
if any(param.truncate_prompt_tokens is not None for param in seq_params):
|
||||
# TODO: Remove this after deprecating `param.truncate_prompt_tokens`
|
||||
# Then, move the code from the `else` block to the top and let
|
||||
# `self._preprocess_cmpl` handle prompt normalization
|
||||
engine_prompts: Sequence[DictPrompt | TokPrompt] = [
|
||||
engine_prompt
|
||||
for prompt, param in zip(seq_prompts, seq_params)
|
||||
for engine_prompt in self._preprocess_cmpl(
|
||||
[prompt],
|
||||
tokenization_kwargs=merge_kwargs(
|
||||
tokenization_kwargs,
|
||||
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
|
||||
),
|
||||
)
|
||||
]
|
||||
else:
|
||||
engine_prompts = self._preprocess_cmpl(
|
||||
seq_prompts,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
|
||||
seq_tok_kwargs = [
|
||||
merge_kwargs(
|
||||
tokenization_kwargs,
|
||||
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
|
||||
)
|
||||
for param in seq_params
|
||||
]
|
||||
seq_priority = self._priority_to_seq(priority, len(prompts))
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=engine_prompts,
|
||||
return self._render_and_run_requests(
|
||||
prompts=(
|
||||
self._preprocess_cmpl_one(prompt, tok_kwargs)
|
||||
for prompt, tok_kwargs in zip(
|
||||
maybe_tqdm(
|
||||
seq_prompts,
|
||||
use_tqdm=use_tqdm,
|
||||
desc="Rendering prompts",
|
||||
),
|
||||
seq_tok_kwargs,
|
||||
)
|
||||
),
|
||||
params=seq_params,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=self._get_modality_specific_lora_reqs(
|
||||
engine_prompts, lora_request
|
||||
),
|
||||
lora_requests=seq_lora_requests,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
priority=priority,
|
||||
priorities=seq_priority,
|
||||
)
|
||||
|
||||
return self._run_engine(use_tqdm=use_tqdm)
|
||||
|
||||
def _run_chat(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam]
|
||||
@@ -1855,7 +1826,7 @@ class LLM:
|
||||
| Sequence[SamplingParams | PoolingParams],
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
lora_request: LoRARequest | None = None,
|
||||
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
|
||||
chat_template: str | None = None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
|
||||
add_generation_prompt: bool = True,
|
||||
@@ -1865,68 +1836,94 @@ class LLM:
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
mm_processor_kwargs: dict[str, Any] | None = None,
|
||||
):
|
||||
engine_prompts = self._preprocess_chat(
|
||||
conversation_to_seq(messages),
|
||||
chat_template=chat_template,
|
||||
chat_template_content_format=chat_template_content_format,
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=continue_final_message,
|
||||
tools=tools,
|
||||
seq_convs = conversation_to_seq(messages)
|
||||
seq_params = self._params_to_seq(params, len(seq_convs))
|
||||
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_convs))
|
||||
seq_tok_kwargs = [
|
||||
merge_kwargs(
|
||||
tokenization_kwargs,
|
||||
dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
|
||||
)
|
||||
for param in seq_params
|
||||
]
|
||||
|
||||
return self._render_and_run_requests(
|
||||
prompts=(
|
||||
self._preprocess_chat_one(
|
||||
conversation,
|
||||
chat_template=chat_template,
|
||||
chat_template_content_format=chat_template_content_format,
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=continue_final_message,
|
||||
tools=tools,
|
||||
tokenization_kwargs=tok_kwargs,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
for conversation, tok_kwargs in zip(
|
||||
maybe_tqdm(
|
||||
seq_convs,
|
||||
use_tqdm=use_tqdm,
|
||||
desc="Rendering conversations",
|
||||
),
|
||||
seq_tok_kwargs,
|
||||
)
|
||||
),
|
||||
params=seq_params,
|
||||
lora_requests=seq_lora_requests,
|
||||
use_tqdm=use_tqdm,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=engine_prompts,
|
||||
def _render_and_run_requests(
|
||||
self,
|
||||
prompts: Iterable[ProcessorInputs],
|
||||
params: Sequence[SamplingParams | PoolingParams],
|
||||
*,
|
||||
lora_requests: Sequence[LoRARequest | None] | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
priorities: Sequence[int] | None = None,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
):
|
||||
if isinstance(prompts, (list, tuple)):
|
||||
logger.warning_once(
|
||||
"Rendering all prompts before adding them to the engine "
|
||||
"is less efficient than performing both on the same prompt "
|
||||
"before processing the next prompt. You should instead pass "
|
||||
"a generator that renders one prompt per iteration, as that allows "
|
||||
"engine execution to begin for the first prompt while processing "
|
||||
"the next prompt."
|
||||
)
|
||||
|
||||
self._render_and_add_requests(
|
||||
prompts=prompts,
|
||||
params=params,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=self._get_modality_specific_lora_reqs(
|
||||
engine_prompts, lora_request
|
||||
),
|
||||
lora_requests=lora_requests,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
priorities=priorities,
|
||||
)
|
||||
|
||||
return self._run_engine(use_tqdm=use_tqdm)
|
||||
|
||||
def _validate_and_add_requests(
|
||||
def _render_and_add_requests(
|
||||
self,
|
||||
prompts: Sequence[DictPrompt | TokPrompt],
|
||||
params: SamplingParams
|
||||
| PoolingParams
|
||||
| Sequence[SamplingParams | PoolingParams],
|
||||
prompts: Iterable[ProcessorInputs],
|
||||
params: Sequence[SamplingParams | PoolingParams],
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm] = True,
|
||||
lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
|
||||
lora_requests: Sequence[LoRARequest | None] | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
priority: list[int] | None = None,
|
||||
priorities: Sequence[int] | None = None,
|
||||
) -> list[str]:
|
||||
num_requests = len(prompts)
|
||||
seq_params = self._params_to_seq(params, num_requests)
|
||||
seq_lora_requests = self._lora_request_to_seq(lora_request, num_requests)
|
||||
seq_priority = self._priority_to_seq(priority, num_requests)
|
||||
|
||||
for sp in seq_params:
|
||||
if isinstance(sp, SamplingParams):
|
||||
# We only care about the final output
|
||||
sp.output_kind = RequestOutputKind.FINAL_ONLY
|
||||
|
||||
# Add requests to the engine.
|
||||
it = prompts
|
||||
if use_tqdm:
|
||||
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
|
||||
it = tqdm_func(it, desc="Adding requests")
|
||||
|
||||
added_request_ids: list[str] = []
|
||||
|
||||
try:
|
||||
for i, prompt in enumerate(it):
|
||||
for i, prompt in enumerate(prompts):
|
||||
request_id = self._add_request(
|
||||
prompt,
|
||||
seq_params[i],
|
||||
lora_request=seq_lora_requests[i],
|
||||
params[i],
|
||||
lora_request=None if lora_requests is None else lora_requests[i],
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
priority=seq_priority[i],
|
||||
priority=0 if priorities is None else priorities[i],
|
||||
)
|
||||
added_request_ids.append(request_id)
|
||||
except Exception as e:
|
||||
@@ -1938,13 +1935,16 @@ class LLM:
|
||||
|
||||
def _add_request(
|
||||
self,
|
||||
prompt: PromptType | DictPrompt | TokPrompt,
|
||||
prompt: ProcessorInputs,
|
||||
params: SamplingParams | PoolingParams,
|
||||
lora_request: LoRARequest | None = None,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
priority: int = 0,
|
||||
) -> str:
|
||||
prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
|
||||
if isinstance(params, SamplingParams):
|
||||
# We only care about the final output
|
||||
params.output_kind = RequestOutputKind.FINAL_ONLY
|
||||
|
||||
request_id = str(next(self.request_counter))
|
||||
|
||||
if params.truncate_prompt_tokens is not None:
|
||||
@@ -1962,33 +1962,15 @@ class LLM:
|
||||
dict(truncate_prompt_tokens=params.truncate_prompt_tokens),
|
||||
)
|
||||
|
||||
renderer = self.renderer
|
||||
tok_params = renderer.default_cmpl_tok_params.with_kwargs(
|
||||
**(tokenization_kwargs or {})
|
||||
)
|
||||
|
||||
tokenization_kwargs = tok_params.get_encode_kwargs()
|
||||
engine_request = self.input_processor.process_inputs(
|
||||
return self.llm_engine.add_request(
|
||||
request_id,
|
||||
prompt,
|
||||
params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
priority=priority,
|
||||
supported_tasks=self.supported_tasks,
|
||||
)
|
||||
|
||||
self.llm_engine.add_request(
|
||||
request_id,
|
||||
engine_request,
|
||||
params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
priority=priority,
|
||||
prompt_text=prompt_text,
|
||||
)
|
||||
return engine_request.request_id
|
||||
|
||||
def _run_engine(
|
||||
self,
|
||||
*,
|
||||
|
||||
@@ -67,13 +67,12 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
|
||||
)
|
||||
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
|
||||
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.inputs.data import ProcessorInputs, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.parser import ParserManager
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.renderers.inputs import TokPrompt
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.mistral import (
|
||||
@@ -221,7 +220,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
async def render_chat_request(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> tuple[list[ConversationMessage], list[TokPrompt]] | ErrorResponse:
|
||||
) -> tuple[list[ConversationMessage], list[ProcessorInputs]] | ErrorResponse:
|
||||
"""
|
||||
render chat request by validating and preprocessing inputs.
|
||||
|
||||
@@ -380,7 +379,9 @@ class OpenAIServingChat(OpenAIServing):
|
||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
prompt_text = self._extract_prompt_text(engine_prompt)
|
||||
prompt_token_ids = self._extract_prompt_components(
|
||||
engine_prompt
|
||||
).token_ids
|
||||
|
||||
# If we are creating sub requests for multiple prompts, ensure that they
|
||||
# have unique request ids.
|
||||
@@ -431,35 +432,21 @@ class OpenAIServingChat(OpenAIServing):
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
else:
|
||||
tok_params = request.build_tok_params(self.model_config)
|
||||
tokenization_kwargs = tok_params.get_encode_kwargs()
|
||||
reasoning_ended = (
|
||||
reasoning_parser.is_reasoning_end(prompt_token_ids or [])
|
||||
if reasoning_parser
|
||||
else None
|
||||
)
|
||||
|
||||
engine_request = self.input_processor.process_inputs(
|
||||
sub_request_id,
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
reasoning_ended = None
|
||||
if reasoning_parser:
|
||||
reasoning_ended = reasoning_parser.is_reasoning_end(
|
||||
engine_request.prompt_token_ids or [] # type: ignore[attr-defined]
|
||||
)
|
||||
engine_request.reasoning_ended = reasoning_ended
|
||||
generator = self.engine_client.generate(
|
||||
engine_request,
|
||||
sampling_params,
|
||||
sub_request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
prompt_text=prompt_text,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
reasoning_ended=reasoning_ended,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
||||
@@ -34,10 +34,10 @@ from vllm.entrypoints.openai.engine.serving import (
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs.data import ProcessorInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.renderers.inputs import TokPrompt
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
@@ -80,7 +80,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
async def render_completion_request(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
) -> list[TokPrompt] | ErrorResponse:
|
||||
) -> list[ProcessorInputs] | ErrorResponse:
|
||||
"""
|
||||
render completion request by validating and preprocessing inputs.
|
||||
|
||||
@@ -163,8 +163,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
prompt_text = self._extract_prompt_text(engine_prompt)
|
||||
|
||||
max_tokens = get_max_tokens(
|
||||
max_model_len,
|
||||
request.max_tokens,
|
||||
@@ -208,29 +206,13 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
else:
|
||||
tok_params = request.build_tok_params(self.model_config)
|
||||
tokenization_kwargs = tok_params.get_encode_kwargs()
|
||||
|
||||
engine_request = self.input_processor.process_inputs(
|
||||
request_id_item,
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
|
||||
generator = self.engine_client.generate(
|
||||
engine_request,
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
prompt_text=prompt_text,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
|
||||
@@ -312,7 +294,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
async def completion_stream_generator(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
engine_prompts: list[TokPrompt],
|
||||
engine_prompts: list[ProcessorInputs],
|
||||
result_generator: AsyncIterator[tuple[int, RequestOutput]],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
|
||||
@@ -96,15 +96,19 @@ from vllm.entrypoints.serve.tokenize.protocol import (
|
||||
)
|
||||
from vllm.entrypoints.utils import get_max_tokens, sanitize_message
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs.data import PromptType, SingletonPrompt, TokensPrompt
|
||||
from vllm.inputs.data import (
|
||||
ProcessorInputs,
|
||||
PromptType,
|
||||
SingletonPrompt,
|
||||
TokensPrompt,
|
||||
token_inputs,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob, PromptLogprobs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
|
||||
from vllm.renderers.inputs import TokPrompt
|
||||
from vllm.renderers.inputs.preprocess import (
|
||||
extract_prompt_components,
|
||||
extract_prompt_len,
|
||||
@@ -206,7 +210,7 @@ class ServeContext(Generic[RequestT]):
|
||||
request_id: str
|
||||
created_time: int = field(default_factory=lambda: int(time.time()))
|
||||
lora_request: LoRARequest | None = None
|
||||
engine_prompts: list[TokPrompt] | None = None
|
||||
engine_prompts: list[ProcessorInputs] | None = None
|
||||
|
||||
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
|
||||
None
|
||||
@@ -249,7 +253,7 @@ class OpenAIServing:
|
||||
|
||||
async def beam_search(
|
||||
self,
|
||||
prompt: TokPrompt,
|
||||
prompt: ProcessorInputs,
|
||||
request_id: str,
|
||||
params: BeamSearchParams,
|
||||
lora_request: LoRARequest | None = None,
|
||||
@@ -262,86 +266,53 @@ class OpenAIServing:
|
||||
length_penalty = params.length_penalty
|
||||
include_stop_str_in_output = params.include_stop_str_in_output
|
||||
|
||||
input_processor = self.input_processor
|
||||
tokenizer = input_processor.tokenizer
|
||||
if tokenizer is None:
|
||||
raise VLLMValidationError(
|
||||
"You cannot use beam search when `skip_tokenizer_init=True`",
|
||||
parameter="skip_tokenizer_init",
|
||||
value=True,
|
||||
)
|
||||
|
||||
eos_token_id: int = tokenizer.eos_token_id # type: ignore
|
||||
|
||||
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
|
||||
raise NotImplementedError("Encoder-decoder prompt not supported")
|
||||
|
||||
prompt_text: str | None = prompt.get("prompt") # type: ignore
|
||||
prompt_token_ids: list[int] = prompt.get("prompt_token_ids", []) # type: ignore
|
||||
multi_modal_data: MultiModalDataDict | None = prompt.get("multi_modal_data") # type: ignore
|
||||
|
||||
mm_processor_kwargs: dict[str, Any] | None = None
|
||||
|
||||
# This is a workaround to fix multimodal beam search; this is a
|
||||
# bandaid fix for 2 small problems:
|
||||
# 1. Multi_modal_data on the processed_inputs currently resolves to
|
||||
# `None`.
|
||||
# 2. preprocessing above expands the multimodal placeholders. However,
|
||||
# this happens again in generation, so the double expansion causes
|
||||
# a mismatch.
|
||||
# TODO - would be ideal to handle this more gracefully.
|
||||
|
||||
tokenized_length = len(prompt_token_ids)
|
||||
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
eos_token_id = tokenizer.eos_token_id
|
||||
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
|
||||
|
||||
if prompt["type"] == "embeds":
|
||||
raise NotImplementedError("Embedding prompt not supported for beam search")
|
||||
if prompt["type"] == "enc_dec":
|
||||
raise NotImplementedError(
|
||||
"Encoder-decoder prompt not supported for beam search"
|
||||
)
|
||||
|
||||
prompt_text = prompt.get("prompt")
|
||||
prompt_token_ids = prompt["prompt_token_ids"]
|
||||
tokenized_length = len(prompt_token_ids)
|
||||
|
||||
logprobs_num = 2 * beam_width
|
||||
beam_search_params = SamplingParams(
|
||||
sampling_params = SamplingParams(
|
||||
logprobs=logprobs_num,
|
||||
max_tokens=1,
|
||||
temperature=temperature,
|
||||
)
|
||||
all_beams = [
|
||||
BeamSearchSequence(
|
||||
orig_prompt=prompt,
|
||||
tokens=prompt_token_ids,
|
||||
cum_logprob=0,
|
||||
logprobs=[],
|
||||
multi_modal_data=multi_modal_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
]
|
||||
completed = []
|
||||
|
||||
for _ in range(max_tokens):
|
||||
prompts_batch, lora_req_batch = zip(
|
||||
*[
|
||||
(
|
||||
TokensPrompt(
|
||||
prompt_token_ids=beam.tokens,
|
||||
multi_modal_data=beam.multi_modal_data,
|
||||
mm_processor_kwargs=beam.mm_processor_kwargs,
|
||||
),
|
||||
beam.lora_request,
|
||||
)
|
||||
for beam in all_beams
|
||||
]
|
||||
)
|
||||
|
||||
tasks = []
|
||||
request_id_batch = f"{request_id}-{random_uuid()}"
|
||||
|
||||
for i, (individual_prompt, lora_req) in enumerate(
|
||||
zip(prompts_batch, lora_req_batch)
|
||||
):
|
||||
for i, beam in enumerate(all_beams):
|
||||
prompt_item = beam.get_prompt()
|
||||
lora_request_item = beam.lora_request
|
||||
request_id_item = f"{request_id_batch}-beam-{i}"
|
||||
task = asyncio.create_task(
|
||||
collect_from_async_generator(
|
||||
self.engine_client.generate(
|
||||
individual_prompt,
|
||||
beam_search_params,
|
||||
prompt_item,
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_req,
|
||||
lora_request=lora_request_item,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
)
|
||||
@@ -406,6 +377,7 @@ class OpenAIServing:
|
||||
logprobs_entry = result.outputs[0].logprobs[0]
|
||||
completed.append(
|
||||
BeamSearchSequence(
|
||||
orig_prompt=prompt,
|
||||
tokens=current_beam.tokens + [eos_token_id]
|
||||
if include_stop_str_in_output
|
||||
else current_beam.tokens,
|
||||
@@ -433,12 +405,11 @@ class OpenAIServing:
|
||||
logprobs_entry = result.outputs[0].logprobs[0]
|
||||
new_beams.append(
|
||||
BeamSearchSequence(
|
||||
orig_prompt=prompt,
|
||||
tokens=current_beam.tokens + [token_id],
|
||||
logprobs=current_beam.logprobs + [logprobs_entry],
|
||||
lora_request=current_beam.lora_request,
|
||||
cum_logprob=float(all_beams_logprob[idx]),
|
||||
multi_modal_data=current_beam.multi_modal_data,
|
||||
mm_processor_kwargs=current_beam.mm_processor_kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -958,7 +929,7 @@ class OpenAIServing:
|
||||
request: RendererRequest,
|
||||
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
|
||||
prompt_embeds: bytes | list[bytes] | None,
|
||||
) -> list[TokPrompt]:
|
||||
) -> list[ProcessorInputs]:
|
||||
prompts = list[SingletonPrompt | bytes]()
|
||||
if prompt_embeds is not None: # embeds take higher priority
|
||||
prompts.extend(prompt_to_seq(prompt_embeds))
|
||||
@@ -971,7 +942,7 @@ class OpenAIServing:
|
||||
self,
|
||||
request: RendererRequest,
|
||||
prompts: Sequence[PromptType | bytes],
|
||||
) -> list[TokPrompt]:
|
||||
) -> list[ProcessorInputs]:
|
||||
renderer = self.renderer
|
||||
model_config = self.model_config
|
||||
|
||||
@@ -1004,7 +975,7 @@ class OpenAIServing:
|
||||
default_template_kwargs: dict[str, Any] | None,
|
||||
tool_dicts: list[dict[str, Any]] | None = None,
|
||||
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
|
||||
) -> tuple[list[ConversationMessage], list[TokPrompt]]:
|
||||
) -> tuple[list[ConversationMessage], list[ProcessorInputs]]:
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
renderer = self.renderer
|
||||
@@ -1052,13 +1023,13 @@ class OpenAIServing:
|
||||
|
||||
return conversation, [engine_prompt]
|
||||
|
||||
def _extract_prompt_components(self, prompt: object):
|
||||
def _extract_prompt_components(self, prompt: PromptType | ProcessorInputs):
|
||||
return extract_prompt_components(self.model_config, prompt)
|
||||
|
||||
def _extract_prompt_text(self, prompt: object):
|
||||
def _extract_prompt_text(self, prompt: ProcessorInputs):
|
||||
return self._extract_prompt_components(prompt).text
|
||||
|
||||
def _extract_prompt_len(self, prompt: object):
|
||||
def _extract_prompt_len(self, prompt: ProcessorInputs):
|
||||
return extract_prompt_len(self.model_config, prompt)
|
||||
|
||||
async def _render_next_turn(
|
||||
@@ -1088,16 +1059,14 @@ class OpenAIServing:
|
||||
async def _generate_with_builtin_tools(
|
||||
self,
|
||||
request_id: str,
|
||||
engine_prompt: TokPrompt,
|
||||
engine_prompt: ProcessorInputs,
|
||||
sampling_params: SamplingParams,
|
||||
tok_params: TokenizeParams,
|
||||
context: ConversationContext,
|
||||
lora_request: LoRARequest | None = None,
|
||||
priority: int = 0,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
):
|
||||
max_model_len = self.model_config.max_model_len
|
||||
prompt_text = self._extract_prompt_text(engine_prompt)
|
||||
|
||||
orig_priority = priority
|
||||
sub_request = 0
|
||||
@@ -1112,26 +1081,13 @@ class OpenAIServing:
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
tokenization_kwargs = tok_params.get_encode_kwargs()
|
||||
engine_request = self.input_processor.process_inputs(
|
||||
sub_request_id,
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
generator = self.engine_client.generate(
|
||||
engine_request,
|
||||
sampling_params,
|
||||
sub_request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
prompt_text=prompt_text,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
async for res in generator:
|
||||
@@ -1154,11 +1110,11 @@ class OpenAIServing:
|
||||
# Render the next prompt token ids and update sampling_params.
|
||||
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
|
||||
token_ids = context.render_for_completion()
|
||||
engine_prompt = TokensPrompt(prompt_token_ids=token_ids)
|
||||
engine_prompt = token_inputs(token_ids)
|
||||
|
||||
sampling_params.max_tokens = max_model_len - len(token_ids)
|
||||
elif isinstance(context, ParsableContext):
|
||||
engine_prompts = await self._render_next_turn(
|
||||
(engine_prompt,) = await self._render_next_turn(
|
||||
context.request,
|
||||
context.parser.response_messages,
|
||||
context.tool_dicts,
|
||||
@@ -1166,8 +1122,6 @@ class OpenAIServing:
|
||||
context.chat_template,
|
||||
context.chat_template_content_format,
|
||||
)
|
||||
engine_prompt = engine_prompts[0]
|
||||
prompt_text = self._extract_prompt_text(engine_prompt)
|
||||
|
||||
sampling_params.max_tokens = get_max_tokens(
|
||||
max_model_len,
|
||||
@@ -1184,7 +1138,7 @@ class OpenAIServing:
|
||||
def _log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: PromptType | TokPrompt,
|
||||
inputs: PromptType | ProcessorInputs,
|
||||
params: SamplingParams | PoolingParams | BeamSearchParams | None,
|
||||
lora_request: LoRARequest | None,
|
||||
) -> None:
|
||||
|
||||
@@ -15,6 +15,7 @@ from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import SupportsRealtime
|
||||
from vllm.renderers.inputs.preprocess import parse_model_prompt
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -70,15 +71,20 @@ class OpenAIServingRealtime(OpenAIServing):
|
||||
Yields:
|
||||
StreamingInput objects containing audio prompts for the engine
|
||||
"""
|
||||
model_config = self.model_config
|
||||
renderer = self.renderer
|
||||
|
||||
# mypy is being stupid
|
||||
# TODO(Patrick) - fix this
|
||||
stream_input_iter = cast(
|
||||
AsyncGenerator[PromptType, None],
|
||||
self.model_cls.buffer_realtime_audio(
|
||||
audio_stream, input_stream, self.model_config
|
||||
audio_stream, input_stream, model_config
|
||||
),
|
||||
)
|
||||
|
||||
async for prompt in stream_input_iter:
|
||||
yield StreamingInput(prompt=prompt)
|
||||
parsed_prompt = parse_model_prompt(model_config, prompt)
|
||||
(engine_prompt,) = await renderer.render_cmpl_async([parsed_prompt])
|
||||
|
||||
yield StreamingInput(prompt=engine_prompt)
|
||||
|
||||
@@ -9,7 +9,7 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from contextlib import AsyncExitStack
|
||||
from dataclasses import replace
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from typing import TYPE_CHECKING, Final, Union
|
||||
|
||||
from openai.types.responses.response_function_tool_call_output_item import (
|
||||
ResponseFunctionToolCallOutputItem,
|
||||
@@ -304,7 +304,7 @@ class ParsableContext(ConversationContext):
|
||||
|
||||
self.tool_dicts = construct_tool_dicts(request.tools, request.tool_choice)
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format = chat_template_content_format
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
|
||||
self.input_messages: list[ResponseRawMessageAndToken] = []
|
||||
self.output_messages: list[ResponseRawMessageAndToken] = []
|
||||
|
||||
@@ -116,13 +116,12 @@ from vllm.entrypoints.openai.responses.utils import (
|
||||
)
|
||||
from vllm.entrypoints.utils import get_max_tokens
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.inputs.data import ProcessorInputs, token_inputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob as SampleLogprob
|
||||
from vllm.logprobs import SampleLogprobs
|
||||
from vllm.outputs import CompletionOutput
|
||||
from vllm.parser import ParserManager
|
||||
from vllm.renderers.inputs import TokPrompt
|
||||
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils import random_uuid
|
||||
@@ -298,7 +297,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
|
||||
def _validate_generator_input(
|
||||
self,
|
||||
engine_prompt: TokPrompt,
|
||||
engine_prompt: ProcessorInputs,
|
||||
) -> ErrorResponse | None:
|
||||
"""Add validations to the input to the generator here."""
|
||||
prompt_len = self._extract_prompt_len(engine_prompt)
|
||||
@@ -458,7 +457,6 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens, self.default_sampling_params
|
||||
)
|
||||
tok_params = request.build_tok_params(self.model_config)
|
||||
|
||||
trace_headers = (
|
||||
None
|
||||
@@ -512,7 +510,6 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
request_id=request.request_id,
|
||||
engine_prompt=engine_prompt,
|
||||
sampling_params=sampling_params,
|
||||
tok_params=tok_params,
|
||||
context=context,
|
||||
lora_request=lora_request,
|
||||
priority=request.priority,
|
||||
@@ -647,7 +644,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
|
||||
messages = self._construct_input_messages_with_harmony(request, prev_response)
|
||||
prompt_token_ids = render_for_completion(messages)
|
||||
engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
|
||||
engine_prompt = token_inputs(prompt_token_ids)
|
||||
|
||||
# Add cache_salt if provided in the request
|
||||
if request.cache_salt is not None:
|
||||
|
||||
@@ -36,14 +36,15 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
|
||||
TranslationSegment,
|
||||
TranslationStreamResponse,
|
||||
)
|
||||
from vllm.entrypoints.utils import get_max_tokens
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.inputs import ProcessorInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import FlatLogprobs, Logprob
|
||||
from vllm.model_executor.models import SupportsTranscription, supports_transcription
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.renderers.inputs import EncoderDecoderDictPrompt
|
||||
from vllm.renderers.inputs.preprocess import parse_enc_dec_prompt
|
||||
from vllm.renderers.inputs import DictPrompt, EncoderDecoderDictPrompt
|
||||
from vllm.renderers.inputs.preprocess import parse_enc_dec_prompt, parse_model_prompt
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
@@ -202,8 +203,6 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
return
|
||||
|
||||
try:
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
warmup_start = time.perf_counter()
|
||||
logger.info("Warming up multimodal input processor...")
|
||||
|
||||
@@ -221,21 +220,11 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
request_prompt="",
|
||||
to_language=None,
|
||||
)
|
||||
|
||||
# Create minimal sampling params
|
||||
dummy_params = SamplingParams(
|
||||
max_tokens=1,
|
||||
temperature=0.0,
|
||||
skip_clone=True, # Internal warmup, safe to skip clone
|
||||
)
|
||||
parsed_prompt = parse_model_prompt(self.model_config, dummy_prompt)
|
||||
|
||||
# Process the dummy input through the input processor
|
||||
# This will trigger all the multimodal processing initialization
|
||||
_ = self.input_processor.process_inputs(
|
||||
request_id="warmup",
|
||||
prompt=dummy_prompt,
|
||||
params=dummy_params,
|
||||
)
|
||||
_ = self.renderer.render_cmpl([parsed_prompt])
|
||||
|
||||
warmup_elapsed = time.perf_counter() - warmup_start
|
||||
logger.info("Input processor warmup completed in %.2fs", warmup_elapsed)
|
||||
@@ -257,7 +246,7 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
self,
|
||||
request: SpeechToTextRequest,
|
||||
audio_data: bytes,
|
||||
) -> tuple[list[PromptType], float]:
|
||||
) -> tuple[list[ProcessorInputs], float]:
|
||||
# Validate request
|
||||
language = self.model_cls.validate_language(request.language)
|
||||
# Skip to_language validation to avoid extra logging for Whisper.
|
||||
@@ -285,7 +274,7 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
and duration > self.asr_config.max_audio_clip_s
|
||||
)
|
||||
chunks = [y] if not do_split_audio else self._split_audio(y, int(sr))
|
||||
prompts = []
|
||||
parsed_prompts: list[DictPrompt] = []
|
||||
for chunk in chunks:
|
||||
# The model has control over the construction, as long as it
|
||||
# returns a valid PromptType.
|
||||
@@ -298,12 +287,19 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
request_prompt=request.prompt,
|
||||
to_language=to_language,
|
||||
)
|
||||
|
||||
parsed_prompt: DictPrompt
|
||||
if request.response_format == "verbose_json":
|
||||
prompt = self._preprocess_verbose_prompt(parse_enc_dec_prompt(prompt))
|
||||
parsed_prompt = parse_enc_dec_prompt(prompt)
|
||||
parsed_prompt = self._preprocess_verbose_prompt(parsed_prompt)
|
||||
else:
|
||||
parsed_prompt = parse_model_prompt(self.model_config, prompt)
|
||||
|
||||
prompts.append(prompt)
|
||||
parsed_prompts.append(parsed_prompt)
|
||||
|
||||
return prompts, duration
|
||||
engine_prompts = await self.renderer.render_cmpl_async(parsed_prompts)
|
||||
|
||||
return engine_prompts, duration
|
||||
|
||||
def _preprocess_verbose_prompt(self, prompt: EncoderDecoderDictPrompt):
|
||||
dec_prompt = prompt["decoder_prompt"]
|
||||
@@ -436,7 +432,7 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
prompts, duration_s = await self._preprocess_speech_to_text(
|
||||
engine_prompts, duration_s = await self._preprocess_speech_to_text(
|
||||
request=request,
|
||||
audio_data=audio_data,
|
||||
)
|
||||
@@ -445,57 +441,54 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(e)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
max_model_len = self.model_config.max_model_len
|
||||
list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None
|
||||
try:
|
||||
# Unlike most decoder-only models, whisper generation length is not
|
||||
# constrained by the size of the input audio, which is mapped to a
|
||||
# fixed-size log-mel-spectogram. Still, allow for fewer tokens to be
|
||||
# generated by respecting the extra completion tokens arg.
|
||||
if request.max_completion_tokens is None:
|
||||
default_max_tokens = self.model_config.max_model_len
|
||||
else:
|
||||
default_max_tokens = min(
|
||||
self.model_config.max_model_len, request.max_completion_tokens
|
||||
)
|
||||
max_tokens = get_max_tokens(
|
||||
max_model_len,
|
||||
request.max_completion_tokens,
|
||||
0,
|
||||
self.default_sampling_params,
|
||||
)
|
||||
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens, self.default_sampling_params
|
||||
max_tokens,
|
||||
self.default_sampling_params,
|
||||
)
|
||||
if request.response_format == "verbose_json":
|
||||
sampling_params.logprobs = 1
|
||||
|
||||
self._log_inputs(
|
||||
request_id,
|
||||
# It will not display special tokens like <|startoftranscript|>
|
||||
request.prompt,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
trace_headers = (
|
||||
None
|
||||
if raw_request is None
|
||||
else await self._get_trace_headers(raw_request.headers)
|
||||
)
|
||||
|
||||
list_result_generator = []
|
||||
for i, prompt in enumerate(prompts):
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}_{i}"
|
||||
engine_request = self.input_processor.process_inputs(
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
prompt,
|
||||
engine_prompt,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
trace_headers = (
|
||||
None
|
||||
if raw_request is None
|
||||
else await self._get_trace_headers(raw_request.headers)
|
||||
)
|
||||
|
||||
generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=0,
|
||||
)
|
||||
list_result_generator.append(
|
||||
self.engine_client.generate(
|
||||
engine_request,
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
)
|
||||
|
||||
list_result_generator.append(generator)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(e)
|
||||
|
||||
|
||||
@@ -28,11 +28,10 @@ from vllm.entrypoints.pooling.utils import (
|
||||
encode_pooling_output_base64,
|
||||
encode_pooling_output_float,
|
||||
)
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.inputs.data import ProcessorInputs, TokensPrompt, token_inputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers.inputs import TokPrompt
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
from vllm.utils.collection_utils import chunk_list
|
||||
from vllm.utils.serial_utils import EmbedDType, Endianness
|
||||
@@ -256,7 +255,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
chunk_request_id = f"{ctx.request_id}-prompt-{prompt_idx}-chunk-{chunk_idx}"
|
||||
|
||||
# Create engine prompt for this chunk
|
||||
chunk_engine_prompt = TokensPrompt(prompt_token_ids=chunk_tokens)
|
||||
chunk_engine_prompt = token_inputs(chunk_tokens)
|
||||
|
||||
# Log the chunk
|
||||
self._log_inputs(
|
||||
@@ -266,16 +265,12 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
lora_request=ctx.lora_request,
|
||||
)
|
||||
|
||||
tok_params = ctx.request.build_tok_params(self.model_config)
|
||||
tokenization_kwargs = tok_params.get_encode_kwargs()
|
||||
|
||||
# Create generator for this chunk and wrap it to return indices
|
||||
original_generator = self.engine_client.encode(
|
||||
chunk_engine_prompt,
|
||||
pooling_params,
|
||||
chunk_request_id,
|
||||
lora_request=ctx.lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
priority=ctx.request.priority,
|
||||
)
|
||||
@@ -362,7 +357,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
async def _create_single_prompt_generator(
|
||||
self,
|
||||
ctx: EmbeddingServeContext,
|
||||
engine_prompt: TokPrompt,
|
||||
engine_prompt: ProcessorInputs,
|
||||
pooling_params: PoolingParams,
|
||||
trace_headers: Mapping[str, str] | None,
|
||||
prompt_index: int,
|
||||
@@ -377,16 +372,12 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
lora_request=ctx.lora_request,
|
||||
)
|
||||
|
||||
tok_params = ctx.request.build_tok_params(self.model_config)
|
||||
tokenization_kwargs = tok_params.get_encode_kwargs()
|
||||
|
||||
# Return the original generator without wrapping
|
||||
return self.engine_client.encode(
|
||||
engine_prompt,
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=ctx.lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
priority=ctx.request.priority,
|
||||
)
|
||||
|
||||
@@ -33,10 +33,9 @@ from vllm.entrypoints.pooling.utils import (
|
||||
encode_pooling_output_base64,
|
||||
encode_pooling_output_float,
|
||||
)
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.inputs import ProcessorInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.renderers.inputs import TokPrompt
|
||||
from vllm.renderers.inputs.preprocess import prompt_to_seq
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
|
||||
@@ -93,7 +92,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
"dimensions is currently not supported"
|
||||
)
|
||||
|
||||
engine_prompts: Sequence[PromptType | TokPrompt]
|
||||
engine_prompts: Sequence[ProcessorInputs]
|
||||
if use_io_processor := isinstance(request, IOProcessorRequest):
|
||||
if self.io_processor is None:
|
||||
raise ValueError(
|
||||
@@ -152,9 +151,6 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
else:
|
||||
pooling_params = request.to_pooling_params() # type: ignore
|
||||
|
||||
tok_params = request.build_tok_params(self.model_config)
|
||||
tokenization_kwargs = tok_params.get_encode_kwargs()
|
||||
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
@@ -176,7 +172,6 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
pooling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
@@ -35,7 +35,7 @@ from vllm.entrypoints.pooling.score.utils import (
|
||||
get_score_prompt,
|
||||
validate_score_input,
|
||||
)
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.inputs.data import ProcessorInputs, TokensPrompt, token_inputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
|
||||
@@ -108,12 +108,15 @@ class ServingScores(OpenAIServing):
|
||||
*(encode_async(t, **tokenization_kwargs) for t in input_texts)
|
||||
)
|
||||
|
||||
engine_prompts: list[TokensPrompt] = []
|
||||
engine_prompts: list[ProcessorInputs] = []
|
||||
for tok_result, input_text in zip(tokenized_prompts, input_texts):
|
||||
text_token_prompt = self._validate_input(request, tok_result, input_text)
|
||||
|
||||
engine_prompts.append(
|
||||
TokensPrompt(prompt_token_ids=text_token_prompt["prompt_token_ids"])
|
||||
token_inputs(
|
||||
text_token_prompt["prompt_token_ids"],
|
||||
prompt=input_text,
|
||||
)
|
||||
)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
@@ -125,7 +128,7 @@ class ServingScores(OpenAIServing):
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
input_texts[i],
|
||||
engine_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
@@ -207,12 +210,15 @@ class ServingScores(OpenAIServing):
|
||||
*(encode_async(t, **tokenization_kwargs) for t in input_texts)
|
||||
)
|
||||
|
||||
engine_prompts: list[TokensPrompt] = []
|
||||
engine_prompts: list[ProcessorInputs] = []
|
||||
for tok_result, input_text in zip(tokenized_prompts, input_texts):
|
||||
text_token_prompt = self._validate_input(request, tok_result, input_text)
|
||||
|
||||
engine_prompts.append(
|
||||
TokensPrompt(prompt_token_ids=text_token_prompt["prompt_token_ids"])
|
||||
token_inputs(
|
||||
text_token_prompt["prompt_token_ids"],
|
||||
prompt=input_text,
|
||||
)
|
||||
)
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
@@ -225,7 +231,7 @@ class ServingScores(OpenAIServing):
|
||||
|
||||
self._log_inputs(
|
||||
request_id_item,
|
||||
input_texts[i],
|
||||
engine_prompt,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
@@ -29,7 +29,6 @@ from vllm.entrypoints.serve.disagg.protocol import (
|
||||
GenerateResponse,
|
||||
GenerateResponseChoice,
|
||||
)
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.outputs import RequestOutput
|
||||
@@ -116,7 +115,7 @@ class ServingTokens(OpenAIServing):
|
||||
|
||||
self._log_inputs(
|
||||
request_id,
|
||||
TokensPrompt(prompt_token_ids=request.token_ids),
|
||||
engine_prompt,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
@@ -127,27 +126,13 @@ class ServingTokens(OpenAIServing):
|
||||
else await self._get_trace_headers(raw_request.headers)
|
||||
)
|
||||
|
||||
tok_params = request.build_tok_params(self.model_config)
|
||||
tokenization_kwargs = tok_params.get_encode_kwargs()
|
||||
|
||||
engine_request = self.input_processor.process_inputs(
|
||||
request_id,
|
||||
result_generator = self.engine_client.generate(
|
||||
engine_prompt,
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
result_generator = self.engine_client.generate(
|
||||
engine_request,
|
||||
sampling_params,
|
||||
request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
|
||||
@@ -20,7 +20,7 @@ from vllm.entrypoints.serve.tokenize.protocol import (
|
||||
TokenizeResponse,
|
||||
TokenizerInfoResponse,
|
||||
)
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.inputs import TokensPrompt, token_inputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
@@ -135,7 +135,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
|
||||
self._log_inputs(
|
||||
request_id,
|
||||
TokensPrompt(prompt_token_ids=request.tokens),
|
||||
token_inputs(request.tokens),
|
||||
params=None,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
@@ -187,6 +187,9 @@ class _InputOptions(TypedDict):
|
||||
Additional options available to all input types.
|
||||
"""
|
||||
|
||||
arrival_time: NotRequired[float]
|
||||
"""The time when the input was received (before rendering)."""
|
||||
|
||||
cache_salt: NotRequired[str]
|
||||
"""Optional cache salt to be used for prefix caching."""
|
||||
|
||||
@@ -300,6 +303,9 @@ class EncoderDecoderInputs(TypedDict):
|
||||
decoder_prompt: DecoderInputs
|
||||
"""The inputs for the decoder portion."""
|
||||
|
||||
arrival_time: NotRequired[float]
|
||||
"""The time when the input was received (before rendering)."""
|
||||
|
||||
|
||||
ProcessorInputs: TypeAlias = DecoderOnlyInputs | EncoderDecoderInputs
|
||||
"""
|
||||
|
||||
@@ -19,11 +19,9 @@ from vllm.renderers import BaseRenderer, renderer_from_config
|
||||
from vllm.renderers.inputs import (
|
||||
DecoderDictPrompt,
|
||||
DecoderOnlyDictPrompt,
|
||||
DictPrompt,
|
||||
EncoderDecoderDictPrompt,
|
||||
EncoderDictPrompt,
|
||||
SingletonDictPrompt,
|
||||
TokPrompt,
|
||||
)
|
||||
from vllm.renderers.inputs.preprocess import parse_dec_only_prompt, parse_enc_dec_prompt
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
@@ -41,7 +39,6 @@ from .data import (
|
||||
TextPrompt,
|
||||
TokenInputs,
|
||||
TokensPrompt,
|
||||
embeds_inputs,
|
||||
token_inputs,
|
||||
)
|
||||
|
||||
@@ -83,7 +80,7 @@ class InputPreprocessor:
|
||||
**(tokenization_kwargs or {})
|
||||
)
|
||||
|
||||
tok_prompt = renderer.tokenize_prompt(
|
||||
tok_prompt = renderer._tokenize_singleton_prompt(
|
||||
TextPrompt(prompt=prompt),
|
||||
tok_params,
|
||||
)
|
||||
@@ -103,17 +100,10 @@ class InputPreprocessor:
|
||||
Apply the model's multi-modal processor to a multi-modal prompt,
|
||||
returning the corresponding token IDs and metadata.
|
||||
"""
|
||||
mm_processor = self.renderer.get_mm_processor()
|
||||
|
||||
if mm_processor_kwargs is None:
|
||||
mm_processor_kwargs = {}
|
||||
|
||||
mm_items = mm_processor.info.parse_mm_data(mm_data)
|
||||
|
||||
return mm_processor.apply(
|
||||
return self.renderer._process_multimodal(
|
||||
prompt,
|
||||
mm_items,
|
||||
hf_processor_mm_kwargs=mm_processor_kwargs,
|
||||
mm_data,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
@@ -122,31 +112,7 @@ class InputPreprocessor:
|
||||
self,
|
||||
parsed_content: EmbedsPrompt,
|
||||
) -> EmbedsInputs:
|
||||
if not self.model_config.enable_prompt_embeds:
|
||||
raise ValueError(
|
||||
"You must set `--enable-prompt-embeds` to input `prompt_embeds`."
|
||||
)
|
||||
|
||||
prompt_embeds = parsed_content["prompt_embeds"]
|
||||
|
||||
# prompt_embeds must be (seq_len, hidden_size), but if the user
|
||||
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
|
||||
# we can unambiguously process the intent by squeezing the batch
|
||||
# dimension.
|
||||
if prompt_embeds.ndim == 3:
|
||||
prompt_embeds = prompt_embeds.squeeze(dim=0)
|
||||
|
||||
if prompt_embeds.ndim != 2:
|
||||
raise ValueError("prompt_embeds must be of shape (seq_len, hidden_size).")
|
||||
|
||||
# Tensors must be on CPU for serialization between processes
|
||||
# in the MsgpackEncoder. Casting to CPU here ensures that there is no
|
||||
# hidden device transfer in the critical path of generation.
|
||||
prompt_embeds = prompt_embeds.cpu()
|
||||
|
||||
return embeds_inputs(
|
||||
prompt_embeds=prompt_embeds, cache_salt=parsed_content.get("cache_salt")
|
||||
)
|
||||
return self.renderer._process_embeds(parsed_content)
|
||||
|
||||
def _truncate_inputs(
|
||||
self, inputs: list[int], tokenization_kwargs: dict[str, Any] | None = None
|
||||
@@ -157,7 +123,7 @@ class InputPreprocessor:
|
||||
**(tokenization_kwargs or {})
|
||||
)
|
||||
|
||||
tok_prompt = renderer.tokenize_prompt(
|
||||
tok_prompt = renderer._tokenize_singleton_prompt(
|
||||
TokensPrompt(prompt_token_ids=inputs),
|
||||
tok_params,
|
||||
)
|
||||
@@ -168,8 +134,6 @@ class InputPreprocessor:
|
||||
self,
|
||||
parsed_content: TokensPrompt,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
) -> TokenInputs | MultiModalInputs:
|
||||
prompt_token_ids = self._truncate_inputs(
|
||||
parsed_content["prompt_token_ids"], tokenization_kwargs
|
||||
@@ -182,11 +146,13 @@ class InputPreprocessor:
|
||||
multi_modal_data,
|
||||
parsed_content.get("mm_processor_kwargs") or {},
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
mm_uuids=parsed_content.get("multi_modal_uuids"),
|
||||
)
|
||||
else:
|
||||
inputs = token_inputs(prompt_token_ids)
|
||||
|
||||
if prompt_text := parsed_content.get("prompt"):
|
||||
inputs["prompt"] = prompt_text
|
||||
if cache_salt := parsed_content.get("cache_salt"):
|
||||
inputs["cache_salt"] = cache_salt
|
||||
|
||||
@@ -196,8 +162,6 @@ class InputPreprocessor:
|
||||
self,
|
||||
parsed_content: TextPrompt,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
) -> TokenInputs | MultiModalInputs:
|
||||
prompt_text = parsed_content["prompt"]
|
||||
|
||||
@@ -208,7 +172,6 @@ class InputPreprocessor:
|
||||
multi_modal_data,
|
||||
parsed_content.get("mm_processor_kwargs") or {},
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
else:
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
@@ -217,6 +180,8 @@ class InputPreprocessor:
|
||||
)
|
||||
inputs = token_inputs(prompt_token_ids)
|
||||
|
||||
inputs["prompt"] = prompt_text
|
||||
|
||||
if cache_salt := parsed_content.get("cache_salt"):
|
||||
inputs["cache_salt"] = cache_salt
|
||||
|
||||
@@ -227,8 +192,6 @@ class InputPreprocessor:
|
||||
self,
|
||||
prompt: EncoderDictPrompt,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
) -> EncoderInputs: ...
|
||||
|
||||
@overload
|
||||
@@ -236,8 +199,6 @@ class InputPreprocessor:
|
||||
self,
|
||||
prompt: DecoderDictPrompt,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
) -> DecoderInputs: ...
|
||||
|
||||
@overload
|
||||
@@ -245,16 +206,12 @@ class InputPreprocessor:
|
||||
self,
|
||||
prompt: DecoderOnlyDictPrompt,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
) -> DecoderOnlyInputs: ...
|
||||
|
||||
def _prompt_to_llm_inputs(
|
||||
self,
|
||||
prompt: SingletonDictPrompt,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
) -> SingletonInputs:
|
||||
"""
|
||||
Extract the singleton inputs from a prompt.
|
||||
@@ -271,16 +228,12 @@ class InputPreprocessor:
|
||||
return self._process_embeds(prompt) # type: ignore[arg-type]
|
||||
|
||||
if "prompt_token_ids" in prompt:
|
||||
return self._process_tokens(
|
||||
prompt, # type: ignore[arg-type]
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
return self._process_tokens(prompt) # type: ignore[arg-type]
|
||||
|
||||
if "prompt" in prompt:
|
||||
return self._process_text(
|
||||
prompt, # type: ignore[arg-type]
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
assert_never(prompt) # type: ignore[arg-type]
|
||||
@@ -289,8 +242,6 @@ class InputPreprocessor:
|
||||
self,
|
||||
prompt: EncoderDecoderDictPrompt,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
) -> EncoderDecoderInputs:
|
||||
"""
|
||||
For encoder/decoder models only:
|
||||
@@ -314,7 +265,6 @@ class InputPreprocessor:
|
||||
encoder_inputs=self._prompt_to_llm_inputs(
|
||||
encoder_prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
),
|
||||
decoder_inputs=(
|
||||
None
|
||||
@@ -331,8 +281,6 @@ class InputPreprocessor:
|
||||
self,
|
||||
prompt: DecoderOnlyDictPrompt,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
) -> DecoderOnlyInputs:
|
||||
"""
|
||||
For decoder-only models:
|
||||
@@ -350,41 +298,23 @@ class InputPreprocessor:
|
||||
return self._prompt_to_llm_inputs(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
def _preprocess(
|
||||
def preprocess(
|
||||
self,
|
||||
prompt: PromptType | DictPrompt | TokPrompt,
|
||||
prompt: PromptType,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
) -> ProcessorInputs:
|
||||
"""Preprocess the input prompt."""
|
||||
if self.model_config.is_encoder_decoder:
|
||||
# Encoder-decoder model requires special mapping of
|
||||
# input prompts to encoder & decoder.
|
||||
return self._process_encoder_decoder_prompt(
|
||||
parse_enc_dec_prompt(prompt),
|
||||
tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
return self._process_decoder_only_prompt(
|
||||
parse_dec_only_prompt(prompt),
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
def preprocess(
|
||||
self,
|
||||
prompt: PromptType | DictPrompt | TokPrompt,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
*,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
) -> ProcessorInputs:
|
||||
"""Preprocess the input prompt."""
|
||||
res = self._preprocess(prompt, tokenization_kwargs, mm_uuids=mm_uuids)
|
||||
|
||||
self.renderer.update_mm_cache_stats()
|
||||
|
||||
return res
|
||||
|
||||
@@ -48,7 +48,6 @@ from vllm.multimodal.processing import (
|
||||
BaseProcessingInfo,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
)
|
||||
from vllm.transformers_utils.processor import cached_processor_from_config
|
||||
from vllm.transformers_utils.processors.funasr_processor import FunASRFeatureExtractor
|
||||
@@ -810,13 +809,7 @@ class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]):
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
|
||||
# Use getattr with default to be compatible with transformers<4.48
|
||||
audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
|
||||
|
||||
audio_token_id = vocab[audio_token]
|
||||
audio_token_id = processor.audio_token_id
|
||||
|
||||
out_mm_data = out_mm_kwargs.get_data()
|
||||
|
||||
@@ -836,17 +829,12 @@ class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]):
|
||||
assert len(audio_embeds.shape) == 2, "audio_embeds must be a 2D tensor"
|
||||
num_features = audio_embeds.shape[0]
|
||||
|
||||
audio_tokens = [audio_token_id] * num_features
|
||||
|
||||
return PromptUpdateDetails.select_token_id(
|
||||
audio_tokens,
|
||||
embed_token_id=audio_token_id,
|
||||
)
|
||||
return [audio_token_id] * num_features
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="audio",
|
||||
target=audio_token,
|
||||
target=[audio_token_id],
|
||||
replacement=get_replacement_qwen2_audio,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -59,7 +59,6 @@ from vllm.multimodal.processing import (
|
||||
BaseProcessingInfo,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
@@ -187,8 +186,10 @@ class Qwen2AudioDummyInputsBuilder(BaseDummyInputsBuilder[Qwen2AudioProcessingIn
|
||||
|
||||
hf_processor = self.info.get_hf_processor()
|
||||
audio_token = hf_processor.audio_token
|
||||
audio_bos_token = hf_processor.audio_bos_token
|
||||
audio_eos_token = hf_processor.audio_eos_token
|
||||
|
||||
return audio_token * num_audios
|
||||
return (audio_bos_token + audio_token + audio_eos_token) * num_audios
|
||||
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
@@ -262,17 +263,7 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor[Qwen2AudioProcessing
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
|
||||
# Use getattr with default to be compatible with transformers<4.48
|
||||
audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
|
||||
audio_bos_token = getattr(processor, "audio_bos_token", "<|audio_bos|>")
|
||||
audio_eos_token = getattr(processor, "audio_eos_token", "<|audio_eos|>")
|
||||
|
||||
audio_token_id = vocab[audio_token]
|
||||
audio_bos_id = vocab[audio_bos_token]
|
||||
audio_eos_id = vocab[audio_eos_token]
|
||||
audio_token_id = processor.audio_token_id
|
||||
|
||||
out_mm_data = out_mm_kwargs.get_data()
|
||||
feature_attention_mask = out_mm_data.get("feature_attention_mask")
|
||||
@@ -303,17 +294,12 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor[Qwen2AudioProcessing
|
||||
"to be represented inside the model"
|
||||
)
|
||||
|
||||
audio_tokens = [audio_token_id] * num_features
|
||||
|
||||
return PromptUpdateDetails.select_token_id(
|
||||
[audio_bos_id] + audio_tokens + [audio_eos_id],
|
||||
embed_token_id=audio_token_id,
|
||||
)
|
||||
return [audio_token_id] * num_features
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="audio",
|
||||
target=audio_token,
|
||||
target=[audio_token_id],
|
||||
replacement=get_replacement_qwen2_audio,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -1843,15 +1843,18 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
decoder_prompt_raw = self.create_decoder_prompt(prompt, mm_items)
|
||||
if isinstance(decoder_prompt_raw, str):
|
||||
decoder_prompt_text = decoder_prompt_raw
|
||||
decoder_prompt_ids = tokenizer.encode(
|
||||
decoder_prompt_raw, add_special_tokens=False
|
||||
)
|
||||
else:
|
||||
decoder_prompt_text = None
|
||||
decoder_prompt_ids = decoder_prompt_raw
|
||||
|
||||
return mm_enc_dec_inputs(
|
||||
encoder_inputs,
|
||||
decoder_prompt_ids,
|
||||
decoder_prompt=decoder_prompt_text,
|
||||
)
|
||||
|
||||
def apply(
|
||||
|
||||
@@ -19,7 +19,6 @@ if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import ProcessorInputs, PromptType
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers.inputs import DictPrompt, TokPrompt
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.v1.attention.selector import AttentionSelectorConfig
|
||||
@@ -569,7 +568,7 @@ class Platform:
|
||||
@classmethod
|
||||
def validate_request(
|
||||
cls,
|
||||
prompt: "PromptType | DictPrompt | TokPrompt",
|
||||
prompt: "PromptType | ProcessorInputs",
|
||||
params: "SamplingParams | PoolingParams",
|
||||
processed_inputs: "ProcessorInputs",
|
||||
) -> None:
|
||||
|
||||
@@ -1,17 +1,29 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, Generic, overload
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.inputs import (
|
||||
EmbedsInputs,
|
||||
EmbedsPrompt,
|
||||
EncoderDecoderInputs,
|
||||
ProcessorInputs,
|
||||
SingletonInputs,
|
||||
TextPrompt,
|
||||
TokenInputs,
|
||||
TokensPrompt,
|
||||
)
|
||||
from vllm.inputs.data import build_enc_dec_inputs, embeds_inputs, token_inputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
|
||||
from vllm.utils.counter import AtomicCounter
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
from vllm.v1.metrics.stats import MultiModalCacheStats
|
||||
|
||||
@@ -20,6 +32,8 @@ from .inputs import (
|
||||
DictPrompt,
|
||||
EncoderDecoderDictPrompt,
|
||||
EncoderDecoderTokPrompt,
|
||||
SingletonDictPrompt,
|
||||
SingletonTokPrompt,
|
||||
TokPrompt,
|
||||
)
|
||||
from .inputs.preprocess import extract_target_prompt
|
||||
@@ -32,6 +46,12 @@ if TYPE_CHECKING:
|
||||
ConversationMessage,
|
||||
)
|
||||
from vllm.multimodal.cache import BaseMultiModalProcessorCache
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalInputs,
|
||||
MultiModalUUIDDict,
|
||||
)
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -79,6 +99,10 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
if mm_processor_cache:
|
||||
self._mm_cache_stats = MultiModalCacheStats()
|
||||
|
||||
# This is used to generate internal request ID for MM processing
|
||||
# It has no relation to the request ID for engine core
|
||||
self._mm_req_counter = AtomicCounter()
|
||||
|
||||
def get_tokenizer(self) -> _T:
|
||||
tokenizer = self.tokenizer
|
||||
if tokenizer is None:
|
||||
@@ -284,17 +308,79 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
|
||||
return prompt
|
||||
|
||||
@overload
|
||||
def _tokenize_singleton_prompt(
|
||||
self,
|
||||
prompt: TextPrompt | TokensPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> TokensPrompt: ...
|
||||
|
||||
@overload
|
||||
def _tokenize_singleton_prompt( # type: ignore[misc]
|
||||
self,
|
||||
prompt: EmbedsPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> EmbedsPrompt: ...
|
||||
|
||||
def _tokenize_singleton_prompt(
|
||||
self,
|
||||
prompt: SingletonDictPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> SingletonTokPrompt:
|
||||
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
|
||||
prompt = params.apply_pre_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
|
||||
prompt = self._tokenize_prompt(prompt, params)
|
||||
|
||||
if params.needs_detokenization and "prompt" not in prompt:
|
||||
if "prompt_token_ids" not in prompt:
|
||||
raise RuntimeError("Cannot run detokenization on embeddings")
|
||||
|
||||
prompt = self._detokenize_prompt(prompt) # type: ignore[arg-type]
|
||||
|
||||
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
|
||||
|
||||
@overload
|
||||
async def _tokenize_singleton_prompt_async(
|
||||
self,
|
||||
prompt: TextPrompt | TokensPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> TokensPrompt: ...
|
||||
|
||||
@overload
|
||||
async def _tokenize_singleton_prompt_async( # type: ignore[misc]
|
||||
self,
|
||||
prompt: EmbedsPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> EmbedsPrompt: ...
|
||||
|
||||
async def _tokenize_singleton_prompt_async(
|
||||
self,
|
||||
prompt: SingletonDictPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> SingletonTokPrompt:
|
||||
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
|
||||
prompt = params.apply_pre_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
|
||||
prompt = await self._tokenize_prompt_async(prompt, params)
|
||||
|
||||
if params.needs_detokenization and "prompt" not in prompt:
|
||||
if "prompt_token_ids" not in prompt:
|
||||
raise RuntimeError("Cannot run detokenization on embeddings")
|
||||
|
||||
prompt = await self._detokenize_prompt_async(prompt) # type: ignore[arg-type]
|
||||
|
||||
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
|
||||
|
||||
def _tokenize_enc_dec_prompt(
|
||||
self,
|
||||
prompt: EncoderDecoderDictPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> EncoderDecoderTokPrompt:
|
||||
enc_prompt, dec_prompt = (
|
||||
self.tokenize_prompt(prompt["encoder_prompt"], params),
|
||||
self._tokenize_singleton_prompt(prompt["encoder_prompt"], params),
|
||||
(
|
||||
None
|
||||
if prompt["decoder_prompt"] is None
|
||||
else self.tokenize_prompt(prompt["decoder_prompt"], params)
|
||||
else self._tokenize_singleton_prompt(prompt["decoder_prompt"], params)
|
||||
),
|
||||
)
|
||||
|
||||
@@ -309,11 +395,13 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
params: TokenizeParams,
|
||||
) -> EncoderDecoderTokPrompt:
|
||||
enc_prompt, dec_prompt = await asyncio.gather(
|
||||
self.tokenize_prompt_async(prompt["encoder_prompt"], params),
|
||||
self._tokenize_singleton_prompt_async(prompt["encoder_prompt"], params),
|
||||
(
|
||||
asyncio.sleep(0)
|
||||
if prompt["decoder_prompt"] is None
|
||||
else self.tokenize_prompt_async(prompt["decoder_prompt"], params)
|
||||
else self._tokenize_singleton_prompt_async(
|
||||
prompt["decoder_prompt"], params
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@@ -322,27 +410,6 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
decoder_prompt=dec_prompt,
|
||||
)
|
||||
|
||||
@overload
|
||||
def tokenize_prompt(
|
||||
self,
|
||||
prompt: TextPrompt | TokensPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> TokensPrompt: ...
|
||||
|
||||
@overload
|
||||
def tokenize_prompt( # type: ignore[misc]
|
||||
self,
|
||||
prompt: EmbedsPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> EmbedsPrompt: ...
|
||||
|
||||
@overload
|
||||
def tokenize_prompt( # type: ignore[misc]
|
||||
self,
|
||||
prompt: EncoderDecoderDictPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> EncoderDecoderTokPrompt: ...
|
||||
|
||||
def tokenize_prompt(
|
||||
self,
|
||||
prompt: DictPrompt,
|
||||
@@ -351,17 +418,7 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
if "encoder_prompt" in prompt:
|
||||
return self._tokenize_enc_dec_prompt(prompt, params) # type: ignore[arg-type]
|
||||
|
||||
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
|
||||
prompt = params.apply_pre_tokenization(self.tokenizer, prompt)
|
||||
prompt = self._tokenize_prompt(prompt, params)
|
||||
|
||||
if params.needs_detokenization and "prompt" not in prompt:
|
||||
if "prompt_token_ids" not in prompt:
|
||||
raise RuntimeError("Cannot run detokenization on embeddings")
|
||||
|
||||
prompt = self._detokenize_prompt(prompt) # type: ignore[arg-type]
|
||||
|
||||
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
|
||||
return self._tokenize_singleton_prompt(prompt, params)
|
||||
|
||||
def tokenize_prompts(
|
||||
self,
|
||||
@@ -370,27 +427,6 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
) -> list[TokPrompt]:
|
||||
return [self.tokenize_prompt(prompt, params) for prompt in prompts]
|
||||
|
||||
@overload
|
||||
async def tokenize_prompt_async(
|
||||
self,
|
||||
prompt: TextPrompt | TokensPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> TokensPrompt: ...
|
||||
|
||||
@overload
|
||||
async def tokenize_prompt_async( # type: ignore[misc]
|
||||
self,
|
||||
prompt: EmbedsPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> EmbedsPrompt: ...
|
||||
|
||||
@overload
|
||||
async def tokenize_prompt_async( # type: ignore[misc]
|
||||
self,
|
||||
prompt: EncoderDecoderDictPrompt,
|
||||
params: TokenizeParams,
|
||||
) -> EncoderDecoderTokPrompt: ...
|
||||
|
||||
async def tokenize_prompt_async(
|
||||
self,
|
||||
prompt: DictPrompt,
|
||||
@@ -399,17 +435,7 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
if "encoder_prompt" in prompt:
|
||||
return await self._tokenize_enc_dec_prompt_async(prompt, params) # type: ignore[arg-type]
|
||||
|
||||
if "prompt_token_ids" not in prompt and "prompt_embeds" not in prompt:
|
||||
prompt = params.apply_pre_tokenization(self.tokenizer, prompt)
|
||||
prompt = await self._tokenize_prompt_async(prompt, params)
|
||||
|
||||
if params.needs_detokenization and "prompt" not in prompt:
|
||||
if "prompt_token_ids" not in prompt:
|
||||
raise RuntimeError("Cannot run detokenization on embeddings")
|
||||
|
||||
prompt = await self._detokenize_prompt_async(prompt) # type: ignore[arg-type]
|
||||
|
||||
return params.apply_post_tokenization(self.tokenizer, prompt) # type: ignore[arg-type]
|
||||
return await self._tokenize_singleton_prompt_async(prompt, params)
|
||||
|
||||
async def tokenize_prompts_async(
|
||||
self,
|
||||
@@ -423,7 +449,7 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
# Step 3: Add extra keys to the prompts
|
||||
def _apply_prompt_extras(
|
||||
self,
|
||||
prompts: Sequence[DictPrompt | TokPrompt],
|
||||
prompts: Sequence[TokPrompt],
|
||||
prompt_extras: dict[str, Any] | None,
|
||||
):
|
||||
if not prompt_extras:
|
||||
@@ -433,6 +459,200 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
target_prompt = extract_target_prompt(self.model_config, prompt)
|
||||
target_prompt.update(prompt_extras) # type: ignore[arg-type]
|
||||
|
||||
# Step 4: Convert to engine inputs
|
||||
def _validate_mm_uuids(
|
||||
self,
|
||||
mm_data: "MultiModalDataDict",
|
||||
mm_items: "MultiModalDataItems",
|
||||
mm_uuids: "MultiModalUUIDDict | None",
|
||||
) -> None:
|
||||
if mm_uuids is None:
|
||||
mm_uuids = {}
|
||||
|
||||
# NOTE: Keys corresponding to `None` in `mm_data` don't appear in `mm_items`
|
||||
modalities = mm_data.keys() | mm_uuids.keys()
|
||||
|
||||
for modality in modalities:
|
||||
data_items = mm_items.get(modality) or list[Any]()
|
||||
|
||||
uuid_items = mm_uuids.get(modality) or list[str | None]()
|
||||
if isinstance(uuid_items, str):
|
||||
uuid_items = [uuid_items]
|
||||
|
||||
if len(data_items) > 0:
|
||||
if len(uuid_items) > 0 and len(data_items) != len(uuid_items):
|
||||
raise ValueError(
|
||||
f"If given, multi_modal_uuids[{modality!r}] must have "
|
||||
f"same length as multi_modal_data[{modality!r}], but "
|
||||
f"got {len(uuid_items)} vs {len(data_items)}."
|
||||
)
|
||||
|
||||
for i, item in enumerate(data_items):
|
||||
if item is None:
|
||||
if not uuid_items:
|
||||
raise ValueError(
|
||||
f"multi_modal_data[{modality!r}][{i}] is empty but "
|
||||
f"multi_modal_uuids[{modality!r}] is missing."
|
||||
)
|
||||
|
||||
if uuid_items[i] is None:
|
||||
raise ValueError(
|
||||
f"multi_modal_data[{modality!r}][{i}] is empty but "
|
||||
f"multi_modal_uuids[{modality!r}][{i}] is missing."
|
||||
)
|
||||
|
||||
def _process_mm_uuids(
|
||||
self,
|
||||
mm_data: "MultiModalDataDict",
|
||||
mm_items: "MultiModalDataItems",
|
||||
mm_uuids: "MultiModalUUIDDict | None",
|
||||
mm_req_id: str,
|
||||
):
|
||||
model_config = self.model_config
|
||||
|
||||
# NOTE: When users explicitly turn off BOTH prefix caching and input
|
||||
# processing caching, no multimodal features or embeddings will be
|
||||
# reused across requests, therefore identifying multimodal data items
|
||||
# by their content is no longer necessary, and we create uuids with
|
||||
# `<mm_req_id>-<modality>-<index>`, overriding even user-provided ones.
|
||||
if (
|
||||
model_config.multimodal_config
|
||||
and model_config.multimodal_config.mm_processor_cache_gb == 0
|
||||
and not self.config.cache_config.enable_prefix_caching
|
||||
):
|
||||
mm_uuids = {
|
||||
modality: [f"{mm_req_id}-{modality}-{i}" for i in range(data_count)]
|
||||
for modality, data_count in mm_items.get_all_counts().items()
|
||||
}
|
||||
|
||||
self._validate_mm_uuids(mm_data, mm_items, mm_uuids)
|
||||
|
||||
return mm_uuids
|
||||
|
||||
# TODO: Remove str and tokenization_kwargs after deprecating InputPreprocessor
|
||||
def _process_multimodal(
|
||||
self,
|
||||
prompt: list[int] | str,
|
||||
mm_data: "MultiModalDataDict",
|
||||
mm_processor_kwargs: Mapping[str, object] | None,
|
||||
tokenization_kwargs: dict[str, Any] | None,
|
||||
mm_uuids: "MultiModalUUIDDict | None",
|
||||
) -> "MultiModalInputs":
|
||||
from vllm.multimodal.processing.context import set_request_id
|
||||
|
||||
mm_req_id = f"renderer-mm-{self._mm_req_counter.inc(1)}"
|
||||
|
||||
mm_processor = self.get_mm_processor()
|
||||
|
||||
mm_items = mm_processor.info.parse_mm_data(mm_data)
|
||||
mm_uuids = self._process_mm_uuids(mm_data, mm_items, mm_uuids, mm_req_id)
|
||||
|
||||
with set_request_id(mm_req_id), set_default_torch_num_threads():
|
||||
mm_inputs = mm_processor.apply(
|
||||
prompt,
|
||||
mm_items,
|
||||
hf_processor_mm_kwargs=mm_processor_kwargs or {},
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
self.update_mm_cache_stats()
|
||||
|
||||
return mm_inputs
|
||||
|
||||
def _process_tokens(
|
||||
self,
|
||||
prompt: TokensPrompt,
|
||||
) -> "TokenInputs | MultiModalInputs":
|
||||
prompt_token_ids = prompt["prompt_token_ids"]
|
||||
|
||||
inputs: TokenInputs | MultiModalInputs
|
||||
if multi_modal_data := prompt.get("multi_modal_data"):
|
||||
inputs = self._process_multimodal(
|
||||
prompt_token_ids,
|
||||
multi_modal_data,
|
||||
mm_processor_kwargs=prompt.get("mm_processor_kwargs"),
|
||||
tokenization_kwargs=None, # Tokenization already done in Step 2
|
||||
mm_uuids=prompt.get("multi_modal_uuids"),
|
||||
)
|
||||
else:
|
||||
inputs = token_inputs(prompt_token_ids)
|
||||
|
||||
if prompt_text := prompt.get("prompt"):
|
||||
inputs["prompt"] = prompt_text
|
||||
if cache_salt := prompt.get("cache_salt"):
|
||||
inputs["cache_salt"] = cache_salt
|
||||
|
||||
return inputs
|
||||
|
||||
def _process_embeds(
|
||||
self,
|
||||
prompt: EmbedsPrompt,
|
||||
) -> EmbedsInputs:
|
||||
if not self.model_config.enable_prompt_embeds:
|
||||
raise ValueError(
|
||||
"You must set `--enable-prompt-embeds` to input `prompt_embeds`."
|
||||
)
|
||||
|
||||
prompt_embeds = prompt["prompt_embeds"]
|
||||
|
||||
# prompt_embeds must be (seq_len, hidden_size), but if the user
|
||||
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
|
||||
# we can unambiguously process the intent by squeezing the batch
|
||||
# dimension.
|
||||
if prompt_embeds.ndim == 3:
|
||||
prompt_embeds = prompt_embeds.squeeze(dim=0)
|
||||
|
||||
if prompt_embeds.ndim != 2:
|
||||
raise ValueError("prompt_embeds must be of shape (seq_len, hidden_size).")
|
||||
|
||||
# Tensors must be on CPU for serialization between processes
|
||||
# in the MsgpackEncoder. Casting to CPU here ensures that there is no
|
||||
# hidden device transfer in the critical path of generation.
|
||||
prompt_embeds = prompt_embeds.cpu()
|
||||
|
||||
return embeds_inputs(
|
||||
prompt_embeds=prompt_embeds,
|
||||
cache_salt=prompt.get("cache_salt"),
|
||||
)
|
||||
|
||||
def _process_singleton(
|
||||
self,
|
||||
prompt: SingletonTokPrompt,
|
||||
) -> SingletonInputs:
|
||||
if "prompt_embeds" in prompt:
|
||||
return self._process_embeds(prompt) # type: ignore[arg-type]
|
||||
|
||||
return self._process_tokens(prompt) # type: ignore[arg-type]
|
||||
|
||||
def _process_enc_dec(
|
||||
self,
|
||||
prompt: EncoderDecoderTokPrompt,
|
||||
) -> EncoderDecoderInputs:
|
||||
enc_prompt = prompt["encoder_prompt"]
|
||||
dec_prompt = prompt["decoder_prompt"]
|
||||
|
||||
return build_enc_dec_inputs(
|
||||
encoder_inputs=self._process_singleton(enc_prompt),
|
||||
decoder_inputs=(
|
||||
None if dec_prompt is None else self._process_singleton(dec_prompt)
|
||||
),
|
||||
decoder_start_token_id=self.get_dec_start_token_id(),
|
||||
)
|
||||
|
||||
def process_for_engine(
|
||||
self, prompt: TokPrompt, arrival_time: float
|
||||
) -> ProcessorInputs:
|
||||
engine_prompt: ProcessorInputs
|
||||
if "encoder_prompt" in prompt:
|
||||
engine_prompt = self._process_enc_dec(prompt) # type: ignore[arg-type]
|
||||
else:
|
||||
engine_prompt = self._process_singleton(prompt)
|
||||
|
||||
engine_prompt["arrival_time"] = arrival_time
|
||||
|
||||
return engine_prompt
|
||||
|
||||
# Top-level methods
|
||||
def render_cmpl(
|
||||
self,
|
||||
@@ -441,6 +661,8 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
*,
|
||||
prompt_extras: dict[str, Any] | None = None,
|
||||
):
|
||||
arrival_time = time.time()
|
||||
|
||||
if tok_params is None:
|
||||
tok_params = self.default_cmpl_tok_params
|
||||
|
||||
@@ -449,8 +671,7 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
|
||||
self._apply_prompt_extras(tok_prompts, prompt_extras)
|
||||
|
||||
# TODO: Apply multi-modal processor
|
||||
return tok_prompts
|
||||
return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts]
|
||||
|
||||
async def render_cmpl_async(
|
||||
self,
|
||||
@@ -459,6 +680,8 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
*,
|
||||
prompt_extras: dict[str, Any] | None = None,
|
||||
):
|
||||
arrival_time = time.time()
|
||||
|
||||
if tok_params is None:
|
||||
tok_params = self.default_cmpl_tok_params
|
||||
|
||||
@@ -467,8 +690,7 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
|
||||
self._apply_prompt_extras(tok_prompts, prompt_extras)
|
||||
|
||||
# TODO: Apply multi-modal processor
|
||||
return tok_prompts
|
||||
return [self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts]
|
||||
|
||||
def render_chat(
|
||||
self,
|
||||
@@ -478,6 +700,8 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
*,
|
||||
prompt_extras: dict[str, Any] | None = None,
|
||||
):
|
||||
arrival_time = time.time()
|
||||
|
||||
if tok_params is None:
|
||||
tok_params = self.default_chat_tok_params
|
||||
|
||||
@@ -496,8 +720,11 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
|
||||
self._apply_prompt_extras(tok_prompts, prompt_extras)
|
||||
|
||||
# TODO: Apply multi-modal processor
|
||||
return out_conversations, tok_prompts
|
||||
eng_prompts = [
|
||||
self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts
|
||||
]
|
||||
|
||||
return out_conversations, eng_prompts
|
||||
|
||||
async def render_chat_async(
|
||||
self,
|
||||
@@ -507,6 +734,8 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
*,
|
||||
prompt_extras: dict[str, Any] | None = None,
|
||||
):
|
||||
arrival_time = time.time()
|
||||
|
||||
if tok_params is None:
|
||||
tok_params = self.default_chat_tok_params
|
||||
|
||||
@@ -525,5 +754,8 @@ class BaseRenderer(ABC, Generic[_T]):
|
||||
|
||||
self._apply_prompt_extras(tok_prompts, prompt_extras)
|
||||
|
||||
# TODO: Apply multi-modal processor
|
||||
return out_conversations, tok_prompts
|
||||
eng_prompts = [
|
||||
self.process_for_engine(prompt, arrival_time) for prompt in tok_prompts
|
||||
]
|
||||
|
||||
return out_conversations, eng_prompts
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, NamedTuple, TypeAlias, TypedDict, overload
|
||||
from vllm.inputs import (
|
||||
EmbedsPrompt,
|
||||
ExplicitEncoderDecoderPrompt,
|
||||
ProcessorInputs,
|
||||
PromptType,
|
||||
SingletonPrompt,
|
||||
TextPrompt,
|
||||
@@ -115,7 +116,7 @@ that has been standardized into a dictionary.
|
||||
"""
|
||||
|
||||
|
||||
def parse_dec_only_prompt(prompt: object) -> DecoderOnlyDictPrompt:
|
||||
def parse_dec_only_prompt(prompt: PromptType | object) -> DecoderOnlyDictPrompt:
|
||||
"""
|
||||
Parse a prompt for a decoder-only model and normalize it to a dictionary.
|
||||
"""
|
||||
@@ -144,7 +145,7 @@ def parse_dec_only_prompt(prompt: object) -> DecoderOnlyDictPrompt:
|
||||
raise TypeError("Prompt should be a string, list of tokens, or dictionary")
|
||||
|
||||
|
||||
def _parse_enc_prompt(prompt: object) -> EncoderDictPrompt:
|
||||
def _parse_enc_prompt(prompt: PromptType | object) -> EncoderDictPrompt:
|
||||
if isinstance(prompt, str):
|
||||
return TextPrompt(prompt=prompt)
|
||||
|
||||
@@ -166,7 +167,7 @@ def _parse_enc_prompt(prompt: object) -> EncoderDictPrompt:
|
||||
raise TypeError("Prompt should be a string, list of tokens, or dictionary")
|
||||
|
||||
|
||||
def _parse_dec_prompt(prompt: object) -> DecoderDictPrompt:
|
||||
def _parse_dec_prompt(prompt: PromptType | object) -> DecoderDictPrompt:
|
||||
if isinstance(prompt, str):
|
||||
return TextPrompt(prompt=prompt)
|
||||
|
||||
@@ -195,13 +196,13 @@ def _parse_dec_prompt(prompt: object) -> DecoderDictPrompt:
|
||||
raise TypeError("Prompt should be a string, list of tokens, or dictionary")
|
||||
|
||||
|
||||
def parse_enc_dec_prompt(prompt: object) -> EncoderDecoderDictPrompt:
|
||||
def parse_enc_dec_prompt(prompt: PromptType | object) -> EncoderDecoderDictPrompt:
|
||||
"""
|
||||
Parse a prompt for an encoder-decoder model and normalize it to a dictionary.
|
||||
"""
|
||||
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
|
||||
enc_prompt: object = prompt["encoder_prompt"] # type: ignore[typeddict-item]
|
||||
dec_prompt: object | None = prompt["decoder_prompt"] # type: ignore[typeddict-item]
|
||||
enc_prompt = prompt["encoder_prompt"] # type: ignore[typeddict-item]
|
||||
dec_prompt = prompt["decoder_prompt"] # type: ignore[typeddict-item]
|
||||
else:
|
||||
enc_prompt = prompt
|
||||
dec_prompt = None
|
||||
@@ -235,21 +236,23 @@ def extract_target_prompt(model_config: "ModelConfig", prompt: object):
|
||||
|
||||
def extract_prompt_components(
|
||||
model_config: "ModelConfig",
|
||||
prompt: object,
|
||||
prompt: PromptType | ProcessorInputs,
|
||||
) -> PromptComponents:
|
||||
target_prompt = extract_target_prompt(model_config, prompt)
|
||||
|
||||
return PromptComponents(
|
||||
text=target_prompt.get("prompt"),
|
||||
token_ids=target_prompt.get("prompt_token_ids"), # type: ignore[arg-type]
|
||||
token_ids=target_prompt.get("prompt_token_ids"),
|
||||
embeds=target_prompt.get("prompt_embeds"),
|
||||
)
|
||||
|
||||
|
||||
def extract_prompt_len(model_config: "ModelConfig", prompt: object):
|
||||
def extract_prompt_len(
|
||||
model_config: "ModelConfig", prompt: PromptType | ProcessorInputs
|
||||
):
|
||||
target_prompt = extract_target_prompt(model_config, prompt)
|
||||
|
||||
return length_from_prompt_token_ids_or_embeds(
|
||||
target_prompt.get("prompt_token_ids"), # type: ignore[arg-type]
|
||||
target_prompt.get("prompt_token_ids"),
|
||||
target_prompt.get("prompt_embeds"),
|
||||
)
|
||||
|
||||
39
vllm/utils/tqdm_utils.py
Normal file
39
vllm/utils/tqdm_utils.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from typing import Any, TypeVar, overload
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
_T = TypeVar("_T", bound=Iterable)
|
||||
|
||||
|
||||
@overload
|
||||
def maybe_tqdm(
|
||||
it: Sequence[_T],
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm],
|
||||
**tqdm_kwargs: Any,
|
||||
) -> Sequence[_T]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def maybe_tqdm(
|
||||
it: Iterable[_T],
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm],
|
||||
**tqdm_kwargs: Any,
|
||||
) -> Iterable[_T]: ...
|
||||
|
||||
|
||||
def maybe_tqdm(
|
||||
it: Iterable[_T],
|
||||
*,
|
||||
use_tqdm: bool | Callable[..., tqdm],
|
||||
**tqdm_kwargs: Any,
|
||||
) -> Iterable[_T]:
|
||||
if not use_tqdm:
|
||||
return it
|
||||
|
||||
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
|
||||
return tqdm_func(it, **tqdm_kwargs)
|
||||
@@ -20,7 +20,7 @@ from vllm.distributed.weight_transfer.base import (
|
||||
)
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.protocol import EngineClient, StreamingInput
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.inputs import ProcessorInputs, PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
@@ -28,7 +28,6 @@ from vllm.outputs import STREAM_FINISHED, PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import merge_kwargs, renderer_from_config
|
||||
from vllm.renderers.inputs import DictPrompt, TokPrompt
|
||||
from vllm.renderers.inputs.preprocess import extract_prompt_components
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
@@ -290,8 +289,7 @@ class AsyncLLM(EngineClient):
|
||||
request_id: str,
|
||||
prompt: EngineCoreRequest
|
||||
| PromptType
|
||||
| DictPrompt
|
||||
| TokPrompt
|
||||
| ProcessorInputs
|
||||
| AsyncGenerator[StreamingInput, None],
|
||||
params: SamplingParams | PoolingParams,
|
||||
arrival_time: float | None = None,
|
||||
@@ -301,6 +299,7 @@ class AsyncLLM(EngineClient):
|
||||
priority: int = 0,
|
||||
data_parallel_rank: int | None = None,
|
||||
prompt_text: str | None = None,
|
||||
reasoning_ended: bool | None = None,
|
||||
) -> RequestOutputCollector:
|
||||
"""Add new request to the AsyncLLM."""
|
||||
|
||||
@@ -336,6 +335,9 @@ class AsyncLLM(EngineClient):
|
||||
)
|
||||
|
||||
if isinstance(prompt, AsyncGenerator):
|
||||
if reasoning_ended is not None:
|
||||
raise NotImplementedError
|
||||
|
||||
# Streaming input case.
|
||||
return await self._add_streaming_input_request(
|
||||
request_id,
|
||||
@@ -359,10 +361,6 @@ class AsyncLLM(EngineClient):
|
||||
"latter will be used, and the former will be ignored."
|
||||
)
|
||||
else:
|
||||
if prompt_text is not None:
|
||||
raise ValueError(
|
||||
"should only provide prompt_text with EngineCoreRequest"
|
||||
)
|
||||
request = self.input_processor.process_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
@@ -377,6 +375,9 @@ class AsyncLLM(EngineClient):
|
||||
)
|
||||
prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
|
||||
|
||||
if reasoning_ended is not None:
|
||||
request.reasoning_ended = reasoning_ended
|
||||
|
||||
self.input_processor.assign_request_id(request)
|
||||
|
||||
# We start the output_handler on the first call to add_request() so
|
||||
@@ -536,8 +537,7 @@ class AsyncLLM(EngineClient):
|
||||
self,
|
||||
prompt: EngineCoreRequest
|
||||
| PromptType
|
||||
| DictPrompt
|
||||
| TokPrompt
|
||||
| ProcessorInputs
|
||||
| AsyncGenerator[StreamingInput, None],
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
@@ -548,6 +548,7 @@ class AsyncLLM(EngineClient):
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: int | None = None,
|
||||
reasoning_ended: bool | None = None,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
"""
|
||||
Main function called by the API server to kick off a request
|
||||
@@ -576,6 +577,7 @@ class AsyncLLM(EngineClient):
|
||||
priority=priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
prompt_text=prompt_text,
|
||||
reasoning_ended=reasoning_ended,
|
||||
)
|
||||
|
||||
# The output_handler task pushes items into the queue.
|
||||
@@ -770,13 +772,14 @@ class AsyncLLM(EngineClient):
|
||||
|
||||
async def encode(
|
||||
self,
|
||||
prompt: PromptType | DictPrompt | TokPrompt,
|
||||
prompt: PromptType | ProcessorInputs,
|
||||
pooling_params: PoolingParams,
|
||||
request_id: str,
|
||||
lora_request: LoRARequest | None = None,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
priority: int = 0,
|
||||
tokenization_kwargs: dict[str, Any] | None = None,
|
||||
reasoning_ended: bool | None = None,
|
||||
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
||||
"""
|
||||
Main function called by the API server to kick off a request
|
||||
@@ -802,6 +805,7 @@ class AsyncLLM(EngineClient):
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
reasoning_ended=reasoning_ended,
|
||||
)
|
||||
|
||||
# The output_handler task pushes items into the queue.
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Literal, cast
|
||||
from typing import Any, Literal
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
@@ -11,7 +11,6 @@ from vllm.inputs.data import (
|
||||
ProcessorInputs,
|
||||
PromptType,
|
||||
SingletonInputs,
|
||||
SingletonPrompt,
|
||||
)
|
||||
from vllm.inputs.parse import split_enc_dec_inputs
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
@@ -20,22 +19,16 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.encoder_budget import MultiModalBudget
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFeatureSpec,
|
||||
MultiModalUUIDDict,
|
||||
)
|
||||
from vllm.multimodal.parse import ModalityDataItems, MultiModalDataItems
|
||||
from vllm.multimodal.processing.context import set_request_id
|
||||
from vllm.multimodal.utils import argsort_mm_positions
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import BaseRenderer, renderer_from_config
|
||||
from vllm.renderers.inputs import DictPrompt, TokPrompt
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tasks import POOLING_TASKS, SupportedTask
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds, random_uuid
|
||||
from vllm.utils.jsontree import json_iter_leaves
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -133,81 +126,6 @@ class InputProcessor:
|
||||
f"but got {type(params).__name__}"
|
||||
)
|
||||
|
||||
def _parse_mm_items(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:
|
||||
mm_processor = self.renderer.get_mm_processor()
|
||||
return mm_processor.info.parse_mm_data(mm_data)
|
||||
|
||||
def _validate_singleton_mm_uuids(self, prompt: SingletonPrompt) -> None:
|
||||
if not isinstance(prompt, dict):
|
||||
return
|
||||
|
||||
mm_data = cast(MultiModalDataDict, prompt.get("multi_modal_data") or {})
|
||||
mm_uuids = cast(MultiModalUUIDDict, prompt.get("multi_modal_uuids") or {})
|
||||
if not mm_data and not mm_uuids:
|
||||
return
|
||||
|
||||
mm_data_parsed = self._parse_mm_items(
|
||||
{k: v for k, v in mm_data.items() if v is not None}
|
||||
)
|
||||
mm_uuids_parsed = {
|
||||
k: [v] if isinstance(v, str) else v
|
||||
for k, v in mm_uuids.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
# NOTE: Include the keys corresponding to `None`
|
||||
modalities = mm_data.keys() | mm_uuids.keys()
|
||||
|
||||
for modality in modalities:
|
||||
data_items = cast(
|
||||
ModalityDataItems | list[Any], mm_data_parsed.get(modality, [])
|
||||
)
|
||||
uuid_items = cast(list[str | None], mm_uuids_parsed.get(modality, []))
|
||||
|
||||
if len(data_items) > 0:
|
||||
if len(uuid_items) > 0 and len(data_items) != len(uuid_items):
|
||||
raise ValueError(
|
||||
f"If given, multi_modal_uuids[{modality!r}] must have "
|
||||
f"same length as multi_modal_data[{modality!r}], but "
|
||||
f"got {len(uuid_items)} vs {len(data_items)}."
|
||||
)
|
||||
|
||||
for i, item in enumerate(data_items):
|
||||
if item is None:
|
||||
if not uuid_items:
|
||||
raise ValueError(
|
||||
f"multi_modal_data[{modality!r}][{i}] is empty but "
|
||||
f"multi_modal_uuids[{modality!r}] is missing."
|
||||
)
|
||||
|
||||
if uuid_items[i] is None:
|
||||
raise ValueError(
|
||||
f"multi_modal_data[{modality!r}][{i}] is empty but "
|
||||
f"multi_modal_uuids[{modality!r}][{i}] is missing."
|
||||
)
|
||||
else:
|
||||
if len(uuid_items) == 0:
|
||||
raise ValueError(
|
||||
f"multi_modal_data[{modality!r}] is empty but "
|
||||
f"multi_modal_uuids[{modality!r}] is missing."
|
||||
)
|
||||
|
||||
def _validate_mm_uuids(self, prompt: PromptType | DictPrompt | TokPrompt) -> None:
|
||||
"""
|
||||
Validate that user-provided multi_modal_uuids align with
|
||||
multi_modal_data in the incoming request prompt(s).
|
||||
Only checks lengths; `None` entries are allowed and will be
|
||||
auto-hashed downstream.
|
||||
"""
|
||||
|
||||
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
|
||||
self._validate_singleton_mm_uuids(prompt["encoder_prompt"]) # type: ignore[typeddict-item]
|
||||
|
||||
if (dec_prompt := prompt["decoder_prompt"]) is not None: # type: ignore[typeddict-item]
|
||||
self._validate_singleton_mm_uuids(dec_prompt)
|
||||
else:
|
||||
self._validate_singleton_mm_uuids(prompt)
|
||||
|
||||
def _validate_lora(self, lora_request: LoRARequest | None) -> None:
|
||||
if lora_request is None:
|
||||
return
|
||||
@@ -227,47 +145,6 @@ class InputProcessor:
|
||||
"[lora_path]` to use the LoRA tokenizer."
|
||||
)
|
||||
|
||||
def _extract_singleton_mm_data(
|
||||
self, prompt: SingletonPrompt
|
||||
) -> MultiModalDataDict | None:
|
||||
if not isinstance(prompt, dict):
|
||||
return None
|
||||
|
||||
return prompt.get("multi_modal_data")
|
||||
|
||||
def _extract_mm_data(
|
||||
self, prompt: PromptType | DictPrompt | TokPrompt
|
||||
) -> MultiModalDataDict | None:
|
||||
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
|
||||
return self._extract_singleton_mm_data(prompt["encoder_prompt"]) # type: ignore[typeddict-item]
|
||||
else:
|
||||
return self._extract_singleton_mm_data(prompt)
|
||||
|
||||
def _maybe_build_mm_uuids(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType | DictPrompt | TokPrompt,
|
||||
) -> MultiModalUUIDDict | None:
|
||||
"""Build per-item multimodal hash overrides when enabled. In this case,
|
||||
multimodal data items are identified by their request id, modality and
|
||||
index rather than their content.
|
||||
|
||||
Returns a dictionary of modality -> list[str] of overrides, or None if
|
||||
disabled or no multimodal data is present.
|
||||
"""
|
||||
mm_data = self._extract_mm_data(prompt)
|
||||
if not mm_data:
|
||||
return None
|
||||
|
||||
mm_items = self._parse_mm_items(
|
||||
{k: v for k, v in mm_data.items() if v is not None}
|
||||
)
|
||||
|
||||
return {
|
||||
modality: [f"{request_id}-{modality}-{i}" for i in range(data_count)]
|
||||
for modality, data_count in mm_items.get_all_counts().items()
|
||||
}
|
||||
|
||||
def _get_mm_identifier(
|
||||
self,
|
||||
mm_hash: str,
|
||||
@@ -309,7 +186,7 @@ class InputProcessor:
|
||||
def process_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: PromptType | DictPrompt | TokPrompt,
|
||||
prompt: PromptType | ProcessorInputs,
|
||||
params: SamplingParams | PoolingParams,
|
||||
arrival_time: float | None = None,
|
||||
lora_request: LoRARequest | None = None,
|
||||
@@ -333,43 +210,18 @@ class InputProcessor:
|
||||
f"is out of range [0, {num_ranks})."
|
||||
)
|
||||
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
if isinstance(prompt, dict) and "type" in prompt:
|
||||
if arrival_time is None:
|
||||
arrival_time = prompt.get("arrival_time", time.time()) # type: ignore[assignment]
|
||||
|
||||
# Optionally generate multimodal hash overrides to avoid hashing
|
||||
# multimodal data items by their content as their identifiers.
|
||||
|
||||
# NOTE: when users explicitly turn off BOTH prefix caching and input
|
||||
# processing caching, no multimodal features or embeddings will be
|
||||
# reused across requests, therefore identifying multimodal data items
|
||||
# by their content is no longer necessary, and we create uuids with
|
||||
# request id-modality-index as multimodal hash overrides.
|
||||
if (
|
||||
self.model_config.multimodal_config
|
||||
and self.model_config.multimodal_config.mm_processor_cache_gb == 0
|
||||
and not self.cache_config.enable_prefix_caching
|
||||
):
|
||||
mm_uuids = self._maybe_build_mm_uuids(request_id, prompt)
|
||||
processed_inputs: ProcessorInputs = prompt # type: ignore[assignment]
|
||||
else:
|
||||
# Otherwise, use user-provided uuids as multimodal hash overrides
|
||||
# if provided.
|
||||
self._validate_mm_uuids(prompt)
|
||||
if isinstance(prompt, dict):
|
||||
mm_uuids = cast(
|
||||
MultiModalUUIDDict | None, prompt.get("multi_modal_uuids")
|
||||
)
|
||||
else:
|
||||
mm_uuids = None
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
|
||||
# Process inputs, which includes:
|
||||
# 1. Tokenize text prompt, with LoRA request if one exists.
|
||||
# 2. For multimodal models with a merged preprocessor, preprocess
|
||||
# multimodal data and expand prompt token ids accordingly.
|
||||
with set_request_id(request_id), set_default_torch_num_threads():
|
||||
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
|
||||
processed_inputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
mm_uuids=mm_uuids,
|
||||
)
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -14,7 +14,7 @@ from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.inputs import ProcessorInputs, PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
@@ -22,7 +22,6 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import renderer_from_config
|
||||
from vllm.renderers.inputs import DictPrompt, TokPrompt
|
||||
from vllm.renderers.inputs.preprocess import extract_prompt_components
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
@@ -220,7 +219,7 @@ class LLMEngine:
|
||||
def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: EngineCoreRequest | PromptType | DictPrompt | TokPrompt,
|
||||
prompt: EngineCoreRequest | PromptType | ProcessorInputs,
|
||||
params: SamplingParams | PoolingParams,
|
||||
arrival_time: float | None = None,
|
||||
lora_request: LoRARequest | None = None,
|
||||
@@ -228,7 +227,7 @@ class LLMEngine:
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
priority: int = 0,
|
||||
prompt_text: str | None = None,
|
||||
) -> None:
|
||||
) -> str:
|
||||
# Validate the request_id type.
|
||||
if not isinstance(request_id, str):
|
||||
raise TypeError(f"request_id must be a string, got {type(request_id)}")
|
||||
@@ -243,7 +242,6 @@ class LLMEngine:
|
||||
"latter will be used, and the former will be ignored."
|
||||
)
|
||||
else:
|
||||
assert prompt_text is None
|
||||
request = self.input_processor.process_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
@@ -259,6 +257,8 @@ class LLMEngine:
|
||||
|
||||
self.input_processor.assign_request_id(request)
|
||||
|
||||
req_id = request.request_id
|
||||
|
||||
# Use cloned params that may have been updated in process_inputs()
|
||||
params = request.params
|
||||
|
||||
@@ -269,7 +269,7 @@ class LLMEngine:
|
||||
self.output_processor.add_request(request, prompt_text, None, 0)
|
||||
# Add the request to EngineCore.
|
||||
self.engine_core.add_request(request)
|
||||
return
|
||||
return req_id
|
||||
|
||||
# Fan out child requests (for n>1).
|
||||
parent_req = ParentRequest(request)
|
||||
@@ -286,6 +286,8 @@ class LLMEngine:
|
||||
# Add the request to EngineCore.
|
||||
self.engine_core.add_request(child_request)
|
||||
|
||||
return req_id
|
||||
|
||||
def step(self) -> list[RequestOutput | PoolingRequestOutput]:
|
||||
if self.should_execute_dummy_batch:
|
||||
self.should_execute_dummy_batch = False
|
||||
|
||||
Reference in New Issue
Block a user