diff --git a/pyproject.toml b/pyproject.toml index 07d46f0ac..64a6de30e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -167,6 +167,7 @@ fo = "fo" nd = "nd" eles = "eles" datas = "datas" +ser = "ser" ure = "ure" [tool.uv] diff --git a/tests/entrypoints/openai/cpu/__init__.py b/tests/entrypoints/openai/cpu/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/entrypoints/openai/cpu/test_render.py b/tests/entrypoints/openai/cpu/test_render.py index 11389a2e4..7aacf4564 100644 --- a/tests/entrypoints/openai/cpu/test_render.py +++ b/tests/entrypoints/openai/cpu/test_render.py @@ -7,7 +7,7 @@ import httpx import pytest import pytest_asyncio -from tests.utils import RemoteOpenAIServer +from tests.utils import RemoteLaunchRenderServer MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM" @@ -16,7 +16,7 @@ MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM" def server(): args: list[str] = [] - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + with RemoteLaunchRenderServer(MODEL_NAME, args) as remote_server: yield remote_server @@ -43,23 +43,20 @@ async def test_completion_render_basic(client): assert response.status_code == 200 data = response.json() - # Verify response structure + # Verify response structure - list of GenerateRequest assert isinstance(data, list) assert len(data) > 0 - # Verify first prompt + # Verify first prompt is a GenerateRequest first_prompt = data[0] - assert "prompt_token_ids" in first_prompt - assert "prompt" in first_prompt - assert isinstance(first_prompt["prompt_token_ids"], list) - assert len(first_prompt["prompt_token_ids"]) > 0 - assert isinstance(first_prompt["prompt"], str) - - # Verify prompt text is preserved - assert ( - "When should a chat-completions handler return an empty string?" - in first_prompt["prompt"] - ) + assert "token_ids" in first_prompt + assert "sampling_params" in first_prompt + assert "model" in first_prompt + assert "request_id" in first_prompt + assert isinstance(first_prompt["token_ids"], list) + assert len(first_prompt["token_ids"]) > 0 + assert first_prompt["model"] == MODEL_NAME + assert first_prompt["request_id"].startswith("cmpl-") @pytest.mark.asyncio @@ -84,36 +81,15 @@ async def test_chat_completion_render_basic(client): assert response.status_code == 200 data = response.json() - # Verify response structure - should be [conversation, engine_prompts] - assert isinstance(data, list) - assert len(data) == 2 + # Verify response structure - should be a GenerateRequest + assert isinstance(data, dict) + assert "token_ids" in data + assert isinstance(data["token_ids"], list) + assert len(data["token_ids"]) > 0 - conversation, engine_prompts = data - - # Verify conversation - assert isinstance(conversation, list) - assert len(conversation) > 0 - assert conversation[0]["role"] == "user" - assert "empty string" in conversation[0]["content"] - - # Verify engine_prompts - assert isinstance(engine_prompts, list) - assert len(engine_prompts) > 0 - - first_prompt = engine_prompts[0] - assert "prompt_token_ids" in first_prompt - assert "prompt" in first_prompt - assert isinstance(first_prompt["prompt_token_ids"], list) - assert len(first_prompt["prompt_token_ids"]) > 0 - - # Verify chat template was applied (should have instruction markers) - assert "[INST]" in first_prompt["prompt"] - assert "[/INST]" in first_prompt["prompt"] - - # Verify token IDs are correctly preserved as integers - token_ids = first_prompt["prompt_token_ids"] + # Verify token IDs are integers and BOS token is present + token_ids = data["token_ids"] assert all(isinstance(tid, int) for tid in token_ids) - # Verify BOS token (usually 1 for LLaMA models) assert token_ids[0] == 1 @@ -131,15 +107,18 @@ async def test_completion_render_multiple_prompts(client): assert response.status_code == 200 data = response.json() - # Should return two prompts + # Should return two GenerateRequest items assert isinstance(data, list) assert len(data) == 2 - # Verify both prompts have required fields + # Verify both prompts have GenerateRequest fields for prompt in data: - assert "prompt_token_ids" in prompt - assert "prompt" in prompt - assert len(prompt["prompt_token_ids"]) > 0 + assert "token_ids" in prompt + assert "sampling_params" in prompt + assert "model" in prompt + assert "request_id" in prompt + assert len(prompt["token_ids"]) > 0 + assert prompt["request_id"].startswith("cmpl-") @pytest.mark.asyncio @@ -160,17 +139,49 @@ async def test_chat_completion_render_multi_turn(client): assert response.status_code == 200 data = response.json() - conversation, engine_prompts = data - - # Verify all messages preserved - assert len(conversation) == 3 - assert conversation[0]["role"] == "user" - assert conversation[1]["role"] == "assistant" - assert conversation[2]["role"] == "user" - # Verify tokenization occurred - assert len(engine_prompts) > 0 - assert len(engine_prompts[0]["prompt_token_ids"]) > 0 + assert isinstance(data, dict) + assert "token_ids" in data + assert isinstance(data["token_ids"], list) + assert len(data["token_ids"]) > 0 + + +@pytest.mark.asyncio +async def test_chat_completion_render_with_stream_true(client): + """Render accepts stream params but still returns JSON (non-streamed).""" + + response = await client.post( + "/v1/chat/completions/render", + json={ + "model": MODEL_NAME, + "stream": True, + "stream_options": { + "include_usage": True, + "continuous_usage_stats": True, + }, + "messages": [ + { + "role": "user", + "content": "Stream options should be accepted by /render.", + } + ], + }, + ) + + assert response.status_code == 200 + assert response.headers.get("content-type", "").startswith("application/json") + + data = response.json() + assert isinstance(data, dict) + assert "token_ids" in data + assert isinstance(data["token_ids"], list) + assert len(data["token_ids"]) > 0 + + # /render should preserve stream fields on the returned token-in request. + assert data.get("stream") is True + assert isinstance(data.get("stream_options"), dict) + assert data["stream_options"].get("include_usage") is True + assert data["stream_options"].get("continuous_usage_stats") is True @pytest.mark.asyncio @@ -224,3 +235,31 @@ async def test_completion_render_no_generation(client): assert response.status_code == 200 # Render should be fast (< 1 second) since no generation assert elapsed < 1.0 + + +@pytest.mark.asyncio +async def test_chat_completion_render_with_sampling_params(client): + """Verify sampling params are correctly returned by /render.""" + response = await client.post( + "/v1/chat/completions/render", + json={ + "model": MODEL_NAME, + "messages": [{"role": "user", "content": "Test sampling params"}], + "temperature": 0.123, + "top_p": 0.456, + "frequency_penalty": 1.1, + }, + ) + + assert response.status_code == 200 + data = response.json() + + assert "sampling_params" in data + sampling_params = data["sampling_params"] + + assert sampling_params.get("temperature") == 0.123 + assert sampling_params.get("top_p") == 0.456 + assert sampling_params.get("frequency_penalty") == 1.1 + + # Check that internal fields are not present + assert "_all_stop_token_ids" not in sampling_params diff --git a/tests/entrypoints/openai/cpu/test_render_multimodal.py b/tests/entrypoints/openai/cpu/test_render_multimodal.py new file mode 100644 index 000000000..459a965c0 --- /dev/null +++ b/tests/entrypoints/openai/cpu/test_render_multimodal.py @@ -0,0 +1,155 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Multimodal tests for the /render endpoints that expose prompt preprocessing.""" + +import httpx +import pytest +import pytest_asyncio + +from tests.utils import RemoteOpenAIServer +from vllm.multimodal.utils import encode_image_url + +VISION_MODEL_NAME = "Qwen/Qwen3-VL-2B-Instruct" + + +@pytest.fixture(scope="module") +def vision_server(): + """Vision-capable server used for multimodal /render tests.""" + + args = [ + "--enforce-eager", + "--max-model-len", + "100", + "--max-num-seqs", + "1", + "--limit-mm-per-prompt.image", + "1", + "--limit-mm-per-prompt.video", + "0", + ] + + env_overrides: dict[str, str] = {} + + with RemoteOpenAIServer( + VISION_MODEL_NAME, + args, + env_dict=env_overrides, + ) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def vision_client(vision_server): + async with httpx.AsyncClient( + base_url=vision_server.url_for(""), timeout=60.0 + ) as http_client: + yield http_client + + +@pytest.mark.asyncio +async def test_chat_completion_render_with_base64_image_url( + vision_client, + local_asset_server, +): + """Render a multimodal chat request and verify tokens are returned.""" + + image = local_asset_server.get_image_asset("RGBA_comp.png") + data_url = encode_image_url(image, format="PNG") + + assert data_url.startswith("data:image/") + assert ";base64," in data_url + + response = await vision_client.post( + "/v1/chat/completions/render", + json={ + "model": VISION_MODEL_NAME, + "messages": [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": data_url}}, + {"type": "text", "text": "What's in this image?"}, + ], + } + ], + }, + ) + + assert response.status_code == 200 + + data = response.json() + assert isinstance(data, dict) + assert "token_ids" in data + assert isinstance(data["token_ids"], list) + assert len(data["token_ids"]) > 0 + + # Verify multimodal features are populated + assert "features" in data + features = data["features"] + assert features is not None + + # mm_hashes: should have an "image" key with a list of hash strings + assert "mm_hashes" in features + assert "image" in features["mm_hashes"] + image_hashes = features["mm_hashes"]["image"] + assert isinstance(image_hashes, list) + assert len(image_hashes) > 0 + assert all(isinstance(h, str) for h in image_hashes) + + # mm_placeholders: should have an "image" key with offset/length dicts + assert "mm_placeholders" in features + assert "image" in features["mm_placeholders"] + image_placeholders = features["mm_placeholders"]["image"] + assert isinstance(image_placeholders, list) + assert len(image_placeholders) > 0 + for p in image_placeholders: + assert "offset" in p + assert "length" in p + assert isinstance(p["offset"], int) + assert isinstance(p["length"], int) + assert p["length"] > 0 + + +@pytest.mark.asyncio +async def test_tokenize_matches_render_for_multimodal_input( + vision_client, + local_asset_server, +): + """`/tokenize` should match `/v1/chat/completions/render` token output.""" + + image = local_asset_server.get_image_asset("RGBA_comp.png") + data_url = encode_image_url(image, format="PNG") + + messages = [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": data_url}}, + {"type": "text", "text": "What's in this image?"}, + ], + } + ] + + render_response = await vision_client.post( + "/v1/chat/completions/render", + json={ + "model": VISION_MODEL_NAME, + "messages": messages, + }, + ) + assert render_response.status_code == 200 + render_data = render_response.json() + + tokenize_response = await vision_client.post( + "/tokenize", + json={ + "model": VISION_MODEL_NAME, + "messages": messages, + }, + ) + assert tokenize_response.status_code == 200 + tokenize_data = tokenize_response.json() + + assert tokenize_data["tokens"] == render_data["token_ids"] + assert tokenize_data["count"] == len(render_data["token_ids"]) diff --git a/tests/entrypoints/openai/test_launch_render.py b/tests/entrypoints/openai/test_launch_render.py index 069e61f84..12e95e219 100644 --- a/tests/entrypoints/openai/test_launch_render.py +++ b/tests/entrypoints/openai/test_launch_render.py @@ -42,21 +42,12 @@ async def test_chat_render_basic(client): assert response.status_code == 200 data = response.json() - assert isinstance(data, list) - assert len(data) == 2 - - conversation, engine_prompts = data - - assert isinstance(conversation, list) - assert conversation[0]["role"] == "user" - - assert isinstance(engine_prompts, list) - assert len(engine_prompts) > 0 - first_prompt = engine_prompts[0] - assert "prompt_token_ids" in first_prompt - assert "prompt" in first_prompt - assert isinstance(first_prompt["prompt_token_ids"], list) - assert all(isinstance(t, int) for t in first_prompt["prompt_token_ids"]) + # Response should be a GenerateRequest dict + assert isinstance(data, dict) + assert "token_ids" in data + assert isinstance(data["token_ids"], list) + assert len(data["token_ids"]) > 0 + assert all(isinstance(t, int) for t in data["token_ids"]) @pytest.mark.asyncio @@ -74,14 +65,12 @@ async def test_chat_render_multi_turn(client): ) assert response.status_code == 200 - conversation, engine_prompts = response.json() + data = response.json() - assert len(conversation) == 3 - assert conversation[0]["role"] == "user" - assert conversation[1]["role"] == "assistant" - assert conversation[2]["role"] == "user" - assert len(engine_prompts) > 0 - assert len(engine_prompts[0]["prompt_token_ids"]) > 0 + assert isinstance(data, dict) + assert "token_ids" in data + assert isinstance(data["token_ids"], list) + assert len(data["token_ids"]) > 0 @pytest.mark.asyncio @@ -118,11 +107,13 @@ async def test_completion_render_basic(client): assert len(data) > 0 first_prompt = data[0] - assert "prompt_token_ids" in first_prompt - assert "prompt" in first_prompt - assert isinstance(first_prompt["prompt_token_ids"], list) - assert len(first_prompt["prompt_token_ids"]) > 0 - assert "Once upon a time" in first_prompt["prompt"] + assert "token_ids" in first_prompt + assert "sampling_params" in first_prompt + assert "model" in first_prompt + assert "request_id" in first_prompt + assert isinstance(first_prompt["token_ids"], list) + assert len(first_prompt["token_ids"]) > 0 + assert first_prompt["request_id"].startswith("cmpl-") @pytest.mark.asyncio @@ -142,9 +133,12 @@ async def test_completion_render_multiple_prompts(client): assert len(data) == 2 for prompt in data: - assert "prompt_token_ids" in prompt - assert "prompt" in prompt - assert len(prompt["prompt_token_ids"]) > 0 + assert "token_ids" in prompt + assert "sampling_params" in prompt + assert "model" in prompt + assert "request_id" in prompt + assert len(prompt["token_ids"]) > 0 + assert prompt["request_id"].startswith("cmpl-") @pytest.mark.asyncio diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 2487fe567..002ae62b8 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -368,6 +368,7 @@ async def init_app_state( request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, + default_chat_template_kwargs=args.default_chat_template_kwargs, trust_request_chat_template=args.trust_request_chat_template, ) @@ -457,6 +458,9 @@ async def init_render_app_state( state.openai_serving_models = model_registry + # Expose tokenization via the render handler (no engine required). + state.openai_serving_tokenization = state.openai_serving_render + state.vllm_config = vllm_config # Disable stats logging — there is no engine to poll. state.log_stats = False diff --git a/vllm/entrypoints/openai/engine/protocol.py b/vllm/entrypoints/openai/engine/protocol.py index 02dad6c1f..8f6cdb3e6 100644 --- a/vllm/entrypoints/openai/engine/protocol.py +++ b/vllm/entrypoints/openai/engine/protocol.py @@ -17,7 +17,6 @@ from pydantic import ( from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.logger import init_logger -from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid from vllm.utils.import_utils import resolve_obj_by_qualname @@ -269,53 +268,3 @@ class GenerationError(Exception): def __init__(self, message: str = "Internal server error"): super().__init__(message) self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR - - -####### Tokens IN <> Tokens OUT ####### -class GenerateRequest(BaseModel): - request_id: str = Field( - default_factory=random_uuid, - description=( - "The request_id related to this request. If the caller does " - "not set it, a random_uuid will be generated. This id is used " - "through out the inference process and return in response." - ), - ) - token_ids: list[int] - """The token ids to generate text from.""" - - # features: MultiModalFeatureSpec - # TODO (NickLucche): implement once Renderer work is completed - features: str | None = None - """The processed MM inputs for the model.""" - - sampling_params: SamplingParams - """The sampling parameters for the model.""" - - model: str | None = None - - stream: bool | None = False - stream_options: StreamOptions | None = None - cache_salt: str | None = Field( - default=None, - description=( - "If specified, the prefix cache will be salted with the provided " - "string to prevent an attacker to guess prompts in multi-user " - "environments. The salt should be random, protected from " - "access by 3rd parties, and long enough to be " - "unpredictable (e.g., 43 characters base64-encoded, corresponding " - "to 256 bit)." - ), - ) - priority: int = Field( - default=0, - description=( - "The priority of the request (lower means earlier handling; " - "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling." - ), - ) - kv_transfer_params: dict[str, Any] | None = Field( - default=None, - description="KVTransfer parameters used for disaggregated serving.", - ) diff --git a/vllm/entrypoints/openai/server_utils.py b/vllm/entrypoints/openai/server_utils.py index b21126472..1453d8083 100644 --- a/vllm/entrypoints/openai/server_utils.py +++ b/vllm/entrypoints/openai/server_utils.py @@ -11,7 +11,7 @@ from contextlib import asynccontextmanager from http import HTTPStatus import pydantic -from fastapi import FastAPI, HTTPException, Request, Response +from fastapi import FastAPI, HTTPException, Request from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse from starlette.concurrency import iterate_in_threadpool @@ -350,7 +350,8 @@ async def engine_error_handler( server=req.app.state.server, engine=req.app.state.engine_client, ) - return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) + err = create_error_response(exc) + return JSONResponse(err.model_dump(), status_code=err.error.code) async def exception_handler(req: Request, exc: Exception): diff --git a/vllm/entrypoints/serve/disagg/protocol.py b/vllm/entrypoints/serve/disagg/protocol.py index da13ea0cd..c4d510297 100644 --- a/vllm/entrypoints/serve/disagg/protocol.py +++ b/vllm/entrypoints/serve/disagg/protocol.py @@ -2,20 +2,55 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Any -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from vllm.config import ModelConfig from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionLogProbs -from vllm.entrypoints.openai.engine.protocol import ( - SamplingParams, - StreamOptions, -) +from vllm.entrypoints.openai.engine.protocol import StreamOptions from vllm.logprobs import Logprob from vllm.renderers import TokenizeParams +from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid - ####### Tokens IN <> Tokens OUT ####### + + +class PlaceholderRangeInfo(BaseModel): + """Serializable placeholder location for a single multi-modal item.""" + + offset: int + """Start index of the placeholder tokens in the prompt.""" + + length: int + """Number of placeholder tokens.""" + + # TODO: add ``is_embed: list[bool] | None`` once the /generate side + # consumes features — some models (e.g. Qwen-VL) use sparse + # placeholder masks that cannot be recomputed from offset+length alone. + + +class MultiModalFeatures(BaseModel): + """Lightweight multimodal metadata produced by the render step. + + Carries hashes (for cache lookup / identification) and placeholder + positions so the downstream ``/generate`` service knows *where* in + the token sequence each multimodal item lives. + + .. note:: Phase 1 — metadata only. + Phase 2 should add ``mm_kwargs`` (processed tensor data) using a + binary transport so the ``/generate`` side can skip re-processing. + The ``/generate`` endpoint must also be updated to inject these + features into ``ProcessorInputs`` before passing to + ``InputProcessor.process_inputs``. + """ + + mm_hashes: dict[str, list[str]] + """Per-modality item hashes, e.g. ``{"image": ["abc", "def"]}``.""" + + mm_placeholders: dict[str, list[PlaceholderRangeInfo]] + """Per-modality placeholder ranges in the token sequence.""" + + class GenerateRequest(BaseModel): request_id: str = Field( default_factory=lambda: f"{random_uuid()}", @@ -28,10 +63,15 @@ class GenerateRequest(BaseModel): token_ids: list[int] """The token ids to generate text from.""" - # features: MultiModalFeatureSpec - # TODO (NickLucche): implement once Renderer work is completed - features: str | None = None - """The processed MM inputs for the model.""" + @field_validator("token_ids") + @classmethod + def validate_token_ids(cls, v: list[int]) -> list[int]: + if any(t < 0 for t in v): + raise ValueError("token_ids must not contain negative values") + return v + + features: MultiModalFeatures | None = None + """Multimodal hashes and placeholder positions (populated for MM inputs).""" sampling_params: SamplingParams """The sampling parameters for the model.""" diff --git a/vllm/entrypoints/serve/render/api_router.py b/vllm/entrypoints/serve/render/api_router.py index dd782a97f..d8e613070 100644 --- a/vllm/entrypoints/serve/render/api_router.py +++ b/vllm/entrypoints/serve/render/api_router.py @@ -9,6 +9,7 @@ from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionReque from vllm.entrypoints.openai.completion.protocol import CompletionRequest from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.utils import validate_json_request +from vllm.entrypoints.serve.disagg.protocol import GenerateRequest from vllm.entrypoints.serve.render.serving import OpenAIServingRender from vllm.logger import init_logger @@ -24,7 +25,7 @@ def render(request: Request) -> OpenAIServingRender | None: @router.post( "/v1/chat/completions/render", dependencies=[Depends(validate_json_request)], - response_model=list, + response_model=GenerateRequest, responses={ HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, @@ -44,13 +45,13 @@ async def render_chat_completion(request: ChatCompletionRequest, raw_request: Re if isinstance(result, ErrorResponse): return JSONResponse(content=result.model_dump(), status_code=result.error.code) - return JSONResponse(content=result) + return JSONResponse(content=result.model_dump()) @router.post( "/v1/completions/render", dependencies=[Depends(validate_json_request)], - response_model=list, + response_model=list[GenerateRequest], responses={ HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, @@ -67,7 +68,7 @@ async def render_completion(request: CompletionRequest, raw_request: Request): if isinstance(result, ErrorResponse): return JSONResponse(content=result.model_dump(), status_code=result.error.code) - return JSONResponse(content=result) + return JSONResponse(content=[item.model_dump() for item in result]) def attach_router(app: FastAPI) -> None: diff --git a/vllm/entrypoints/serve/render/serving.py b/vllm/entrypoints/serve/render/serving.py index 0ff737824..86533447c 100644 --- a/vllm/entrypoints/serve/render/serving.py +++ b/vllm/entrypoints/serve/render/serving.py @@ -24,14 +24,29 @@ from vllm.entrypoints.openai.parser.harmony_utils import ( parse_chat_inputs_to_harmony_messages, render_for_completion, ) -from vllm.entrypoints.utils import create_error_response +from vllm.entrypoints.serve.disagg.protocol import ( + GenerateRequest, + MultiModalFeatures, + PlaceholderRangeInfo, +) +from vllm.entrypoints.utils import ( + create_error_response, + get_max_tokens, +) from vllm.inputs.data import ProcessorInputs, PromptType, SingletonPrompt, TokensPrompt from vllm.logger import init_logger +from vllm.multimodal.inputs import MultiModalHashes, MultiModalPlaceholderDict from vllm.parser import ParserManager from vllm.renderers import BaseRenderer, merge_kwargs -from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq +from vllm.renderers.inputs.preprocess import ( + extract_prompt_components, + extract_prompt_len, + parse_model_prompt, + prompt_to_seq, +) from vllm.tokenizers import TokenizerLike from vllm.tool_parsers import ToolParser +from vllm.utils import random_uuid from vllm.utils.mistral import is_mistral_tokenizer from vllm.utils.mistral import mt as _mt @@ -83,10 +98,18 @@ class OpenAIServingRender: self.supports_browsing = False self.supports_code_interpreter = False + self.default_sampling_params = model_config.get_diff_sampling_param() + mc = model_config + self.override_max_tokens = ( + self.default_sampling_params.get("max_tokens") + if mc.generation_config not in ("auto", "vllm") + else getattr(mc, "override_generation_config", {}).get("max_new_tokens") + ) + async def render_chat_request( self, request: ChatCompletionRequest, - ) -> tuple[list[ConversationMessage], list[ProcessorInputs]] | ErrorResponse: + ) -> GenerateRequest | ErrorResponse: """Validate the model and preprocess a chat completion request. This is the authoritative implementation used directly by the @@ -96,7 +119,56 @@ class OpenAIServingRender: if error_check_ret is not None: logger.error("Error with model %s", error_check_ret) return error_check_ret - return await self.render_chat(request) + + if request.use_beam_search: + return self.create_error_response( + "Beam search is not supported by the render endpoint" + ) + + result = await self.render_chat(request) + if isinstance(result, ErrorResponse): + return result + + _, engine_prompts = result + + if len(engine_prompts) != 1: + return self.create_error_response( + f"Expected exactly 1 engine prompt, got {len(engine_prompts)}" + ) + + engine_prompt = engine_prompts[0] + + prompt_components = extract_prompt_components(self.model_config, engine_prompt) + token_ids = prompt_components.token_ids + if not token_ids: + return self.create_error_response("No token_ids rendered") + token_ids = list(token_ids) + + input_length = extract_prompt_len(self.model_config, engine_prompt) + max_tokens = get_max_tokens( + self.model_config.max_model_len, + request.max_completion_tokens + if request.max_completion_tokens is not None + else request.max_tokens, + input_length, + self.default_sampling_params, + self.override_max_tokens, + ) + params = request.to_sampling_params(max_tokens, self.default_sampling_params) + + request_id = f"chatcmpl-{random_uuid()}" + + return GenerateRequest( + request_id=request_id, + token_ids=token_ids, + features=self._extract_mm_features(engine_prompt), + sampling_params=params, + model=request.model, + stream=bool(request.stream), + stream_options=(request.stream_options if request.stream else None), + cache_salt=request.cache_salt, + priority=request.priority, + ) async def render_chat( self, @@ -183,7 +255,7 @@ class OpenAIServingRender: async def render_completion_request( self, request: CompletionRequest, - ) -> list[ProcessorInputs] | ErrorResponse: + ) -> list[GenerateRequest] | ErrorResponse: """Validate the model and preprocess a completion request. This is the authoritative implementation used directly by the @@ -192,7 +264,48 @@ class OpenAIServingRender: error_check_ret = await self._check_model(request) if error_check_ret is not None: return error_check_ret - return await self.render_completion(request) + result = await self.render_completion(request) + if isinstance(result, ErrorResponse): + return result + generate_requests: list[GenerateRequest] = [] + for engine_prompt in result: + prompt_components = extract_prompt_components( + self.model_config, engine_prompt + ) + token_ids = prompt_components.token_ids + if not token_ids: + return self.create_error_response("No token_ids rendered") + token_ids = list(token_ids) + + input_length = extract_prompt_len(self.model_config, engine_prompt) + max_tokens = get_max_tokens( + self.model_config.max_model_len, + request.max_tokens, + input_length, + self.default_sampling_params, + self.override_max_tokens, + ) + params = request.to_sampling_params( + max_tokens, self.default_sampling_params + ) + + request_id = f"cmpl-{random_uuid()}" + + generate_requests.append( + GenerateRequest( + request_id=request_id, + token_ids=token_ids, + features=self._extract_mm_features(engine_prompt), + sampling_params=params, + model=request.model, + stream=bool(request.stream), + stream_options=(request.stream_options if request.stream else None), + cache_salt=request.cache_salt, + priority=request.priority, + ) + ) + + return generate_requests async def render_completion( self, @@ -223,6 +336,33 @@ class OpenAIServingRender: return engine_prompts + @staticmethod + def _extract_mm_features( + engine_prompt: ProcessorInputs, + ) -> MultiModalFeatures | None: + """Extract multimodal metadata from a rendered engine prompt. + + Returns ``None`` for text-only prompts. + """ + if engine_prompt.get("type") != "multimodal": + return None + + # At this point engine_prompt is a MultiModalInputs TypedDict. + mm_hashes: MultiModalHashes = engine_prompt["mm_hashes"] # type: ignore[typeddict-item] + raw_placeholders: MultiModalPlaceholderDict = engine_prompt["mm_placeholders"] # type: ignore[typeddict-item] + + mm_placeholders = { + modality: [ + PlaceholderRangeInfo(offset=p.offset, length=p.length) for p in ranges + ] + for modality, ranges in raw_placeholders.items() + } + + return MultiModalFeatures( + mm_hashes=mm_hashes, + mm_placeholders=mm_placeholders, + ) + def _make_request_with_harmony( self, request: ChatCompletionRequest, diff --git a/vllm/entrypoints/serve/tokenize/serving.py b/vllm/entrypoints/serve/tokenize/serving.py index 77ce2787c..233674aff 100644 --- a/vllm/entrypoints/serve/tokenize/serving.py +++ b/vllm/entrypoints/serve/tokenize/serving.py @@ -35,6 +35,7 @@ class OpenAIServingTokenization(OpenAIServing): request_logger: RequestLogger | None, chat_template: str | None, chat_template_content_format: ChatTemplateContentFormatOption, + default_chat_template_kwargs: dict[str, Any] | None = None, trust_request_chat_template: bool = False, ) -> None: super().__init__( @@ -45,6 +46,7 @@ class OpenAIServingTokenization(OpenAIServing): self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format + self.default_chat_template_kwargs = default_chat_template_kwargs or {} self.trust_request_chat_template = trust_request_chat_template async def create_tokenize( @@ -79,7 +81,7 @@ class OpenAIServingTokenization(OpenAIServing): request.messages, default_template=self.chat_template, default_template_content_format=self.chat_template_content_format, - default_template_kwargs=None, + default_template_kwargs=self.default_chat_template_kwargs, tool_dicts=tool_dicts, ) else: @@ -98,8 +100,9 @@ class OpenAIServingTokenization(OpenAIServing): lora_request=lora_request, ) - if "prompt_token_ids" in engine_prompt: - input_ids.extend(engine_prompt["prompt_token_ids"]) # type: ignore[typeddict-item] + prompt_components = self._extract_prompt_components(engine_prompt) + if prompt_components.token_ids is not None: + input_ids.extend(prompt_components.token_ids) token_strs = None if request.return_token_strs: diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index 0c03de71c..be880bec2 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -8,7 +8,7 @@ from collections.abc import Callable, Sequence from functools import partial from inspect import isclass from types import FunctionType -from typing import Any, TypeAlias, get_type_hints +from typing import Any, ClassVar, TypeAlias, cast, get_type_hints import cloudpickle import msgspec @@ -460,6 +460,19 @@ def run_method( class PydanticMsgspecMixin: + """Make a ``msgspec.Struct`` compatible with Pydantic for both + **validation** (JSON/dict -> Struct) and **serialization** + (Struct -> JSON-safe dict). + + Subclasses may set ``__pydantic_msgspec_exclude__`` (a ``set[str]``) + to list non-underscore field names that should also be stripped from + serialized output. Fields whose names start with ``_`` are always + excluded automatically. + """ + + # Subclasses can override to exclude additional public-but-internal keys. + __pydantic_msgspec_exclude__: ClassVar[set[str]] = set() + @classmethod def __get_pydantic_core_schema__( cls, source_type: Any, handler: GetCoreSchemaHandler @@ -476,32 +489,62 @@ class PydanticMsgspecMixin: # Build the Pydantic typed_dict_field for each msgspec field fields = {} for name, hint in type_hints.items(): + if name not in msgspec_fields: + # Skip ClassVar and other non-struct annotations. + continue + # Skip private fields — they are excluded from serialization + # and should not appear in the generated JSON/OpenAPI schema. + if name.startswith("_"): + continue msgspec_field = msgspec_fields[name] # typed_dict_field using the handler to get the schema field_schema = handler(hint) # Add default value to the schema. + # Mark fields with defaults as not required so the generated + # JSON Schema stays consistent with ``omit_defaults=True`` + # serialization (fields at their default value may be absent). if msgspec_field.default_factory is not msgspec.NODEFAULT: wrapped_schema = core_schema.with_default_schema( schema=field_schema, default_factory=msgspec_field.default_factory, ) - fields[name] = core_schema.typed_dict_field(wrapped_schema) + fields[name] = core_schema.typed_dict_field( + wrapped_schema, required=False + ) elif msgspec_field.default is not msgspec.NODEFAULT: wrapped_schema = core_schema.with_default_schema( schema=field_schema, default=msgspec_field.default, ) - fields[name] = core_schema.typed_dict_field(wrapped_schema) + fields[name] = core_schema.typed_dict_field( + wrapped_schema, required=False + ) else: # No default, so Pydantic will treat it as required fields[name] = core_schema.typed_dict_field(field_schema) - return core_schema.no_info_after_validator_function( + typed_dict_then_convert = core_schema.no_info_after_validator_function( cls._validate_msgspec, core_schema.typed_dict_schema(fields), ) + # Build a serializer that strips private / excluded fields. + serializer = core_schema.plain_serializer_function_ser_schema( + cls._serialize_msgspec, + info_arg=False, + ) + + # Accept either an already-constructed msgspec.Struct instance or a + # JSON/dict-like payload. + return core_schema.union_schema( + [ + core_schema.is_instance_schema(source_type), + typed_dict_then_convert, + ], + serialization=serializer, + ) + @classmethod def _validate_msgspec(cls, value: Any) -> Any: """Validate and convert input to msgspec.Struct instance.""" @@ -510,3 +553,25 @@ class PydanticMsgspecMixin: if isinstance(value, dict): return cls(**value) return msgspec.convert(value, type=cls) + + @staticmethod + def _serialize_msgspec(value: Any) -> Any: + """Serialize a msgspec.Struct to a JSON-compatible dict, stripping + private (``_``-prefixed) and explicitly excluded fields. + + Uses ``msgspec.to_builtins`` which respects ``omit_defaults=True``, + so only fields that differ from their declared defaults are included. + """ + raw = msgspec.to_builtins(value) + if not isinstance(raw, dict): + return raw + + exclude: set[str] = cast( + set[str], + getattr(type(value), "__pydantic_msgspec_exclude__", set()), + ) + for key in list(raw): + if key.startswith("_") or key in exclude: + del raw[key] + + return raw