[Frontend] Add render endpoints for prompt preprocessing (#32473)
Signed-off-by: HyunKyun Moon <mhg5303@gmail.com> Signed-off-by: Hyunkyun Moon <mhg5303@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
226
tests/entrypoints/openai/test_render.py
Normal file
226
tests/entrypoints/openai/test_render.py
Normal file
@@ -0,0 +1,226 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""Tests for the /render endpoints that expose prompt preprocessing."""
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args: list[str] = []
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with httpx.AsyncClient(
|
||||
base_url=server.url_for(""), timeout=30.0
|
||||
) as http_client:
|
||||
yield http_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_render_basic(client):
|
||||
"""Test basic completion render endpoint."""
|
||||
# Make request to render endpoint
|
||||
response = await client.post(
|
||||
"/v1/completions/render",
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"prompt": "When should a chat-completions handler return an empty string?",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Verify response structure
|
||||
assert isinstance(data, list)
|
||||
assert len(data) > 0
|
||||
|
||||
# Verify first prompt
|
||||
first_prompt = data[0]
|
||||
assert "prompt_token_ids" in first_prompt
|
||||
assert "prompt" in first_prompt
|
||||
assert isinstance(first_prompt["prompt_token_ids"], list)
|
||||
assert len(first_prompt["prompt_token_ids"]) > 0
|
||||
assert isinstance(first_prompt["prompt"], str)
|
||||
|
||||
# Verify prompt text is preserved
|
||||
assert (
|
||||
"When should a chat-completions handler return an empty string?"
|
||||
in first_prompt["prompt"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_render_basic(client):
|
||||
"""Test basic chat completion render endpoint."""
|
||||
# Make request to render endpoint
|
||||
response = await client.post(
|
||||
"/v1/chat/completions/render",
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Returning an empty string for the prompt may be confusing."
|
||||
),
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Verify response structure - should be [conversation, engine_prompts]
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 2
|
||||
|
||||
conversation, engine_prompts = data
|
||||
|
||||
# Verify conversation
|
||||
assert isinstance(conversation, list)
|
||||
assert len(conversation) > 0
|
||||
assert conversation[0]["role"] == "user"
|
||||
assert "empty string" in conversation[0]["content"]
|
||||
|
||||
# Verify engine_prompts
|
||||
assert isinstance(engine_prompts, list)
|
||||
assert len(engine_prompts) > 0
|
||||
|
||||
first_prompt = engine_prompts[0]
|
||||
assert "prompt_token_ids" in first_prompt
|
||||
assert "prompt" in first_prompt
|
||||
assert isinstance(first_prompt["prompt_token_ids"], list)
|
||||
assert len(first_prompt["prompt_token_ids"]) > 0
|
||||
|
||||
# Verify chat template was applied (should have instruction markers)
|
||||
assert "[INST]" in first_prompt["prompt"]
|
||||
assert "[/INST]" in first_prompt["prompt"]
|
||||
|
||||
# Verify token IDs are correctly preserved as integers
|
||||
token_ids = first_prompt["prompt_token_ids"]
|
||||
assert all(isinstance(tid, int) for tid in token_ids)
|
||||
# Verify BOS token (usually 1 for LLaMA models)
|
||||
assert token_ids[0] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_render_multiple_prompts(client):
|
||||
"""Test completion render with multiple prompts."""
|
||||
response = await client.post(
|
||||
"/v1/completions/render",
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"prompt": ["Hello world", "Goodbye world"],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Should return two prompts
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 2
|
||||
|
||||
# Verify both prompts have required fields
|
||||
for prompt in data:
|
||||
assert "prompt_token_ids" in prompt
|
||||
assert "prompt" in prompt
|
||||
assert len(prompt["prompt_token_ids"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_render_multi_turn(client):
|
||||
"""Test chat completion render with multi-turn conversation."""
|
||||
response = await client.post(
|
||||
"/v1/chat/completions/render",
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
conversation, engine_prompts = data
|
||||
|
||||
# Verify all messages preserved
|
||||
assert len(conversation) == 3
|
||||
assert conversation[0]["role"] == "user"
|
||||
assert conversation[1]["role"] == "assistant"
|
||||
assert conversation[2]["role"] == "user"
|
||||
|
||||
# Verify tokenization occurred
|
||||
assert len(engine_prompts) > 0
|
||||
assert len(engine_prompts[0]["prompt_token_ids"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_render_error_invalid_model(client):
|
||||
"""Test completion render with invalid model returns error."""
|
||||
response = await client.post(
|
||||
"/v1/completions/render",
|
||||
json={
|
||||
"model": "invalid-model-name",
|
||||
"prompt": "Hello",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert "error" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_render_error_invalid_model(client):
|
||||
"""Test chat completion render with invalid model returns error."""
|
||||
response = await client.post(
|
||||
"/v1/chat/completions/render",
|
||||
json={
|
||||
"model": "invalid-model-name",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert "error" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_completion_render_no_generation(client):
|
||||
"""Verify render endpoint does not generate text."""
|
||||
# This test verifies that calling render is fast (no generation)
|
||||
import time
|
||||
|
||||
start = time.perf_counter()
|
||||
response = await client.post(
|
||||
"/v1/completions/render",
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"prompt": "Tell me a very long story about " * 10,
|
||||
},
|
||||
)
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
assert response.status_code == 200
|
||||
# Render should be fast (< 1 second) since no generation
|
||||
assert elapsed < 1.0
|
||||
@@ -73,5 +73,36 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/chat/completions/render",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
response_model=list,
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def render_chat_completion(request: ChatCompletionRequest, raw_request: Request):
|
||||
"""Render chat completion request and return conversation and engine
|
||||
prompts without generating."""
|
||||
handler = chat(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
message="The model does not support Chat Completions API"
|
||||
)
|
||||
|
||||
try:
|
||||
result = await handler.render_chat_request(request)
|
||||
except Exception as e:
|
||||
return handler.create_error_response(e)
|
||||
|
||||
if isinstance(result, ErrorResponse):
|
||||
return JSONResponse(content=result.model_dump(), status_code=result.error.code)
|
||||
|
||||
return JSONResponse(content=result)
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
app.include_router(router)
|
||||
|
||||
@@ -224,17 +224,16 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# Log but don't fail server startup if warmup fails
|
||||
logger.exception("Chat template warmup failed")
|
||||
|
||||
async def create_chat_completion(
|
||||
async def render_chat_request(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> AsyncGenerator[str, None] | ChatCompletionResponse | ErrorResponse:
|
||||
) -> tuple[list[ConversationMessage], list[Any]] | ErrorResponse:
|
||||
"""
|
||||
Chat Completion API similar to OpenAI's API.
|
||||
render chat request by validating and preprocessing inputs.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/chat/create
|
||||
for the API specification. This API mimics the OpenAI
|
||||
Chat Completion API.
|
||||
Returns:
|
||||
A tuple of (conversation, engine_prompts) on success,
|
||||
or an ErrorResponse on failure.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
@@ -248,12 +247,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(
|
||||
request, supports_default_mm_loras=True
|
||||
)
|
||||
|
||||
model_name = self.models.model_name(lora_request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
|
||||
tool_parser = self.tool_parser
|
||||
@@ -336,7 +329,27 @@ class OpenAIServingChat(OpenAIServing):
|
||||
)
|
||||
except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(f"{e} {e.__cause__}")
|
||||
return self.create_error_response(e)
|
||||
|
||||
return conversation, engine_prompts
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> AsyncGenerator[str, None] | ChatCompletionResponse | ErrorResponse:
|
||||
"""
|
||||
Chat Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/chat/create
|
||||
for the API specification. This API mimics the OpenAI
|
||||
Chat Completion API.
|
||||
"""
|
||||
result = await self.render_chat_request(request)
|
||||
if isinstance(result, ErrorResponse):
|
||||
return result
|
||||
|
||||
conversation, engine_prompts = result
|
||||
|
||||
request_id = (
|
||||
f"chatcmpl-{self._base_request_id(raw_request, request.request_id)}"
|
||||
@@ -346,6 +359,18 @@ class OpenAIServingChat(OpenAIServing):
|
||||
if raw_request:
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(
|
||||
request, supports_default_mm_loras=True
|
||||
)
|
||||
|
||||
model_name = self.models.model_name(lora_request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
except (ValueError, TypeError, RuntimeError) as e:
|
||||
logger.exception("Error preparing request components")
|
||||
return self.create_error_response(e)
|
||||
|
||||
# Extract data_parallel_rank from header (router can inject it)
|
||||
data_parallel_rank = self._get_data_parallel_rank(raw_request)
|
||||
|
||||
|
||||
@@ -72,5 +72,35 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/completions/render",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
response_model=list,
|
||||
responses={
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
async def render_completion(request: CompletionRequest, raw_request: Request):
|
||||
"""render completion request and return engine prompts without generating."""
|
||||
handler = completion(raw_request)
|
||||
if handler is None:
|
||||
base_server = raw_request.app.state.openai_serving_tokenization
|
||||
return base_server.create_error_response(
|
||||
message="The model does not support Completions API"
|
||||
)
|
||||
|
||||
try:
|
||||
result = await handler.render_completion_request(request)
|
||||
except Exception as e:
|
||||
return handler.create_error_response(e)
|
||||
|
||||
if isinstance(result, ErrorResponse):
|
||||
return JSONResponse(content=result.model_dump(), status_code=result.error.code)
|
||||
|
||||
return JSONResponse(content=result)
|
||||
|
||||
|
||||
def attach_router(app: FastAPI):
|
||||
app.include_router(router)
|
||||
|
||||
@@ -83,19 +83,16 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
self.default_sampling_params,
|
||||
)
|
||||
|
||||
async def create_completion(
|
||||
async def render_completion_request(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> AsyncGenerator[str, None] | CompletionResponse | ErrorResponse:
|
||||
"""Completion API similar to OpenAI's API.
|
||||
) -> list[TokensPrompt | EmbedsPrompt] | ErrorResponse:
|
||||
"""
|
||||
render completion request by validating and preprocessing inputs.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/completions/create
|
||||
for the API specification. This API mimics the OpenAI Completion API.
|
||||
|
||||
NOTE: Currently we do not support the following feature:
|
||||
- suffix (the language models we currently support do not support
|
||||
suffix)
|
||||
Returns:
|
||||
A list of engine_prompts on success,
|
||||
or an ErrorResponse on failure.
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
@@ -119,6 +116,44 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
"prompt_logprobs is not compatible with prompt embeds."
|
||||
)
|
||||
|
||||
try:
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
renderer = self._get_renderer(tokenizer)
|
||||
|
||||
engine_prompts = await renderer.render_prompt_and_embeds(
|
||||
prompt_or_prompts=request.prompt,
|
||||
prompt_embeds=request.prompt_embeds,
|
||||
config=self._build_render_config(request),
|
||||
)
|
||||
except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(e)
|
||||
|
||||
return engine_prompts
|
||||
|
||||
async def create_completion(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
raw_request: Request | None = None,
|
||||
) -> AsyncGenerator[str, None] | CompletionResponse | ErrorResponse:
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/completions/create
|
||||
for the API specification. This API mimics the OpenAI Completion API.
|
||||
|
||||
NOTE: Currently we do not support the following feature:
|
||||
- suffix (the language models we currently support do not support
|
||||
suffix)
|
||||
"""
|
||||
result = await self.render_completion_request(request)
|
||||
if isinstance(result, ErrorResponse):
|
||||
return result
|
||||
|
||||
engine_prompts = result
|
||||
|
||||
request_id = f"cmpl-{self._base_request_id(raw_request, request.request_id)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
@@ -133,24 +168,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
renderer = self._get_renderer(tokenizer)
|
||||
|
||||
engine_prompts = await renderer.render_prompt_and_embeds(
|
||||
prompt_or_prompts=request.prompt,
|
||||
prompt_embeds=request.prompt_embeds,
|
||||
config=self._build_render_config(request),
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(e)
|
||||
except TypeError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(e)
|
||||
except RuntimeError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(e)
|
||||
except jinja2.TemplateError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
except (ValueError, TypeError, RuntimeError) as e:
|
||||
logger.exception("Error preparing request components")
|
||||
return self.create_error_response(e)
|
||||
|
||||
# Extract data_parallel_rank from header (router can inject it)
|
||||
|
||||
Reference in New Issue
Block a user