[Renderer] Move Processor out of AsyncLLM (#24138)

Signed-off-by: Yang <lymailforjob@gmail.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Yang Liu
2025-10-03 04:29:45 -07:00
committed by GitHub
parent 5f2cacdb1e
commit 812b7f54a8
7 changed files with 215 additions and 125 deletions

View File

@@ -7,7 +7,7 @@ import asyncio
from contextlib import suppress
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock
import pytest
import pytest_asyncio
@@ -230,6 +230,7 @@ class MockHFConfig:
@dataclass
class MockModelConfig:
task = "generate"
runner_type = "generate"
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
@@ -244,11 +245,33 @@ class MockModelConfig:
encoder_config = None
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
skip_tokenizer_init = False
def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
def _build_serving_chat(engine: AsyncLLM,
model_config: MockModelConfig) -> OpenAIServingChat:
models = OpenAIServingModels(engine_client=engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=model_config)
serving_chat = OpenAIServingChat(engine,
model_config,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
async def _fake_process_inputs(request_id, engine_prompt, sampling_params,
*, lora_request, trace_headers, priority):
return dict(engine_prompt), {}
serving_chat._process_inputs = AsyncMock(side_effect=_fake_process_inputs)
return serving_chat
@dataclass
class MockEngine:
@@ -282,16 +305,7 @@ async def test_serving_chat_returns_correct_model_name():
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=MockModelConfig())
serving_chat = OpenAIServingChat(mock_engine,
MockModelConfig(),
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
serving_chat = _build_serving_chat(mock_engine, MockModelConfig())
messages = [{"role": "user", "content": "what is 1+1?"}]
async def return_model_name(*args):
@@ -318,16 +332,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=MockModelConfig())
serving_chat = OpenAIServingChat(mock_engine,
MockModelConfig(),
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
serving_chat = _build_serving_chat(mock_engine, MockModelConfig())
req = ChatCompletionRequest(
model=MODEL_NAME,
@@ -361,16 +366,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine.errored = False
# Initialize the serving chat
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config)
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
# Test Case 1: No max_tokens specified in request
req = ChatCompletionRequest(
@@ -415,16 +411,7 @@ async def test_serving_chat_should_set_correct_max_tokens():
mock_engine.errored = False
# Initialize the serving chat
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config)
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
# Test case 1: No max_tokens specified, defaults to context_window
req = ChatCompletionRequest(
@@ -471,16 +458,7 @@ async def test_serving_chat_could_load_correct_generation_config():
mock_engine.errored = False
# Initialize the serving chat
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config)
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
req = ChatCompletionRequest(
model=MODEL_NAME,
@@ -525,17 +503,7 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
# Initialize the serving chat
models = OpenAIServingModels(engine_client=mock_engine,
base_model_paths=BASE_MODEL_PATHS,
model_config=mock_model_config)
serving_chat = OpenAIServingChat(mock_engine,
mock_model_config,
models,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
request_logger=None)
serving_chat = _build_serving_chat(mock_engine, mock_model_config)
# Test cache_salt
req = ChatCompletionRequest(
@@ -549,10 +517,12 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
# By default, cache_salt in the engine prompt is not set
with suppress(Exception):
await serving_chat.create_chat_completion(req)
assert "cache_salt" not in mock_engine.generate.call_args.args[0]
engine_prompt = serving_chat._process_inputs.await_args_list[0].args[1]
assert "cache_salt" not in engine_prompt
# Test with certain cache_salt
req.cache_salt = "test_salt"
with suppress(Exception):
await serving_chat.create_chat_completion(req)
assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt"
engine_prompt = serving_chat._process_inputs.await_args_list[1].args[1]
assert engine_prompt.get("cache_salt") == "test_salt"