[Frontend] Introduce Renderer for processing chat messages (using ModelConfig) (#30200)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -11,7 +11,7 @@ import pytest_asyncio
|
||||
from openai import OpenAI
|
||||
|
||||
from vllm._aiter_ops import is_aiter_found_and_supported
|
||||
from vllm.config.multimodal import MultiModalConfig
|
||||
from vllm.config import MultiModalConfig
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
@@ -23,8 +23,13 @@ from vllm.entrypoints.openai.engine.protocol import (
|
||||
)
|
||||
from vllm.entrypoints.openai.models.serving import BaseModelPath, OpenAIServingModels
|
||||
from vllm.entrypoints.openai.parser.harmony_utils import get_encoding
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.renderers.hf import HfRenderer
|
||||
from vllm.renderers.mistral import MistralRenderer
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.tokenizers.registry import tokenizer_args_from_config
|
||||
from vllm.tool_parsers import ToolParserManager
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
@@ -103,15 +108,16 @@ def gptoss_server(default_server_args: list[str]):
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def gptoss_speculative_server(default_server_args: list[str]):
|
||||
attention_backend = (
|
||||
"TRITON_ATTN"
|
||||
if not is_aiter_found_and_supported()
|
||||
else "ROCM_AITER_UNIFIED_ATTN"
|
||||
)
|
||||
server_args = default_server_args + [
|
||||
"--speculative-config",
|
||||
f'{{"model": "{GPT_OSS_SPECULATOR_NAME}", '
|
||||
f'"method": "eagle3", "num_speculative_tokens": 3}}',
|
||||
f"--attention-backend={
|
||||
'TRITON_ATTN'
|
||||
if not is_aiter_found_and_supported()
|
||||
else 'ROCM_AITER_UNIFIED_ATTN'
|
||||
}",
|
||||
f"--attention-backend={attention_backend}",
|
||||
]
|
||||
# gpt-oss requires AITER unified attention on ROCm
|
||||
# TODO: Remove after fixing TRITON_ATTN issue on ROCm
|
||||
@@ -520,12 +526,21 @@ class MockModelConfig:
|
||||
encoder_config = None
|
||||
generation_config: str = "auto"
|
||||
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
skip_tokenizer_init = False
|
||||
skip_tokenizer_init: bool = False
|
||||
|
||||
def get_diff_sampling_param(self):
|
||||
return self.diff_sampling_param or {}
|
||||
|
||||
|
||||
def _build_renderer(model_config: MockModelConfig):
|
||||
_, tokenizer_name, _, kwargs = tokenizer_args_from_config(model_config)
|
||||
|
||||
return HfRenderer(
|
||||
model_config,
|
||||
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
||||
)
|
||||
|
||||
|
||||
def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
|
||||
models = OpenAIServingModels(
|
||||
engine_client=engine,
|
||||
@@ -561,6 +576,7 @@ class MockEngine:
|
||||
model_config: MockModelConfig = field(default_factory=MockModelConfig)
|
||||
input_processor: MagicMock = field(default_factory=MagicMock)
|
||||
io_processor: MagicMock = field(default_factory=MagicMock)
|
||||
renderer: MagicMock = field(default_factory=MagicMock)
|
||||
|
||||
|
||||
async def _async_serving_chat_init():
|
||||
@@ -586,11 +602,11 @@ def test_async_serving_chat_init():
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_chat_returns_correct_model_name():
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
messages = [{"role": "user", "content": "what is 1+1?"}]
|
||||
@@ -616,11 +632,11 @@ async def test_serving_chat_returns_correct_model_name():
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_chat_should_set_correct_max_tokens():
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
@@ -649,11 +665,11 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
|
||||
# Reinitialize the engine with new settings
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = mock_model_config
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
# Initialize the serving chat
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
@@ -694,11 +710,11 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
|
||||
# Reinitialize the engine with new settings
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = mock_model_config
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
# Initialize the serving chat
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
@@ -732,42 +748,32 @@ async def test_serving_chat_should_set_correct_max_tokens():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_chat_mistral_token_ids_prompt_is_validated(monkeypatch_module):
|
||||
async def test_serving_chat_mistral_token_ids_prompt_is_validated():
|
||||
"""Regression test: when the Mistral tokenizer path returns token IDs
|
||||
directly, we must still apply input length + max_tokens validation.
|
||||
"""
|
||||
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.model_config = MockModelConfig(skip_tokenizer_init=True)
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
class DummyMistralTokenizer:
|
||||
def decode(self, token_ids):
|
||||
# Only used for logging/validation error messages.
|
||||
return "dummy"
|
||||
|
||||
dummy_tokenizer = DummyMistralTokenizer()
|
||||
mock_engine.get_tokenizer.return_value = dummy_tokenizer
|
||||
|
||||
# Patch the OpenAI engine serving module to treat our dummy tokenizer
|
||||
# as a MistralTokenizer. This forces the code path where chat template
|
||||
# rendering can return a list[int] (token IDs).
|
||||
import vllm.entrypoints.openai.engine.serving as engine_serving
|
||||
|
||||
monkeypatch_module.setattr(
|
||||
engine_serving, "MistralTokenizer", DummyMistralTokenizer
|
||||
)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
mock_tokenizer = MagicMock(spec=MistralTokenizer)
|
||||
mock_renderer = MistralRenderer(mock_engine.model_config, tokenizer_kwargs={})
|
||||
mock_renderer._tokenizer = mock_tokenizer
|
||||
# Force the Mistral chat template renderer to return token IDs.
|
||||
# Choose a prompt length that is < max_model_len, but large enough that
|
||||
# adding max_tokens should exceed the model context window.
|
||||
serving_chat._apply_mistral_chat_template_async = AsyncMock(
|
||||
return_value=list(range(95))
|
||||
mock_renderer.render_messages_async = AsyncMock(
|
||||
return_value=(
|
||||
[],
|
||||
TokensPrompt(prompt_token_ids=list(range(95))),
|
||||
)
|
||||
)
|
||||
mock_engine.renderer = mock_renderer
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
@@ -781,39 +787,33 @@ async def test_serving_chat_mistral_token_ids_prompt_is_validated(monkeypatch_mo
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected(
|
||||
monkeypatch_module,
|
||||
):
|
||||
async def test_serving_chat_mistral_token_ids_prompt_too_long_is_rejected():
|
||||
"""Regression test: MistralTokenizer token-id prompts must still enforce
|
||||
the max context length for the input itself (token_num >= max_model_len).
|
||||
"""
|
||||
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.model_config = MockModelConfig(skip_tokenizer_init=True)
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
|
||||
class DummyMistralTokenizer:
|
||||
def decode(self, token_ids):
|
||||
return "dummy"
|
||||
|
||||
dummy_tokenizer = DummyMistralTokenizer()
|
||||
mock_engine.get_tokenizer.return_value = dummy_tokenizer
|
||||
|
||||
import vllm.entrypoints.openai.engine.serving as engine_serving
|
||||
|
||||
monkeypatch_module.setattr(
|
||||
engine_serving, "MistralTokenizer", DummyMistralTokenizer
|
||||
)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
mock_tokenizer = MagicMock(spec=MistralTokenizer)
|
||||
mock_renderer = MistralRenderer(mock_engine.model_config, tokenizer_kwargs={})
|
||||
mock_renderer._tokenizer = mock_tokenizer
|
||||
# prompt_token_ids length == max_model_len should be rejected for
|
||||
# completion-like requests (ChatCompletionRequest).
|
||||
serving_chat._apply_mistral_chat_template_async = AsyncMock(
|
||||
return_value=list(range(mock_engine.model_config.max_model_len))
|
||||
mock_renderer.render_messages_async = AsyncMock(
|
||||
return_value=(
|
||||
[],
|
||||
TokensPrompt(
|
||||
prompt_token_ids=list(range(mock_engine.model_config.max_model_len))
|
||||
),
|
||||
)
|
||||
)
|
||||
mock_engine.renderer = mock_renderer
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
@@ -835,11 +835,11 @@ async def test_serving_chat_could_load_correct_generation_config():
|
||||
}
|
||||
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = mock_model_config
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
# Initialize the serving chat
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
@@ -881,11 +881,11 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
|
||||
mock_model_config.hf_config.model_type = model_type
|
||||
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = mock_model_config
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
serving_chat = _build_serving_chat(mock_engine)
|
||||
|
||||
@@ -914,11 +914,11 @@ async def test_serving_chat_data_parallel_rank_extraction():
|
||||
"""Test that data_parallel_rank is properly extracted from header and
|
||||
passed to engine."""
|
||||
mock_engine = MagicMock(spec=AsyncLLM)
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
mock_engine.model_config = MockModelConfig()
|
||||
mock_engine.input_processor = MagicMock()
|
||||
mock_engine.io_processor = MagicMock()
|
||||
mock_engine.renderer = _build_renderer(mock_engine.model_config)
|
||||
|
||||
# Mock the generate method to return an async generator
|
||||
async def mock_generate(*args, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user