[Bug] Fix Failure in /v1/chat/completions/render for Multimodal Requests (https://github.com/vllm-project/vllm/issues/35665) (#35684)
This commit is contained in:
@@ -167,6 +167,7 @@ fo = "fo"
|
||||
nd = "nd"
|
||||
eles = "eles"
|
||||
datas = "datas"
|
||||
ser = "ser"
|
||||
ure = "ure"
|
||||
|
||||
[tool.uv]
|
||||
|
||||
0
tests/entrypoints/openai/cpu/__init__.py
Normal file
0
tests/entrypoints/openai/cpu/__init__.py
Normal file
@@ -7,7 +7,7 @@ import httpx
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from tests.utils import RemoteLaunchRenderServer
|
||||
|
||||
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
|
||||
|
||||
@@ -16,7 +16,7 @@ MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
|
||||
def server():
|
||||
args: list[str] = []
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
with RemoteLaunchRenderServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@@ -43,23 +43,20 @@ async def test_completion_render_basic(client):
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Verify response structure
|
||||
# Verify response structure - list of GenerateRequest
|
||||
assert isinstance(data, list)
|
||||
assert len(data) > 0
|
||||
|
||||
# Verify first prompt
|
||||
# Verify first prompt is a GenerateRequest
|
||||
first_prompt = data[0]
|
||||
assert "prompt_token_ids" in first_prompt
|
||||
assert "prompt" in first_prompt
|
||||
assert isinstance(first_prompt["prompt_token_ids"], list)
|
||||
assert len(first_prompt["prompt_token_ids"]) > 0
|
||||
assert isinstance(first_prompt["prompt"], str)
|
||||
|
||||
# Verify prompt text is preserved
|
||||
assert (
|
||||
"When should a chat-completions handler return an empty string?"
|
||||
in first_prompt["prompt"]
|
||||
)
|
||||
assert "token_ids" in first_prompt
|
||||
assert "sampling_params" in first_prompt
|
||||
assert "model" in first_prompt
|
||||
assert "request_id" in first_prompt
|
||||
assert isinstance(first_prompt["token_ids"], list)
|
||||
assert len(first_prompt["token_ids"]) > 0
|
||||
assert first_prompt["model"] == MODEL_NAME
|
||||
assert first_prompt["request_id"].startswith("cmpl-")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -84,36 +81,15 @@ async def test_chat_completion_render_basic(client):
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Verify response structure - should be [conversation, engine_prompts]
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 2
|
||||
# Verify response structure - should be a GenerateRequest
|
||||
assert isinstance(data, dict)
|
||||
assert "token_ids" in data
|
||||
assert isinstance(data["token_ids"], list)
|
||||
assert len(data["token_ids"]) > 0
|
||||
|
||||
conversation, engine_prompts = data
|
||||
|
||||
# Verify conversation
|
||||
assert isinstance(conversation, list)
|
||||
assert len(conversation) > 0
|
||||
assert conversation[0]["role"] == "user"
|
||||
assert "empty string" in conversation[0]["content"]
|
||||
|
||||
# Verify engine_prompts
|
||||
assert isinstance(engine_prompts, list)
|
||||
assert len(engine_prompts) > 0
|
||||
|
||||
first_prompt = engine_prompts[0]
|
||||
assert "prompt_token_ids" in first_prompt
|
||||
assert "prompt" in first_prompt
|
||||
assert isinstance(first_prompt["prompt_token_ids"], list)
|
||||
assert len(first_prompt["prompt_token_ids"]) > 0
|
||||
|
||||
# Verify chat template was applied (should have instruction markers)
|
||||
assert "[INST]" in first_prompt["prompt"]
|
||||
assert "[/INST]" in first_prompt["prompt"]
|
||||
|
||||
# Verify token IDs are correctly preserved as integers
|
||||
token_ids = first_prompt["prompt_token_ids"]
|
||||
# Verify token IDs are integers and BOS token is present
|
||||
token_ids = data["token_ids"]
|
||||
assert all(isinstance(tid, int) for tid in token_ids)
|
||||
# Verify BOS token (usually 1 for LLaMA models)
|
||||
assert token_ids[0] == 1
|
||||
|
||||
|
||||
@@ -131,15 +107,18 @@ async def test_completion_render_multiple_prompts(client):
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Should return two prompts
|
||||
# Should return two GenerateRequest items
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 2
|
||||
|
||||
# Verify both prompts have required fields
|
||||
# Verify both prompts have GenerateRequest fields
|
||||
for prompt in data:
|
||||
assert "prompt_token_ids" in prompt
|
||||
assert "prompt" in prompt
|
||||
assert len(prompt["prompt_token_ids"]) > 0
|
||||
assert "token_ids" in prompt
|
||||
assert "sampling_params" in prompt
|
||||
assert "model" in prompt
|
||||
assert "request_id" in prompt
|
||||
assert len(prompt["token_ids"]) > 0
|
||||
assert prompt["request_id"].startswith("cmpl-")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -160,17 +139,49 @@ async def test_chat_completion_render_multi_turn(client):
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
conversation, engine_prompts = data
|
||||
|
||||
# Verify all messages preserved
|
||||
assert len(conversation) == 3
|
||||
assert conversation[0]["role"] == "user"
|
||||
assert conversation[1]["role"] == "assistant"
|
||||
assert conversation[2]["role"] == "user"
|
||||
|
||||
# Verify tokenization occurred
|
||||
assert len(engine_prompts) > 0
|
||||
assert len(engine_prompts[0]["prompt_token_ids"]) > 0
|
||||
assert isinstance(data, dict)
|
||||
assert "token_ids" in data
|
||||
assert isinstance(data["token_ids"], list)
|
||||
assert len(data["token_ids"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_render_with_stream_true(client):
|
||||
"""Render accepts stream params but still returns JSON (non-streamed)."""
|
||||
|
||||
response = await client.post(
|
||||
"/v1/chat/completions/render",
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"stream": True,
|
||||
"stream_options": {
|
||||
"include_usage": True,
|
||||
"continuous_usage_stats": True,
|
||||
},
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Stream options should be accepted by /render.",
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers.get("content-type", "").startswith("application/json")
|
||||
|
||||
data = response.json()
|
||||
assert isinstance(data, dict)
|
||||
assert "token_ids" in data
|
||||
assert isinstance(data["token_ids"], list)
|
||||
assert len(data["token_ids"]) > 0
|
||||
|
||||
# /render should preserve stream fields on the returned token-in request.
|
||||
assert data.get("stream") is True
|
||||
assert isinstance(data.get("stream_options"), dict)
|
||||
assert data["stream_options"].get("include_usage") is True
|
||||
assert data["stream_options"].get("continuous_usage_stats") is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -224,3 +235,31 @@ async def test_completion_render_no_generation(client):
|
||||
assert response.status_code == 200
|
||||
# Render should be fast (< 1 second) since no generation
|
||||
assert elapsed < 1.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_render_with_sampling_params(client):
|
||||
"""Verify sampling params are correctly returned by /render."""
|
||||
response = await client.post(
|
||||
"/v1/chat/completions/render",
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"messages": [{"role": "user", "content": "Test sampling params"}],
|
||||
"temperature": 0.123,
|
||||
"top_p": 0.456,
|
||||
"frequency_penalty": 1.1,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert "sampling_params" in data
|
||||
sampling_params = data["sampling_params"]
|
||||
|
||||
assert sampling_params.get("temperature") == 0.123
|
||||
assert sampling_params.get("top_p") == 0.456
|
||||
assert sampling_params.get("frequency_penalty") == 1.1
|
||||
|
||||
# Check that internal fields are not present
|
||||
assert "_all_stop_token_ids" not in sampling_params
|
||||
|
||||
155
tests/entrypoints/openai/cpu/test_render_multimodal.py
Normal file
155
tests/entrypoints/openai/cpu/test_render_multimodal.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""Multimodal tests for the /render endpoints that expose prompt preprocessing."""
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.multimodal.utils import encode_image_url
|
||||
|
||||
VISION_MODEL_NAME = "Qwen/Qwen3-VL-2B-Instruct"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vision_server():
|
||||
"""Vision-capable server used for multimodal /render tests."""
|
||||
|
||||
args = [
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
"100",
|
||||
"--max-num-seqs",
|
||||
"1",
|
||||
"--limit-mm-per-prompt.image",
|
||||
"1",
|
||||
"--limit-mm-per-prompt.video",
|
||||
"0",
|
||||
]
|
||||
|
||||
env_overrides: dict[str, str] = {}
|
||||
|
||||
with RemoteOpenAIServer(
|
||||
VISION_MODEL_NAME,
|
||||
args,
|
||||
env_dict=env_overrides,
|
||||
) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def vision_client(vision_server):
|
||||
async with httpx.AsyncClient(
|
||||
base_url=vision_server.url_for(""), timeout=60.0
|
||||
) as http_client:
|
||||
yield http_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_render_with_base64_image_url(
|
||||
vision_client,
|
||||
local_asset_server,
|
||||
):
|
||||
"""Render a multimodal chat request and verify tokens are returned."""
|
||||
|
||||
image = local_asset_server.get_image_asset("RGBA_comp.png")
|
||||
data_url = encode_image_url(image, format="PNG")
|
||||
|
||||
assert data_url.startswith("data:image/")
|
||||
assert ";base64," in data_url
|
||||
|
||||
response = await vision_client.post(
|
||||
"/v1/chat/completions/render",
|
||||
json={
|
||||
"model": VISION_MODEL_NAME,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": data_url}},
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
],
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
data = response.json()
|
||||
assert isinstance(data, dict)
|
||||
assert "token_ids" in data
|
||||
assert isinstance(data["token_ids"], list)
|
||||
assert len(data["token_ids"]) > 0
|
||||
|
||||
# Verify multimodal features are populated
|
||||
assert "features" in data
|
||||
features = data["features"]
|
||||
assert features is not None
|
||||
|
||||
# mm_hashes: should have an "image" key with a list of hash strings
|
||||
assert "mm_hashes" in features
|
||||
assert "image" in features["mm_hashes"]
|
||||
image_hashes = features["mm_hashes"]["image"]
|
||||
assert isinstance(image_hashes, list)
|
||||
assert len(image_hashes) > 0
|
||||
assert all(isinstance(h, str) for h in image_hashes)
|
||||
|
||||
# mm_placeholders: should have an "image" key with offset/length dicts
|
||||
assert "mm_placeholders" in features
|
||||
assert "image" in features["mm_placeholders"]
|
||||
image_placeholders = features["mm_placeholders"]["image"]
|
||||
assert isinstance(image_placeholders, list)
|
||||
assert len(image_placeholders) > 0
|
||||
for p in image_placeholders:
|
||||
assert "offset" in p
|
||||
assert "length" in p
|
||||
assert isinstance(p["offset"], int)
|
||||
assert isinstance(p["length"], int)
|
||||
assert p["length"] > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenize_matches_render_for_multimodal_input(
|
||||
vision_client,
|
||||
local_asset_server,
|
||||
):
|
||||
"""`/tokenize` should match `/v1/chat/completions/render` token output."""
|
||||
|
||||
image = local_asset_server.get_image_asset("RGBA_comp.png")
|
||||
data_url = encode_image_url(image, format="PNG")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": data_url}},
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
render_response = await vision_client.post(
|
||||
"/v1/chat/completions/render",
|
||||
json={
|
||||
"model": VISION_MODEL_NAME,
|
||||
"messages": messages,
|
||||
},
|
||||
)
|
||||
assert render_response.status_code == 200
|
||||
render_data = render_response.json()
|
||||
|
||||
tokenize_response = await vision_client.post(
|
||||
"/tokenize",
|
||||
json={
|
||||
"model": VISION_MODEL_NAME,
|
||||
"messages": messages,
|
||||
},
|
||||
)
|
||||
assert tokenize_response.status_code == 200
|
||||
tokenize_data = tokenize_response.json()
|
||||
|
||||
assert tokenize_data["tokens"] == render_data["token_ids"]
|
||||
assert tokenize_data["count"] == len(render_data["token_ids"])
|
||||
@@ -42,21 +42,12 @@ async def test_chat_render_basic(client):
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 2
|
||||
|
||||
conversation, engine_prompts = data
|
||||
|
||||
assert isinstance(conversation, list)
|
||||
assert conversation[0]["role"] == "user"
|
||||
|
||||
assert isinstance(engine_prompts, list)
|
||||
assert len(engine_prompts) > 0
|
||||
first_prompt = engine_prompts[0]
|
||||
assert "prompt_token_ids" in first_prompt
|
||||
assert "prompt" in first_prompt
|
||||
assert isinstance(first_prompt["prompt_token_ids"], list)
|
||||
assert all(isinstance(t, int) for t in first_prompt["prompt_token_ids"])
|
||||
# Response should be a GenerateRequest dict
|
||||
assert isinstance(data, dict)
|
||||
assert "token_ids" in data
|
||||
assert isinstance(data["token_ids"], list)
|
||||
assert len(data["token_ids"]) > 0
|
||||
assert all(isinstance(t, int) for t in data["token_ids"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -74,14 +65,12 @@ async def test_chat_render_multi_turn(client):
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
conversation, engine_prompts = response.json()
|
||||
data = response.json()
|
||||
|
||||
assert len(conversation) == 3
|
||||
assert conversation[0]["role"] == "user"
|
||||
assert conversation[1]["role"] == "assistant"
|
||||
assert conversation[2]["role"] == "user"
|
||||
assert len(engine_prompts) > 0
|
||||
assert len(engine_prompts[0]["prompt_token_ids"]) > 0
|
||||
assert isinstance(data, dict)
|
||||
assert "token_ids" in data
|
||||
assert isinstance(data["token_ids"], list)
|
||||
assert len(data["token_ids"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -118,11 +107,13 @@ async def test_completion_render_basic(client):
|
||||
assert len(data) > 0
|
||||
|
||||
first_prompt = data[0]
|
||||
assert "prompt_token_ids" in first_prompt
|
||||
assert "prompt" in first_prompt
|
||||
assert isinstance(first_prompt["prompt_token_ids"], list)
|
||||
assert len(first_prompt["prompt_token_ids"]) > 0
|
||||
assert "Once upon a time" in first_prompt["prompt"]
|
||||
assert "token_ids" in first_prompt
|
||||
assert "sampling_params" in first_prompt
|
||||
assert "model" in first_prompt
|
||||
assert "request_id" in first_prompt
|
||||
assert isinstance(first_prompt["token_ids"], list)
|
||||
assert len(first_prompt["token_ids"]) > 0
|
||||
assert first_prompt["request_id"].startswith("cmpl-")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -142,9 +133,12 @@ async def test_completion_render_multiple_prompts(client):
|
||||
assert len(data) == 2
|
||||
|
||||
for prompt in data:
|
||||
assert "prompt_token_ids" in prompt
|
||||
assert "prompt" in prompt
|
||||
assert len(prompt["prompt_token_ids"]) > 0
|
||||
assert "token_ids" in prompt
|
||||
assert "sampling_params" in prompt
|
||||
assert "model" in prompt
|
||||
assert "request_id" in prompt
|
||||
assert len(prompt["token_ids"]) > 0
|
||||
assert prompt["request_id"].startswith("cmpl-")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -368,6 +368,7 @@ async def init_app_state(
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
default_chat_template_kwargs=args.default_chat_template_kwargs,
|
||||
trust_request_chat_template=args.trust_request_chat_template,
|
||||
)
|
||||
|
||||
@@ -457,6 +458,9 @@ async def init_render_app_state(
|
||||
|
||||
state.openai_serving_models = model_registry
|
||||
|
||||
# Expose tokenization via the render handler (no engine required).
|
||||
state.openai_serving_tokenization = state.openai_serving_render
|
||||
|
||||
state.vllm_config = vllm_config
|
||||
# Disable stats logging — there is no engine to poll.
|
||||
state.log_stats = False
|
||||
|
||||
@@ -17,7 +17,6 @@ from pydantic import (
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
@@ -269,53 +268,3 @@ class GenerationError(Exception):
|
||||
def __init__(self, message: str = "Internal server error"):
|
||||
super().__init__(message)
|
||||
self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
|
||||
|
||||
####### Tokens IN <> Tokens OUT #######
|
||||
class GenerateRequest(BaseModel):
|
||||
request_id: str = Field(
|
||||
default_factory=random_uuid,
|
||||
description=(
|
||||
"The request_id related to this request. If the caller does "
|
||||
"not set it, a random_uuid will be generated. This id is used "
|
||||
"through out the inference process and return in response."
|
||||
),
|
||||
)
|
||||
token_ids: list[int]
|
||||
"""The token ids to generate text from."""
|
||||
|
||||
# features: MultiModalFeatureSpec
|
||||
# TODO (NickLucche): implement once Renderer work is completed
|
||||
features: str | None = None
|
||||
"""The processed MM inputs for the model."""
|
||||
|
||||
sampling_params: SamplingParams
|
||||
"""The sampling parameters for the model."""
|
||||
|
||||
model: str | None = None
|
||||
|
||||
stream: bool | None = False
|
||||
stream_options: StreamOptions | None = None
|
||||
cache_salt: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the prefix cache will be salted with the provided "
|
||||
"string to prevent an attacker to guess prompts in multi-user "
|
||||
"environments. The salt should be random, protected from "
|
||||
"access by 3rd parties, and long enough to be "
|
||||
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
|
||||
"to 256 bit)."
|
||||
),
|
||||
)
|
||||
priority: int = Field(
|
||||
default=0,
|
||||
description=(
|
||||
"The priority of the request (lower means earlier handling; "
|
||||
"default: 0). Any priority other than 0 will raise an error "
|
||||
"if the served model does not use priority scheduling."
|
||||
),
|
||||
)
|
||||
kv_transfer_params: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="KVTransfer parameters used for disaggregated serving.",
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ from contextlib import asynccontextmanager
|
||||
from http import HTTPStatus
|
||||
|
||||
import pydantic
|
||||
from fastapi import FastAPI, HTTPException, Request, Response
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
@@ -350,7 +350,8 @@ async def engine_error_handler(
|
||||
server=req.app.state.server,
|
||||
engine=req.app.state.engine_client,
|
||||
)
|
||||
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
err = create_error_response(exc)
|
||||
return JSONResponse(err.model_dump(), status_code=err.error.code)
|
||||
|
||||
|
||||
async def exception_handler(req: Request, exc: Exception):
|
||||
|
||||
@@ -2,20 +2,55 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionLogProbs
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
SamplingParams,
|
||||
StreamOptions,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import StreamOptions
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.renderers import TokenizeParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
|
||||
####### Tokens IN <> Tokens OUT #######
|
||||
|
||||
|
||||
class PlaceholderRangeInfo(BaseModel):
|
||||
"""Serializable placeholder location for a single multi-modal item."""
|
||||
|
||||
offset: int
|
||||
"""Start index of the placeholder tokens in the prompt."""
|
||||
|
||||
length: int
|
||||
"""Number of placeholder tokens."""
|
||||
|
||||
# TODO: add ``is_embed: list[bool] | None`` once the /generate side
|
||||
# consumes features — some models (e.g. Qwen-VL) use sparse
|
||||
# placeholder masks that cannot be recomputed from offset+length alone.
|
||||
|
||||
|
||||
class MultiModalFeatures(BaseModel):
|
||||
"""Lightweight multimodal metadata produced by the render step.
|
||||
|
||||
Carries hashes (for cache lookup / identification) and placeholder
|
||||
positions so the downstream ``/generate`` service knows *where* in
|
||||
the token sequence each multimodal item lives.
|
||||
|
||||
.. note:: Phase 1 — metadata only.
|
||||
Phase 2 should add ``mm_kwargs`` (processed tensor data) using a
|
||||
binary transport so the ``/generate`` side can skip re-processing.
|
||||
The ``/generate`` endpoint must also be updated to inject these
|
||||
features into ``ProcessorInputs`` before passing to
|
||||
``InputProcessor.process_inputs``.
|
||||
"""
|
||||
|
||||
mm_hashes: dict[str, list[str]]
|
||||
"""Per-modality item hashes, e.g. ``{"image": ["abc", "def"]}``."""
|
||||
|
||||
mm_placeholders: dict[str, list[PlaceholderRangeInfo]]
|
||||
"""Per-modality placeholder ranges in the token sequence."""
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
request_id: str = Field(
|
||||
default_factory=lambda: f"{random_uuid()}",
|
||||
@@ -28,10 +63,15 @@ class GenerateRequest(BaseModel):
|
||||
token_ids: list[int]
|
||||
"""The token ids to generate text from."""
|
||||
|
||||
# features: MultiModalFeatureSpec
|
||||
# TODO (NickLucche): implement once Renderer work is completed
|
||||
features: str | None = None
|
||||
"""The processed MM inputs for the model."""
|
||||
@field_validator("token_ids")
|
||||
@classmethod
|
||||
def validate_token_ids(cls, v: list[int]) -> list[int]:
|
||||
if any(t < 0 for t in v):
|
||||
raise ValueError("token_ids must not contain negative values")
|
||||
return v
|
||||
|
||||
features: MultiModalFeatures | None = None
|
||||
"""Multimodal hashes and placeholder positions (populated for MM inputs)."""
|
||||
|
||||
sampling_params: SamplingParams
|
||||
"""The sampling parameters for the model."""
|
||||
|
||||
@@ -9,6 +9,7 @@ from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionReque
|
||||
from vllm.entrypoints.openai.completion.protocol import CompletionRequest
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.utils import validate_json_request
|
||||
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest
|
||||
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
|
||||
from vllm.logger import init_logger
|
||||
|
||||
@@ -24,7 +25,7 @@ def render(request: Request) -> OpenAIServingRender | None:
|
||||
@router.post(
|
||||
"/v1/chat/completions/render",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
response_model=list,
|
||||
response_model=GenerateRequest,
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||
@@ -44,13 +45,13 @@ async def render_chat_completion(request: ChatCompletionRequest, raw_request: Re
|
||||
if isinstance(result, ErrorResponse):
|
||||
return JSONResponse(content=result.model_dump(), status_code=result.error.code)
|
||||
|
||||
return JSONResponse(content=result)
|
||||
return JSONResponse(content=result.model_dump())
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/completions/render",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
response_model=list,
|
||||
response_model=list[GenerateRequest],
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||
@@ -67,7 +68,7 @@ async def render_completion(request: CompletionRequest, raw_request: Request):
|
||||
if isinstance(result, ErrorResponse):
|
||||
return JSONResponse(content=result.model_dump(), status_code=result.error.code)
|
||||
|
||||
return JSONResponse(content=result)
|
||||
return JSONResponse(content=[item.model_dump() for item in result])
|
||||
|
||||
|
||||
def attach_router(app: FastAPI) -> None:
|
||||
|
||||
@@ -24,14 +24,29 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
|
||||
parse_chat_inputs_to_harmony_messages,
|
||||
render_for_completion,
|
||||
)
|
||||
from vllm.entrypoints.utils import create_error_response
|
||||
from vllm.entrypoints.serve.disagg.protocol import (
|
||||
GenerateRequest,
|
||||
MultiModalFeatures,
|
||||
PlaceholderRangeInfo,
|
||||
)
|
||||
from vllm.entrypoints.utils import (
|
||||
create_error_response,
|
||||
get_max_tokens,
|
||||
)
|
||||
from vllm.inputs.data import ProcessorInputs, PromptType, SingletonPrompt, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal.inputs import MultiModalHashes, MultiModalPlaceholderDict
|
||||
from vllm.parser import ParserManager
|
||||
from vllm.renderers import BaseRenderer, merge_kwargs
|
||||
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
|
||||
from vllm.renderers.inputs.preprocess import (
|
||||
extract_prompt_components,
|
||||
extract_prompt_len,
|
||||
parse_model_prompt,
|
||||
prompt_to_seq,
|
||||
)
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers import ToolParser
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
from vllm.utils.mistral import mt as _mt
|
||||
|
||||
@@ -83,10 +98,18 @@ class OpenAIServingRender:
|
||||
self.supports_browsing = False
|
||||
self.supports_code_interpreter = False
|
||||
|
||||
self.default_sampling_params = model_config.get_diff_sampling_param()
|
||||
mc = model_config
|
||||
self.override_max_tokens = (
|
||||
self.default_sampling_params.get("max_tokens")
|
||||
if mc.generation_config not in ("auto", "vllm")
|
||||
else getattr(mc, "override_generation_config", {}).get("max_new_tokens")
|
||||
)
|
||||
|
||||
async def render_chat_request(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> tuple[list[ConversationMessage], list[ProcessorInputs]] | ErrorResponse:
|
||||
) -> GenerateRequest | ErrorResponse:
|
||||
"""Validate the model and preprocess a chat completion request.
|
||||
|
||||
This is the authoritative implementation used directly by the
|
||||
@@ -96,7 +119,56 @@ class OpenAIServingRender:
|
||||
if error_check_ret is not None:
|
||||
logger.error("Error with model %s", error_check_ret)
|
||||
return error_check_ret
|
||||
return await self.render_chat(request)
|
||||
|
||||
if request.use_beam_search:
|
||||
return self.create_error_response(
|
||||
"Beam search is not supported by the render endpoint"
|
||||
)
|
||||
|
||||
result = await self.render_chat(request)
|
||||
if isinstance(result, ErrorResponse):
|
||||
return result
|
||||
|
||||
_, engine_prompts = result
|
||||
|
||||
if len(engine_prompts) != 1:
|
||||
return self.create_error_response(
|
||||
f"Expected exactly 1 engine prompt, got {len(engine_prompts)}"
|
||||
)
|
||||
|
||||
engine_prompt = engine_prompts[0]
|
||||
|
||||
prompt_components = extract_prompt_components(self.model_config, engine_prompt)
|
||||
token_ids = prompt_components.token_ids
|
||||
if not token_ids:
|
||||
return self.create_error_response("No token_ids rendered")
|
||||
token_ids = list(token_ids)
|
||||
|
||||
input_length = extract_prompt_len(self.model_config, engine_prompt)
|
||||
max_tokens = get_max_tokens(
|
||||
self.model_config.max_model_len,
|
||||
request.max_completion_tokens
|
||||
if request.max_completion_tokens is not None
|
||||
else request.max_tokens,
|
||||
input_length,
|
||||
self.default_sampling_params,
|
||||
self.override_max_tokens,
|
||||
)
|
||||
params = request.to_sampling_params(max_tokens, self.default_sampling_params)
|
||||
|
||||
request_id = f"chatcmpl-{random_uuid()}"
|
||||
|
||||
return GenerateRequest(
|
||||
request_id=request_id,
|
||||
token_ids=token_ids,
|
||||
features=self._extract_mm_features(engine_prompt),
|
||||
sampling_params=params,
|
||||
model=request.model,
|
||||
stream=bool(request.stream),
|
||||
stream_options=(request.stream_options if request.stream else None),
|
||||
cache_salt=request.cache_salt,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
async def render_chat(
|
||||
self,
|
||||
@@ -183,7 +255,7 @@ class OpenAIServingRender:
|
||||
async def render_completion_request(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
) -> list[ProcessorInputs] | ErrorResponse:
|
||||
) -> list[GenerateRequest] | ErrorResponse:
|
||||
"""Validate the model and preprocess a completion request.
|
||||
|
||||
This is the authoritative implementation used directly by the
|
||||
@@ -192,7 +264,48 @@ class OpenAIServingRender:
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
return await self.render_completion(request)
|
||||
result = await self.render_completion(request)
|
||||
if isinstance(result, ErrorResponse):
|
||||
return result
|
||||
generate_requests: list[GenerateRequest] = []
|
||||
for engine_prompt in result:
|
||||
prompt_components = extract_prompt_components(
|
||||
self.model_config, engine_prompt
|
||||
)
|
||||
token_ids = prompt_components.token_ids
|
||||
if not token_ids:
|
||||
return self.create_error_response("No token_ids rendered")
|
||||
token_ids = list(token_ids)
|
||||
|
||||
input_length = extract_prompt_len(self.model_config, engine_prompt)
|
||||
max_tokens = get_max_tokens(
|
||||
self.model_config.max_model_len,
|
||||
request.max_tokens,
|
||||
input_length,
|
||||
self.default_sampling_params,
|
||||
self.override_max_tokens,
|
||||
)
|
||||
params = request.to_sampling_params(
|
||||
max_tokens, self.default_sampling_params
|
||||
)
|
||||
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
|
||||
generate_requests.append(
|
||||
GenerateRequest(
|
||||
request_id=request_id,
|
||||
token_ids=token_ids,
|
||||
features=self._extract_mm_features(engine_prompt),
|
||||
sampling_params=params,
|
||||
model=request.model,
|
||||
stream=bool(request.stream),
|
||||
stream_options=(request.stream_options if request.stream else None),
|
||||
cache_salt=request.cache_salt,
|
||||
priority=request.priority,
|
||||
)
|
||||
)
|
||||
|
||||
return generate_requests
|
||||
|
||||
async def render_completion(
|
||||
self,
|
||||
@@ -223,6 +336,33 @@ class OpenAIServingRender:
|
||||
|
||||
return engine_prompts
|
||||
|
||||
@staticmethod
|
||||
def _extract_mm_features(
|
||||
engine_prompt: ProcessorInputs,
|
||||
) -> MultiModalFeatures | None:
|
||||
"""Extract multimodal metadata from a rendered engine prompt.
|
||||
|
||||
Returns ``None`` for text-only prompts.
|
||||
"""
|
||||
if engine_prompt.get("type") != "multimodal":
|
||||
return None
|
||||
|
||||
# At this point engine_prompt is a MultiModalInputs TypedDict.
|
||||
mm_hashes: MultiModalHashes = engine_prompt["mm_hashes"] # type: ignore[typeddict-item]
|
||||
raw_placeholders: MultiModalPlaceholderDict = engine_prompt["mm_placeholders"] # type: ignore[typeddict-item]
|
||||
|
||||
mm_placeholders = {
|
||||
modality: [
|
||||
PlaceholderRangeInfo(offset=p.offset, length=p.length) for p in ranges
|
||||
]
|
||||
for modality, ranges in raw_placeholders.items()
|
||||
}
|
||||
|
||||
return MultiModalFeatures(
|
||||
mm_hashes=mm_hashes,
|
||||
mm_placeholders=mm_placeholders,
|
||||
)
|
||||
|
||||
def _make_request_with_harmony(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
|
||||
@@ -35,6 +35,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
request_logger: RequestLogger | None,
|
||||
chat_template: str | None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
default_chat_template_kwargs: dict[str, Any] | None = None,
|
||||
trust_request_chat_template: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
@@ -45,6 +46,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
self.default_chat_template_kwargs = default_chat_template_kwargs or {}
|
||||
self.trust_request_chat_template = trust_request_chat_template
|
||||
|
||||
async def create_tokenize(
|
||||
@@ -79,7 +81,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
request.messages,
|
||||
default_template=self.chat_template,
|
||||
default_template_content_format=self.chat_template_content_format,
|
||||
default_template_kwargs=None,
|
||||
default_template_kwargs=self.default_chat_template_kwargs,
|
||||
tool_dicts=tool_dicts,
|
||||
)
|
||||
else:
|
||||
@@ -98,8 +100,9 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
if "prompt_token_ids" in engine_prompt:
|
||||
input_ids.extend(engine_prompt["prompt_token_ids"]) # type: ignore[typeddict-item]
|
||||
prompt_components = self._extract_prompt_components(engine_prompt)
|
||||
if prompt_components.token_ids is not None:
|
||||
input_ids.extend(prompt_components.token_ids)
|
||||
|
||||
token_strs = None
|
||||
if request.return_token_strs:
|
||||
|
||||
@@ -8,7 +8,7 @@ from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
from inspect import isclass
|
||||
from types import FunctionType
|
||||
from typing import Any, TypeAlias, get_type_hints
|
||||
from typing import Any, ClassVar, TypeAlias, cast, get_type_hints
|
||||
|
||||
import cloudpickle
|
||||
import msgspec
|
||||
@@ -460,6 +460,19 @@ def run_method(
|
||||
|
||||
|
||||
class PydanticMsgspecMixin:
|
||||
"""Make a ``msgspec.Struct`` compatible with Pydantic for both
|
||||
**validation** (JSON/dict -> Struct) and **serialization**
|
||||
(Struct -> JSON-safe dict).
|
||||
|
||||
Subclasses may set ``__pydantic_msgspec_exclude__`` (a ``set[str]``)
|
||||
to list non-underscore field names that should also be stripped from
|
||||
serialized output. Fields whose names start with ``_`` are always
|
||||
excluded automatically.
|
||||
"""
|
||||
|
||||
# Subclasses can override to exclude additional public-but-internal keys.
|
||||
__pydantic_msgspec_exclude__: ClassVar[set[str]] = set()
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, source_type: Any, handler: GetCoreSchemaHandler
|
||||
@@ -476,32 +489,62 @@ class PydanticMsgspecMixin:
|
||||
# Build the Pydantic typed_dict_field for each msgspec field
|
||||
fields = {}
|
||||
for name, hint in type_hints.items():
|
||||
if name not in msgspec_fields:
|
||||
# Skip ClassVar and other non-struct annotations.
|
||||
continue
|
||||
# Skip private fields — they are excluded from serialization
|
||||
# and should not appear in the generated JSON/OpenAPI schema.
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
msgspec_field = msgspec_fields[name]
|
||||
|
||||
# typed_dict_field using the handler to get the schema
|
||||
field_schema = handler(hint)
|
||||
|
||||
# Add default value to the schema.
|
||||
# Mark fields with defaults as not required so the generated
|
||||
# JSON Schema stays consistent with ``omit_defaults=True``
|
||||
# serialization (fields at their default value may be absent).
|
||||
if msgspec_field.default_factory is not msgspec.NODEFAULT:
|
||||
wrapped_schema = core_schema.with_default_schema(
|
||||
schema=field_schema,
|
||||
default_factory=msgspec_field.default_factory,
|
||||
)
|
||||
fields[name] = core_schema.typed_dict_field(wrapped_schema)
|
||||
fields[name] = core_schema.typed_dict_field(
|
||||
wrapped_schema, required=False
|
||||
)
|
||||
elif msgspec_field.default is not msgspec.NODEFAULT:
|
||||
wrapped_schema = core_schema.with_default_schema(
|
||||
schema=field_schema,
|
||||
default=msgspec_field.default,
|
||||
)
|
||||
fields[name] = core_schema.typed_dict_field(wrapped_schema)
|
||||
fields[name] = core_schema.typed_dict_field(
|
||||
wrapped_schema, required=False
|
||||
)
|
||||
else:
|
||||
# No default, so Pydantic will treat it as required
|
||||
fields[name] = core_schema.typed_dict_field(field_schema)
|
||||
return core_schema.no_info_after_validator_function(
|
||||
typed_dict_then_convert = core_schema.no_info_after_validator_function(
|
||||
cls._validate_msgspec,
|
||||
core_schema.typed_dict_schema(fields),
|
||||
)
|
||||
|
||||
# Build a serializer that strips private / excluded fields.
|
||||
serializer = core_schema.plain_serializer_function_ser_schema(
|
||||
cls._serialize_msgspec,
|
||||
info_arg=False,
|
||||
)
|
||||
|
||||
# Accept either an already-constructed msgspec.Struct instance or a
|
||||
# JSON/dict-like payload.
|
||||
return core_schema.union_schema(
|
||||
[
|
||||
core_schema.is_instance_schema(source_type),
|
||||
typed_dict_then_convert,
|
||||
],
|
||||
serialization=serializer,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _validate_msgspec(cls, value: Any) -> Any:
|
||||
"""Validate and convert input to msgspec.Struct instance."""
|
||||
@@ -510,3 +553,25 @@ class PydanticMsgspecMixin:
|
||||
if isinstance(value, dict):
|
||||
return cls(**value)
|
||||
return msgspec.convert(value, type=cls)
|
||||
|
||||
@staticmethod
|
||||
def _serialize_msgspec(value: Any) -> Any:
|
||||
"""Serialize a msgspec.Struct to a JSON-compatible dict, stripping
|
||||
private (``_``-prefixed) and explicitly excluded fields.
|
||||
|
||||
Uses ``msgspec.to_builtins`` which respects ``omit_defaults=True``,
|
||||
so only fields that differ from their declared defaults are included.
|
||||
"""
|
||||
raw = msgspec.to_builtins(value)
|
||||
if not isinstance(raw, dict):
|
||||
return raw
|
||||
|
||||
exclude: set[str] = cast(
|
||||
set[str],
|
||||
getattr(type(value), "__pydantic_msgspec_exclude__", set()),
|
||||
)
|
||||
for key in list(raw):
|
||||
if key.startswith("_") or key in exclude:
|
||||
del raw[key]
|
||||
|
||||
return raw
|
||||
|
||||
Reference in New Issue
Block a user