[Frontend] Introduce Renderer for processing chat messages (using ModelConfig) (#30200)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-22 20:44:22 +08:00
committed by GitHub
parent 421012b63a
commit d117a4d1a9
48 changed files with 2141 additions and 1585 deletions

View File

@@ -11,9 +11,9 @@ from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import IOProcessor
from vllm.pooling_params import PoolingParams
from vllm.renderers import RendererLike
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.input_processor import InputProcessor
@@ -26,6 +26,10 @@ class EngineClient(ABC):
input_processor: InputProcessor
io_processor: IOProcessor | None
@property
@abstractmethod
def renderer(self) -> RendererLike: ...
@property
@abstractmethod
def is_running(self) -> bool: ...
@@ -88,11 +92,6 @@ class EngineClient(ABC):
"""
...
@abstractmethod
async def get_tokenizer(self) -> TokenizerLike:
"""Get the tokenizer"""
...
@abstractmethod
async def is_tracing_enabled(self) -> bool: ...

View File

@@ -2,22 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import inspect
import json
import warnings
from abc import ABC, abstractmethod
from collections import Counter, defaultdict, deque
from collections import Counter, defaultdict
from collections.abc import Awaitable, Callable, Iterable
from functools import cached_property, lru_cache, partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast
import jinja2
import jinja2.ext
import jinja2.meta
import jinja2.nodes
import jinja2.parser
import jinja2.sandbox
import transformers.utils.chat_template_utils as hf_chat_utils
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionContentPartImageParam,
@@ -39,7 +32,6 @@ from openai.types.responses import ResponseInputImageParam
from openai_harmony import Message as OpenAIHarmonyMessage
from PIL import Image
from pydantic import BaseModel, ConfigDict, TypeAdapter
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin
# pydantic needs the TypedDict from typing_extensions
from typing_extensions import Required, TypedDict
@@ -50,24 +42,35 @@ from vllm.logger import init_logger
from vllm.model_executor.models import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import random_uuid
from vllm.utils.collection_utils import is_list_of
from vllm.utils.func_utils import supports_kw
from vllm.utils.import_utils import LazyLoader
if TYPE_CHECKING:
import torch
from vllm.tokenizers.mistral import MistralTokenizer
else:
torch = LazyLoader("torch", globals(), "torch")
logger = init_logger(__name__)
def __getattr__(name: str):
if name == "resolve_hf_chat_template":
from vllm.renderers.hf import resolve_chat_template
warnings.warn(
"`vllm.entrypoints.chat_utils.resolve_hf_chat_template` has been moved to "
"`vllm.renderers.hf.resolve_chat_template`. "
"The old name will be removed in v0.16.",
DeprecationWarning,
stacklevel=2,
)
return resolve_chat_template
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
class ChatTemplateResolutionError(ValueError):
"""Raised when chat template resolution fails.
@@ -320,325 +323,8 @@ class ConversationMessage(TypedDict, total=False):
# Passed in by user
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]
# Used internally
_ChatTemplateContentFormat = Literal["string", "openai"]
def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
if isinstance(node, jinja2.nodes.Name):
return node.ctx == "load" and node.name == varname
return False
def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
if isinstance(node, jinja2.nodes.Getitem):
return (
_is_var_access(node.node, varname)
and isinstance(node.arg, jinja2.nodes.Const)
and node.arg.value == key
)
if isinstance(node, jinja2.nodes.Getattr):
return _is_var_access(node.node, varname) and node.attr == key
return False
def _is_var_or_elems_access(
node: jinja2.nodes.Node,
varname: str,
key: str | None = None,
) -> bool:
if isinstance(node, jinja2.nodes.Filter):
return node.node is not None and _is_var_or_elems_access(
node.node, varname, key
)
if isinstance(node, jinja2.nodes.Test):
return _is_var_or_elems_access(node.node, varname, key)
if isinstance(node, jinja2.nodes.Getitem) and isinstance(
node.arg, jinja2.nodes.Slice
):
return _is_var_or_elems_access(node.node, varname, key)
return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname)
def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
# Global variable that is implicitly defined at the root
yield root, varname
# Iterative BFS
related_varnames = deque([varname])
while related_varnames:
related_varname = related_varnames.popleft()
for assign_ast in root.find_all(jinja2.nodes.Assign):
lhs = assign_ast.target
rhs = assign_ast.node
if _is_var_or_elems_access(rhs, related_varname):
assert isinstance(lhs, jinja2.nodes.Name)
yield assign_ast, lhs.name
# Avoid infinite looping for self-assignment
if lhs.name != related_varname:
related_varnames.append(lhs.name)
# NOTE: The proper way to handle this is to build a CFG so that we can handle
# the scope in which each variable is defined, but that is too complicated
def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
messages_varnames = [
varname for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
]
# Search for {%- for message in messages -%} loops
for loop_ast in root.find_all(jinja2.nodes.For):
loop_iter = loop_ast.iter
loop_target = loop_ast.target
for varname in messages_varnames:
if _is_var_or_elems_access(loop_iter, varname):
assert isinstance(loop_target, jinja2.nodes.Name)
yield loop_ast, loop_target.name
break
def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
message_varnames = [
varname for _, varname in _iter_nodes_assign_messages_item(root)
]
# Search for {%- for content in message['content'] -%} loops
for loop_ast in root.find_all(jinja2.nodes.For):
loop_iter = loop_ast.iter
loop_target = loop_ast.target
for varname in message_varnames:
if _is_var_or_elems_access(loop_iter, varname, "content"):
assert isinstance(loop_target, jinja2.nodes.Name)
yield loop_ast, loop_target.name
break
def _try_extract_ast(chat_template: str) -> jinja2.nodes.Template | None:
try:
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
return jinja_compiled.environment.parse(chat_template)
except Exception:
logger.exception("Error when compiling Jinja template")
return None
@lru_cache(maxsize=32)
def _detect_content_format(
chat_template: str,
*,
default: _ChatTemplateContentFormat,
) -> _ChatTemplateContentFormat:
jinja_ast = _try_extract_ast(chat_template)
if jinja_ast is None:
return default
try:
next(_iter_nodes_assign_content_item(jinja_ast))
except StopIteration:
return "string"
except Exception:
logger.exception("Error when parsing AST of Jinja template")
return default
else:
return "openai"
def resolve_mistral_chat_template(
chat_template: str | None,
**kwargs: Any,
) -> str | None:
if chat_template is not None or kwargs.get("chat_template_kwargs") is not None:
raise ValueError(
"'chat_template' or 'chat_template_kwargs' cannot be overridden "
"for mistral tokenizer."
)
return None
_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], str | None]()
"""
Used in `_try_get_processor_chat_template` to avoid calling
`cached_get_processor` again if the processor fails to be loaded.
This is needed because `lru_cache` does not cache when an exception happens.
"""
def _try_get_processor_chat_template(
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
model_config: ModelConfig,
) -> str | None:
cache_key = (tokenizer.name_or_path, model_config.trust_remote_code)
if cache_key in _PROCESSOR_CHAT_TEMPLATES:
return _PROCESSOR_CHAT_TEMPLATES[cache_key]
try:
processor = cached_get_processor(
tokenizer.name_or_path,
processor_cls=(
PreTrainedTokenizer,
PreTrainedTokenizerFast,
ProcessorMixin,
),
trust_remote_code=model_config.trust_remote_code,
)
if (
isinstance(processor, ProcessorMixin)
and hasattr(processor, "chat_template")
and (chat_template := processor.chat_template) is not None
):
_PROCESSOR_CHAT_TEMPLATES[cache_key] = chat_template
return chat_template
except Exception:
logger.debug(
"Failed to load AutoProcessor chat template for %s",
tokenizer.name_or_path,
exc_info=True,
)
_PROCESSOR_CHAT_TEMPLATES[cache_key] = None
return None
def resolve_hf_chat_template(
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
chat_template: str | None,
tools: list[dict[str, Any]] | None,
*,
model_config: ModelConfig,
) -> str | None:
# 1st priority: The given chat template
if chat_template is not None:
return chat_template
# 2nd priority: AutoProcessor chat template, unless tool calling is enabled
if tools is None:
chat_template = _try_get_processor_chat_template(tokenizer, model_config)
if chat_template is not None:
return chat_template
# 3rd priority: AutoTokenizer chat template
try:
return tokenizer.get_chat_template(chat_template, tools=tools)
except Exception:
logger.debug(
"Failed to load AutoTokenizer chat template for %s",
tokenizer.name_or_path,
exc_info=True,
)
# 4th priority: Predefined fallbacks
path = get_chat_template_fallback_path(
model_type=model_config.hf_config.model_type,
tokenizer_name_or_path=model_config.tokenizer,
)
if path is not None:
logger.info_once(
"Loading chat template fallback for %s as there isn't one "
"defined on HF Hub.",
tokenizer.name_or_path,
)
chat_template = load_chat_template(path)
else:
logger.debug_once(
"There is no chat template fallback for %s", tokenizer.name_or_path
)
return chat_template
def _resolve_chat_template_content_format(
chat_template: str | None,
tools: list[dict[str, Any]] | None,
tokenizer: TokenizerLike | None,
*,
model_config: ModelConfig,
) -> _ChatTemplateContentFormat:
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
hf_chat_template = resolve_hf_chat_template(
tokenizer,
chat_template=chat_template,
tools=tools,
model_config=model_config,
)
else:
hf_chat_template = None
jinja_text = (
hf_chat_template
if isinstance(hf_chat_template, str)
else load_chat_template(chat_template, is_literal=True)
)
detected_format = (
"string"
if jinja_text is None
else _detect_content_format(jinja_text, default="string")
)
return detected_format
@lru_cache
def _log_chat_template_content_format(
chat_template: str | None,
given_format: ChatTemplateContentFormatOption,
detected_format: ChatTemplateContentFormatOption,
):
logger.info(
"Detected the chat template content format to be '%s'. "
"You can set `--chat-template-content-format` to override this.",
detected_format,
)
if given_format != "auto" and given_format != detected_format:
logger.warning(
"You specified `--chat-template-content-format %s` "
"which is different from the detected format '%s'. "
"If our automatic detection is incorrect, please consider "
"opening a GitHub issue so that we can improve it: "
"https://github.com/vllm-project/vllm/issues/new/choose",
given_format,
detected_format,
)
def resolve_chat_template_content_format(
chat_template: str | None,
tools: list[dict[str, Any]] | None,
given_format: ChatTemplateContentFormatOption,
tokenizer: TokenizerLike | None,
*,
model_config: ModelConfig,
) -> _ChatTemplateContentFormat:
if given_format != "auto":
return given_format
detected_format = _resolve_chat_template_content_format(
chat_template,
tools,
tokenizer,
model_config=model_config,
)
_log_chat_template_content_format(
chat_template,
given_format=given_format,
detected_format=detected_format,
)
return detected_format
# After resolving "auto"
ChatTemplateContentFormat = Literal["string", "openai"]
ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"]
@@ -1593,7 +1279,7 @@ _ToolParser = partial(cast, ChatCompletionToolMessageParam)
def _parse_chat_message_content(
message: ChatCompletionMessageParam,
mm_tracker: BaseMultiModalItemTracker,
content_format: _ChatTemplateContentFormat,
content_format: ChatTemplateContentFormat,
interleave_strings: bool,
) -> list[ConversationMessage]:
role = message["role"]
@@ -1669,7 +1355,7 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
def parse_chat_messages(
messages: list[ChatCompletionMessageParam],
model_config: ModelConfig,
content_format: _ChatTemplateContentFormat,
content_format: ChatTemplateContentFormat,
) -> tuple[
list[ConversationMessage],
MultiModalDataDict | None,
@@ -1697,13 +1383,13 @@ def parse_chat_messages(
return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
def parse_chat_messages_futures(
async def parse_chat_messages_async(
messages: list[ChatCompletionMessageParam],
model_config: ModelConfig,
content_format: _ChatTemplateContentFormat,
content_format: ChatTemplateContentFormat,
) -> tuple[
list[ConversationMessage],
Awaitable[MultiModalDataDict | None],
MultiModalDataDict | None,
MultiModalUUIDDict | None,
]:
conversation: list[ConversationMessage] = []
@@ -1725,174 +1411,7 @@ def parse_chat_messages_futures(
_postprocess_messages(conversation)
return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
# only preserve the parse function used to resolve chat template kwargs
class AssistantTracker(jinja2.ext.Extension):
tags = {"generation"}
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
lineno = next(parser.stream).lineno
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
call = self.call_method("_generation_support")
call_block = jinja2.nodes.CallBlock(call, [], [], body)
return call_block.set_lineno(lineno)
def _resolve_chat_template_kwargs(
chat_template: str,
):
env = jinja2.sandbox.ImmutableSandboxedEnvironment(
trim_blocks=True,
lstrip_blocks=True,
extensions=[AssistantTracker, jinja2.ext.loopcontrols],
)
parsed_content = env.parse(chat_template)
template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
return template_vars
_cached_resolve_chat_template_kwargs = lru_cache(_resolve_chat_template_kwargs)
@lru_cache
def _get_hf_base_chat_template_params() -> frozenset[str]:
# Get standard parameters from HuggingFace's base tokenizer class.
# This dynamically extracts parameters from PreTrainedTokenizer's
# apply_chat_template method, ensuring compatibility with tokenizers
# that use **kwargs to receive standard parameters.
# Read signature from HF's base class - the single source of truth
base_sig = inspect.signature(PreTrainedTokenizer.apply_chat_template)
# Exclude VAR_KEYWORD (**kwargs) and VAR_POSITIONAL (*args) placeholders
return frozenset(
p.name
for p in base_sig.parameters.values()
if p.kind
not in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL)
)
def resolve_chat_template_kwargs(
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
chat_template: str,
chat_template_kwargs: dict[str, Any],
raise_on_unexpected: bool = True,
) -> dict[str, Any]:
# We exclude chat_template from kwargs here, because
# chat template has been already resolved at this stage
unexpected_vars = {"chat_template", "tokenize"}
if raise_on_unexpected and (
unexpected_in_kwargs := unexpected_vars & chat_template_kwargs.keys()
):
raise ValueError(
"Found unexpected chat template kwargs from request: "
f"{unexpected_in_kwargs}"
)
fn_kw = {
k
for k in chat_template_kwargs
if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
}
template_vars = _cached_resolve_chat_template_kwargs(chat_template)
# Allow standard HF parameters even if tokenizer uses **kwargs to receive them
hf_base_params = _get_hf_base_chat_template_params()
accept_vars = (fn_kw | template_vars | hf_base_params) - unexpected_vars
return {k: v for k, v in chat_template_kwargs.items() if k in accept_vars}
def apply_hf_chat_template(
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
conversation: list[ConversationMessage],
chat_template: str | None,
tools: list[dict[str, Any]] | None,
*,
model_config: ModelConfig,
**kwargs: Any,
) -> str:
hf_chat_template = resolve_hf_chat_template(
tokenizer,
chat_template=chat_template,
tools=tools,
model_config=model_config,
)
if hf_chat_template is None:
raise ChatTemplateResolutionError(
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
)
resolved_kwargs = resolve_chat_template_kwargs(
tokenizer=tokenizer,
chat_template=hf_chat_template,
chat_template_kwargs=kwargs,
)
try:
return tokenizer.apply_chat_template(
conversation=conversation, # type: ignore[arg-type]
tools=tools, # type: ignore[arg-type]
chat_template=hf_chat_template,
tokenize=False,
**resolved_kwargs,
)
# External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities.
except Exception as e:
# Log and report any library-related exceptions for further
# investigation.
logger.exception(
"An error occurred in `transformers` while applying chat template"
)
raise ValueError(str(e)) from e
def apply_mistral_chat_template(
tokenizer: "MistralTokenizer",
messages: list[ChatCompletionMessageParam],
chat_template: str | None,
tools: list[dict[str, Any]] | None,
**kwargs: Any,
) -> list[int]:
from mistral_common.exceptions import MistralCommonException
# The return value of resolve_mistral_chat_template is always None,
# and we won't use it.
resolve_mistral_chat_template(
chat_template=chat_template,
**kwargs,
)
try:
return tokenizer.apply_chat_template(
messages=messages,
tools=tools,
**kwargs,
)
# mistral-common uses assert statements to stop processing of input
# if input does not comply with the expected format.
# We convert those assertion errors to ValueErrors so they can be
# properly caught in the preprocessing_input step
except (AssertionError, MistralCommonException) as e:
raise ValueError(str(e)) from e
# External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities.
except Exception as e:
# Log and report any library-related exceptions for further
# investigation.
logger.exception(
"An error occurred in `mistral_common` while applying chat template"
)
raise ValueError(str(e)) from e
return conversation, await mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):

View File

@@ -37,10 +37,6 @@ from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages,
resolve_chat_template_content_format,
)
from vllm.entrypoints.pooling.score.utils import (
ScoreContentPartParam,
@@ -786,7 +782,7 @@ class LLM:
tools: list[dict[str, Any]] | None = None,
chat_template_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None,
) -> list[TokensPrompt]:
) -> list[TextPrompt | TokensPrompt]:
"""
Generate prompt for a chat conversation. The pre-processed
prompt can then be used as input for the other LLM methods.
@@ -807,63 +803,27 @@ class LLM:
# messages is list[...]
list_of_messages = [cast(list[ChatCompletionMessageParam], messages)]
tokenizer = self.get_tokenizer()
model_config = self.model_config
resolved_content_format = resolve_chat_template_content_format(
chat_template,
tools,
chat_template_content_format,
tokenizer,
model_config=model_config,
)
renderer = self.llm_engine.renderer
_chat_template_kwargs: dict[str, Any] = dict(
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
)
_chat_template_kwargs.update(chat_template_kwargs or {})
chat_template_kwargs = {
"chat_template": chat_template,
"add_generation_prompt": add_generation_prompt,
"continue_final_message": continue_final_message,
"tools": tools,
**(chat_template_kwargs or {}),
}
prompts: list[TokensPrompt] = []
prompts = list[TextPrompt | TokensPrompt]()
for msgs in list_of_messages:
# NOTE: _parse_chat_message_content_parts() currently doesn't
# NOTE: renderer.render_messages() currently doesn't
# handle mm_processor_kwargs, since there is no implementation in
# the chat message parsing for it.
conversation, mm_data, mm_uuids = parse_chat_messages(
_, prompt = renderer.render_messages(
msgs,
model_config,
content_format=resolved_content_format,
chat_template_content_format=chat_template_content_format,
**chat_template_kwargs,
)
if isinstance(tokenizer, MistralTokenizer):
prompt_token_ids = apply_mistral_chat_template(
tokenizer,
messages=msgs,
**_chat_template_kwargs,
)
else:
prompt_str = apply_hf_chat_template(
tokenizer=tokenizer,
conversation=conversation,
model_config=model_config,
**_chat_template_kwargs,
)
# Special tokens are already included in chat templates so
# should not be added by the tokenizer in this case.
prompt_token_ids = tokenizer.encode(
prompt_str, add_special_tokens=False
)
prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
if mm_processor_kwargs is not None:
prompt["mm_processor_kwargs"] = mm_processor_kwargs

View File

@@ -34,6 +34,7 @@ import vllm.envs as envs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.anthropic.serving import AnthropicServingMessages
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.mcp.tool_server import DemoToolServer, MCPToolServer, ToolServer
@@ -62,7 +63,6 @@ from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
from vllm.entrypoints.utils import (
cli_env_setup,
log_non_default_args,
process_chat_template,
process_lora_modules,
sanitize_message,
)
@@ -662,9 +662,7 @@ async def init_app_state(
supported_tasks = await engine_client.get_supported_tasks()
logger.info("Supported tasks: %s", supported_tasks)
resolved_chat_template = await process_chat_template(
args.chat_template, engine_client, vllm_config.model_config
)
resolved_chat_template = load_chat_template(args.chat_template)
if args.tool_server == "demo":
tool_server: ToolServer | None = DemoToolServer()

View File

@@ -186,8 +186,7 @@ class OpenAIServingChat(OpenAIServing):
start_time = time.perf_counter()
try:
# Get the tokenizer from the engine
tokenizer = await self.engine_client.get_tokenizer()
renderer = self.engine_client.renderer
# Create a minimal dummy request
dummy_request = ChatCompletionRequest(
@@ -203,7 +202,7 @@ class OpenAIServingChat(OpenAIServing):
# 3. Tokenizer initialization for chat
await self._preprocess_chat(
dummy_request,
tokenizer,
renderer,
dummy_request.messages,
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
@@ -247,7 +246,8 @@ class OpenAIServingChat(OpenAIServing):
raise self.engine_client.dead_error
try:
tokenizer = await self.engine_client.get_tokenizer()
renderer = self.engine_client.renderer
tokenizer = renderer.tokenizer
tool_parser = self.tool_parser
@@ -308,7 +308,7 @@ class OpenAIServingChat(OpenAIServing):
conversation, engine_prompts = await self._preprocess_chat(
request,
tokenizer,
renderer,
request.messages,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
@@ -365,8 +365,6 @@ class OpenAIServingChat(OpenAIServing):
)
model_name = self.models.model_name(lora_request)
tokenizer = await self.engine_client.get_tokenizer()
except (ValueError, TypeError, RuntimeError) as e:
logger.exception("Error preparing request components")
return self.create_error_response(e)
@@ -463,6 +461,8 @@ class OpenAIServingChat(OpenAIServing):
(result_generator,) = generators
# Streaming response
tokenizer = self.renderer.tokenizer
if request.stream:
return self.chat_completion_stream_generator(
request,
@@ -1784,7 +1784,7 @@ class OpenAIServingChat(OpenAIServing):
else:
if tokenizer is None:
raise ValueError(
"Tokenizer not available when `skip_tokenizer_init=True`"
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
token = tokenizer.decode(token_id)

View File

@@ -117,12 +117,7 @@ class OpenAIServingCompletion(OpenAIServing):
)
try:
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
renderer = self._get_completion_renderer()
engine_prompts = await renderer.render_prompt_and_embeds(
prompt_or_prompts=request.prompt,
prompt_embeds=request.prompt_embeds,
@@ -163,11 +158,6 @@ class OpenAIServingCompletion(OpenAIServing):
try:
lora_request = self._maybe_get_adapters(request)
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = await self.engine_client.get_tokenizer()
except (ValueError, TypeError, RuntimeError) as e:
logger.exception("Error preparing request components")
return self.create_error_response(e)
@@ -280,6 +270,8 @@ class OpenAIServingCompletion(OpenAIServing):
stream = request.stream and not request.use_beam_search
# Streaming response
tokenizer = self.renderer.tokenizer
if stream:
return self.completion_stream_generator(
request,

View File

@@ -6,10 +6,9 @@ import sys
import time
import traceback
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar, cast
import numpy as np
from fastapi import Request
@@ -26,10 +25,6 @@ from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
ConversationMessage,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages_futures,
resolve_chat_template_content_format,
)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.chat_completion.protocol import (
@@ -113,10 +108,9 @@ from vllm.multimodal import MultiModalDataDict
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.renderers import RendererLike
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.tool_parsers import ToolParser, ToolParserManager
from vllm.tracing import (
contains_trace_headers,
@@ -127,10 +121,8 @@ from vllm.utils import random_uuid
from vllm.utils.async_utils import (
AsyncMicrobatchTokenizer,
collect_from_async_generator,
make_async,
merge_async_iterators,
)
from vllm.utils.collection_utils import is_list_of
from vllm.v1.engine import EngineCoreRequest
@@ -215,7 +207,6 @@ class ResponseGenerationMixin:
@dataclass(kw_only=True)
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[RequestT]):
# Shared across all requests
request: RequestT
raw_request: Request | None = None
model_name: str
@@ -223,9 +214,6 @@ class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[Requ
created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None
# Shared across most requests
tokenizer: TokenizerLike | None = None
@dataclass(kw_only=True)
class ClassificationServeContext(ServeContext[ClassificationRequest]):
@@ -261,16 +249,13 @@ class OpenAIServing:
self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
self._apply_mistral_chat_template_async = make_async(
apply_mistral_chat_template, executor=self._tokenizer_executor
)
self._async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer] = {}
self.log_error_stack = log_error_stack
self.input_processor = self.models.input_processor
self.io_processor = self.models.io_processor
self.renderer = self.models.renderer
self.model_config = self.models.model_config
self.max_model_len = self.model_config.max_model_len
@@ -557,14 +542,14 @@ class OpenAIServing:
prompt_logprobs=None,
)
def _get_renderer(self, tokenizer: TokenizerLike | None) -> BaseRenderer:
def _get_completion_renderer(self) -> BaseRenderer:
"""
Get a Renderer instance with the provided tokenizer.
Uses shared async tokenizer pool for efficiency.
"""
return CompletionRenderer(
model_config=self.model_config,
tokenizer=tokenizer,
tokenizer=self.renderer.tokenizer,
async_tokenizer_pool=self._async_tokenizer_pool,
)
@@ -1183,7 +1168,7 @@ class OpenAIServing:
async def _preprocess_chat(
self,
request: ChatLikeRequest | ResponsesRequest,
tokenizer: TokenizerLike | None,
renderer: RendererLike,
messages: list[ChatCompletionMessageParam],
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
@@ -1196,59 +1181,58 @@ class OpenAIServing:
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
add_special_tokens: bool = False,
) -> tuple[list[ConversationMessage], list[TokensPrompt]]:
model_config = self.model_config
resolved_content_format = resolve_chat_template_content_format(
chat_template,
tool_dicts,
chat_template_content_format,
tokenizer,
model_config=model_config,
)
conversation, mm_data_future, mm_uuids = parse_chat_messages_futures(
messages,
model_config,
content_format=resolved_content_format,
)
_chat_template_kwargs: dict[str, Any] = dict(
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tool_dicts,
documents=documents,
)
_chat_template_kwargs |= self._prepare_extra_chat_template_kwargs(
chat_template_kwargs = {
"chat_template": chat_template,
"add_generation_prompt": add_generation_prompt,
"continue_final_message": continue_final_message,
"tools": tool_dicts,
"documents": documents,
**(chat_template_kwargs or {}),
}
chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
chat_template_kwargs,
default_chat_template_kwargs,
)
request_prompt: str | list[int]
# Use the async tokenizer in `OpenAIServing` if possible.
# Later we can move it into the renderer so that we can return both
# text and token IDs in the same prompt from `render_messages_async`
# which is used for logging and `enable_response_messages`.
from vllm.tokenizers.mistral import MistralTokenizer
if tokenizer is None:
request_prompt = "placeholder"
elif isinstance(tokenizer, MistralTokenizer):
request_prompt = await self._apply_mistral_chat_template_async(
tokenizer,
messages=messages,
**_chat_template_kwargs,
)
elif isinstance(tokenizer, DeepseekV32Tokenizer):
request_prompt = tokenizer.apply_chat_template(
conversation=conversation,
messages=messages,
model_config=model_config,
**_chat_template_kwargs,
conversation, engine_prompt = await renderer.render_messages_async(
messages,
chat_template_content_format=chat_template_content_format,
tokenize=(
chat_template_kwargs.pop("tokenize", False)
or isinstance(renderer.tokenizer, MistralTokenizer)
),
**chat_template_kwargs,
)
if "prompt_token_ids" not in engine_prompt:
extra_data = engine_prompt
engine_prompt = await self._tokenize_prompt_input_async(
request,
renderer.get_tokenizer(),
engine_prompt["prompt"],
add_special_tokens=add_special_tokens,
)
# Fill in other keys like MM data
engine_prompt.update(extra_data) # type: ignore
else:
request_prompt = apply_hf_chat_template(
tokenizer=tokenizer,
conversation=conversation,
model_config=model_config,
**_chat_template_kwargs,
self._validate_input(
request=request,
input_ids=engine_prompt["prompt_token_ids"], # type: ignore
input_text="",
)
mm_data = await mm_data_future
engine_prompt = cast(TokensPrompt, engine_prompt)
if request.mm_processor_kwargs is not None:
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
if (cache_salt := getattr(request, "cache_salt", None)) is not None:
engine_prompt["cache_salt"] = cache_salt
# tool parsing is done only if a tool_parser has been set and if
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser
@@ -1264,49 +1248,10 @@ class OpenAIServing:
"or Responses API requests."
)
raise NotImplementedError(msg)
tokenizer = renderer.get_tokenizer()
request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore
if tokenizer is None:
assert isinstance(request_prompt, str), (
"Prompt has to be a string",
"when the tokenizer is not initialised",
)
prompt_inputs = TokensPrompt(prompt=request_prompt, prompt_token_ids=[1])
elif isinstance(request_prompt, str):
prompt_inputs = await self._tokenize_prompt_input_async(
request,
tokenizer,
request_prompt,
add_special_tokens=add_special_tokens,
)
else:
# For MistralTokenizer
assert is_list_of(request_prompt, int), (
"Prompt has to be either a string or a list of token ids"
)
input_text = tokenizer.decode(request_prompt)
prompt_inputs = self._validate_input(
request=request,
input_ids=request_prompt,
input_text=input_text,
)
engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["prompt_token_ids"])
if "prompt" in prompt_inputs:
engine_prompt["prompt"] = prompt_inputs["prompt"]
if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
engine_prompt["multi_modal_uuids"] = mm_uuids
if request.mm_processor_kwargs is not None:
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
if hasattr(request, "cache_salt") and request.cache_salt is not None:
engine_prompt["cache_salt"] = request.cache_salt
return conversation, [engine_prompt]
async def _process_inputs(
@@ -1341,7 +1286,7 @@ class OpenAIServing:
async def _render_next_turn(
self,
request: ResponsesRequest,
tokenizer: TokenizerLike | None,
renderer: RendererLike,
messages: list[ResponseInputOutputItem],
tool_dicts: list[dict[str, Any]] | None,
tool_parser,
@@ -1354,7 +1299,7 @@ class OpenAIServing:
_, engine_prompts = await self._preprocess_chat(
request,
tokenizer,
renderer,
new_messages,
tool_dicts=tool_dicts,
tool_parser=tool_parser,
@@ -1431,7 +1376,7 @@ class OpenAIServing:
elif isinstance(context, ParsableContext):
engine_prompts = await self._render_next_turn(
context.request,
context.tokenizer,
context.renderer,
context.parser.response_messages,
context.tool_dicts,
context.tool_parser_cls,

View File

@@ -61,6 +61,7 @@ class OpenAIServingModels:
self.input_processor = self.engine_client.input_processor
self.io_processor = self.engine_client.io_processor
self.renderer = self.engine_client.renderer
self.model_config = self.engine_client.model_config
self.max_model_len = self.model_config.max_model_len

View File

@@ -43,6 +43,7 @@ from vllm.entrypoints.openai.responses.protocol import (
from vllm.entrypoints.openai.responses.utils import construct_tool_dicts
from vllm.outputs import RequestOutput
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.renderers import RendererLike
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ToolParser
from vllm.utils import random_uuid
@@ -260,7 +261,7 @@ class ParsableContext(ConversationContext):
self,
*,
response_messages: list[ResponseInputOutputItem],
tokenizer: TokenizerLike,
renderer: RendererLike,
reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser] | None,
request: ResponsesRequest,
available_tools: list[str] | None,
@@ -279,6 +280,7 @@ class ParsableContext(ConversationContext):
if reasoning_parser_cls is None:
raise ValueError("reasoning_parser_cls must be provided.")
tokenizer = renderer.get_tokenizer()
self.parser = get_responses_parser_for_simple_context(
tokenizer=tokenizer,
reasoning_parser_cls=reasoning_parser_cls,
@@ -288,6 +290,7 @@ class ParsableContext(ConversationContext):
)
self.tool_parser_cls = tool_parser_cls
self.request = request
self.renderer = renderer
self.tokenizer = tokenizer
self.available_tools = available_tools or []

View File

@@ -121,6 +121,7 @@ from vllm.logger import init_logger
from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs
from vllm.outputs import CompletionOutput
from vllm.renderers import RendererLike
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid
@@ -380,7 +381,8 @@ class OpenAIServingResponses(OpenAIServing):
try:
lora_request = self._maybe_get_adapters(request)
model_name = self.models.model_name(lora_request)
tokenizer = await self.engine_client.get_tokenizer()
renderer = self.engine_client.renderer
tokenizer = renderer.get_tokenizer()
if self.use_harmony:
messages, engine_prompts = self._make_request_with_harmony(
@@ -388,7 +390,7 @@ class OpenAIServingResponses(OpenAIServing):
)
else:
messages, engine_prompts = await self._make_request(
request, prev_response, tokenizer
request, prev_response, renderer
)
except (
@@ -454,7 +456,7 @@ class OpenAIServingResponses(OpenAIServing):
# tokens during generation instead of at the end
context = ParsableContext(
response_messages=messages,
tokenizer=tokenizer,
renderer=renderer,
reasoning_parser_cls=self.reasoning_parser,
request=request,
tool_parser_cls=self.tool_parser,
@@ -585,7 +587,7 @@ class OpenAIServingResponses(OpenAIServing):
self,
request: ResponsesRequest,
prev_response: ResponsesResponse | None,
tokenizer: TokenizerLike,
renderer: RendererLike,
):
tool_dicts = construct_tool_dicts(request.tools, request.tool_choice)
# Construct the input messages.
@@ -607,7 +609,7 @@ class OpenAIServingResponses(OpenAIServing):
_, engine_prompts = await self._preprocess_chat(
request,
tokenizer,
renderer,
messages,
tool_dicts=tool_dicts,
tool_parser=self.tool_parser,
@@ -631,6 +633,7 @@ class OpenAIServingResponses(OpenAIServing):
raise NotImplementedError(
"Only 'auto' tool_choice is supported in response API with Harmony"
)
messages = self._construct_input_messages_with_harmony(request, prev_response)
prompt_token_ids = render_for_completion(messages)
engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)

View File

@@ -28,21 +28,17 @@ def register_pooling_api_routers(app: FastAPI):
async def init_pooling_state(
engine_client: "EngineClient", state: "State", args: "Namespace"
):
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.pooling.classify.serving import ServingClassification
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
from vllm.entrypoints.pooling.score.serving import ServingScores
from vllm.entrypoints.utils import process_chat_template
from vllm.tasks import POOLING_TASKS
supported_tasks = await engine_client.get_supported_tasks()
vllm_config = engine_client.vllm_config
resolved_chat_template = await process_chat_template(
args.chat_template, engine_client, vllm_config.model_config
)
resolved_chat_template = load_chat_template(args.chat_template)
if args.enable_log_requests:
request_logger = RequestLogger(max_log_len=args.max_log_len)

View File

@@ -54,8 +54,6 @@ class ClassificationMixin(OpenAIServing):
"""
ctx = cast(ClassificationServeContext, ctx)
try:
ctx.tokenizer = await self.engine_client.get_tokenizer()
request_obj = ctx.request
if isinstance(request_obj, ClassificationChatRequest):
@@ -76,7 +74,7 @@ class ClassificationMixin(OpenAIServing):
_, engine_prompts = await self._preprocess_chat(
cast(ChatCompletionRequest, chat_request),
ctx.tokenizer,
self.renderer,
messages,
chat_template=(
chat_request.chat_template
@@ -104,7 +102,7 @@ class ClassificationMixin(OpenAIServing):
ctx.engine_prompts = []
return None
renderer = self._get_renderer(ctx.tokenizer)
renderer = self._get_completion_renderer()
prompt_input = cast(str | list[str], input_data)
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=prompt_input,

View File

@@ -78,13 +78,10 @@ class EmbeddingMixin(OpenAIServing):
try:
ctx.lora_request = self._maybe_get_adapters(ctx.request)
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
if isinstance(ctx.request, EmbeddingChatRequest):
_, ctx.engine_prompts = await self._preprocess_chat(
ctx.request,
tokenizer,
self.renderer,
ctx.request.messages,
chat_template=ctx.request.chat_template or ctx.chat_template,
chat_template_content_format=ctx.chat_template_content_format,
@@ -93,6 +90,7 @@ class EmbeddingMixin(OpenAIServing):
add_special_tokens=ctx.request.add_special_tokens,
)
else:
renderer = self._get_completion_renderer()
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=ctx.request.input,
config=self._build_render_config(ctx.request),

View File

@@ -94,12 +94,6 @@ class OpenAIServingPooling(OpenAIServing):
try:
lora_request = self._maybe_get_adapters(request)
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
if getattr(request, "dimensions", None) is not None:
return self.create_error_response(
"dimensions is currently not supported"
@@ -140,7 +134,7 @@ class OpenAIServingPooling(OpenAIServing):
_, engine_prompts = await self._preprocess_chat(
request,
tokenizer,
self.renderer,
request.messages,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
@@ -149,6 +143,7 @@ class OpenAIServingPooling(OpenAIServing):
add_special_tokens=request.add_special_tokens,
)
elif isinstance(request, PoolingCompletionRequest):
renderer = self._get_completion_renderer()
engine_prompts = await renderer.render_prompt(
prompt_or_prompts=request.input,
config=self._build_render_config(request),

View File

@@ -3,6 +3,7 @@
import asyncio
import time
from collections.abc import AsyncGenerator, Mapping
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from fastapi import Request
@@ -63,6 +64,8 @@ class ServingScores(OpenAIServing):
)
self.score_template = score_template
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
async def _embedding_score(
self,
tokenizer: TokenizerLike,
@@ -283,8 +286,7 @@ class ServingScores(OpenAIServing):
raw_request: Request | None = None,
) -> list[PoolingRequestOutput] | ErrorResponse:
lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer()
tokenizer = self.renderer.get_tokenizer()
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)

View File

@@ -16,12 +16,12 @@ from vllm.entrypoints.chat_utils import (
MultiModalItemTracker,
_ContentPart,
_parse_chat_message_content_part,
apply_hf_chat_template,
)
from vllm.inputs import TokensPrompt
from vllm.model_executor.models.interfaces import supports_score_template
from vllm.multimodal.inputs import MultiModalDataDict
from vllm.outputs import PoolingRequestOutput
from vllm.renderers.hf import safe_apply_chat_template
from vllm.tokenizers import TokenizerLike
ScoreContentPartParam: TypeAlias = (
@@ -224,15 +224,16 @@ def get_score_prompt(
# If that fails because there is no such template,
# fall back to the default implementation.
try:
full_prompt = apply_hf_chat_template(
full_prompt = safe_apply_chat_template(
model_config,
tokenizer,
[
{"role": "query", "content": prompt_1},
{"role": "document", "content": prompt_2},
],
score_template,
chat_template=score_template,
tools=None,
model_config=model_config,
tokenize=False,
)
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
except ChatTemplateResolutionError:

View File

@@ -67,9 +67,6 @@ class OpenAIServingTokenization(OpenAIServing):
try:
lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer()
renderer = self._get_renderer(tokenizer)
if isinstance(request, TokenizeChatRequest):
tool_dicts = (
None
@@ -86,7 +83,7 @@ class OpenAIServingTokenization(OpenAIServing):
_, engine_prompts = await self._preprocess_chat(
request,
tokenizer,
self.renderer,
request.messages,
tool_dicts=tool_dicts,
chat_template=request.chat_template or self.chat_template,
@@ -97,6 +94,7 @@ class OpenAIServingTokenization(OpenAIServing):
add_special_tokens=request.add_special_tokens,
)
else:
renderer = self._get_completion_renderer()
engine_prompts = await renderer.render_prompt(
prompt_or_prompts=request.prompt,
config=self._build_render_config(request),
@@ -116,6 +114,7 @@ class OpenAIServingTokenization(OpenAIServing):
token_strs = None
if request.return_token_strs:
tokenizer = self.renderer.get_tokenizer()
token_strs = tokenizer.convert_ids_to_tokens(input_ids)
return TokenizeResponse(
@@ -137,8 +136,7 @@ class OpenAIServingTokenization(OpenAIServing):
request_id = f"tokenize-{self._base_request_id(raw_request)}"
lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer()
tokenizer = self.renderer.get_tokenizer()
self._log_inputs(
request_id,
@@ -161,7 +159,7 @@ class OpenAIServingTokenization(OpenAIServing):
) -> TokenizerInfoResponse | ErrorResponse:
"""Get comprehensive tokenizer information."""
try:
tokenizer = await self.engine_client.get_tokenizer()
tokenizer = self.renderer.get_tokenizer()
info = TokenizerInfo(tokenizer, self.chat_template).to_dict()
return TokenizerInfoResponse(**info)
except Exception as e:

View File

@@ -6,7 +6,6 @@ import dataclasses
import functools
import os
from argparse import Namespace
from pathlib import Path
from typing import TYPE_CHECKING, Any
import regex as re
@@ -14,17 +13,9 @@ from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse
from starlette.background import BackgroundTask, BackgroundTasks
from vllm.config import ModelConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (
load_chat_template,
resolve_hf_chat_template,
resolve_mistral_chat_template,
)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.argparse_utils import FlexibleArgumentParser
if TYPE_CHECKING:
@@ -301,40 +292,6 @@ def process_lora_modules(
return lora_modules
async def process_chat_template(
args_chat_template: Path | str | None,
engine_client: EngineClient,
model_config: ModelConfig,
) -> str | None:
resolved_chat_template = load_chat_template(args_chat_template)
if resolved_chat_template is not None:
# Get the tokenizer to check official template
tokenizer = await engine_client.get_tokenizer()
if isinstance(tokenizer, MistralTokenizer):
# The warning is logged in resolve_mistral_chat_template.
resolved_chat_template = resolve_mistral_chat_template(
chat_template=resolved_chat_template
)
else:
hf_chat_template = resolve_hf_chat_template(
tokenizer=tokenizer,
chat_template=None,
tools=None,
model_config=model_config,
)
if hf_chat_template != resolved_chat_template:
logger.warning(
"Using supplied chat template: %s\n"
"It is different from official chat template '%s'. "
"This discrepancy may lead to performance degradation.",
resolved_chat_template,
model_config.model,
)
return resolved_chat_template
def sanitize_message(message: str) -> str:
# Avoid leaking memory address from object reprs
return re.sub(r" at 0x[0-9a-f]+>", ">", message)

View File

@@ -17,6 +17,7 @@ from vllm.multimodal.inputs import (
MultiModalUUIDDict,
)
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.renderers import renderer_from_config
from vllm.tokenizers import TokenizerLike
from vllm.utils.jsontree import json_iter_leaves
from vllm.v1.metrics.stats import MultiModalCacheStats
@@ -46,7 +47,6 @@ class InputPreprocessor:
def __init__(
self,
model_config: ModelConfig,
tokenizer: TokenizerLike | None,
observability_config: ObservabilityConfig | None = None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_processor_cache: BaseMultiModalProcessorCache | None = None,
@@ -54,20 +54,19 @@ class InputPreprocessor:
super().__init__()
self.model_config = model_config
self.tokenizer = tokenizer
self.observability_config = observability_config
self.renderer = renderer_from_config(model_config)
self.mm_registry = mm_registry
self.mm_processor_cache = mm_processor_cache
self.mm_cache_stats = MultiModalCacheStats() if mm_processor_cache else None
def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None:
raise ValueError(
"You cannot pass text prompts when `skip_tokenizer_init=True`"
)
@property
def tokenizer(self) -> TokenizerLike | None:
return self.renderer.tokenizer
return self.tokenizer
def get_tokenizer(self) -> TokenizerLike:
return self.renderer.get_tokenizer()
def get_bos_token_id(self) -> int | None:
if self.tokenizer is None:

View File

@@ -0,0 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .protocol import RendererLike
from .registry import RendererRegistry, renderer_from_config
__all__ = ["RendererLike", "RendererRegistry", "renderer_from_config"]

View File

@@ -0,0 +1,119 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ConversationMessage,
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
from .protocol import RendererLike
logger = init_logger(__name__)
class DeepseekV32Renderer(RendererLike):
@classmethod
def from_config(
cls,
config: ModelConfig,
tokenizer_kwargs: dict[str, Any],
) -> "RendererLike":
return cls(config, tokenizer_kwargs)
def __init__(
self,
config: ModelConfig,
tokenizer_kwargs: dict[str, Any],
) -> None:
super().__init__()
self.config = config
if config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = cached_get_tokenizer(
tokenizer_cls=DeepseekV32Tokenizer,
**tokenizer_kwargs,
)
self._tokenizer = tokenizer
@property
def tokenizer(self) -> DeepseekV32Tokenizer | None:
return self._tokenizer
def get_tokenizer(self) -> DeepseekV32Tokenizer:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")
return tokenizer
def render_messages(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = parse_chat_messages(
messages,
self.config,
content_format="string",
)
prompt_raw = tokenizer.apply_chat_template(
conversation=conversation,
messages=messages,
**kwargs,
)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
else TokensPrompt(prompt_token_ids=prompt_raw)
)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt # type: ignore[return-value]
async def render_messages_async(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages,
self.config,
content_format="string",
)
prompt_raw = tokenizer.apply_chat_template(
conversation=conversation,
messages=messages,
**kwargs,
)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
else TokensPrompt(prompt_token_ids=prompt_raw)
)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt # type: ignore[return-value]

119
vllm/renderers/grok2.py Normal file
View File

@@ -0,0 +1,119 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ConversationMessage,
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.grok2 import Grok2Tokenizer
from .protocol import RendererLike
logger = init_logger(__name__)
class Grok2Renderer(RendererLike):
@classmethod
def from_config(
cls,
config: ModelConfig,
tokenizer_kwargs: dict[str, Any],
) -> "RendererLike":
return cls(config, tokenizer_kwargs)
def __init__(
self,
config: ModelConfig,
tokenizer_kwargs: dict[str, Any],
) -> None:
super().__init__()
self.config = config
if config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = cached_get_tokenizer(
tokenizer_cls=Grok2Tokenizer,
**tokenizer_kwargs,
)
self._tokenizer = tokenizer
@property
def tokenizer(self) -> Grok2Tokenizer | None:
return self._tokenizer
def get_tokenizer(self) -> Grok2Tokenizer:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")
return tokenizer
def render_messages(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = parse_chat_messages(
messages,
self.config,
content_format="string",
)
prompt_raw = tokenizer.apply_chat_template(
conversation=conversation,
messages=messages,
**kwargs,
)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
else TokensPrompt(prompt_token_ids=prompt_raw)
)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt # type: ignore[return-value]
async def render_messages_async(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages,
self.config,
content_format="string",
)
prompt_raw = tokenizer.apply_chat_template(
conversation=conversation,
messages=messages,
**kwargs,
)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
else TokensPrompt(prompt_token_ids=prompt_raw)
)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt # type: ignore[return-value]

600
vllm/renderers/hf.py Normal file
View File

@@ -0,0 +1,600 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
from collections import deque
from collections.abc import Set
from functools import lru_cache
from typing import Any, cast
import jinja2
import jinja2.ext
import jinja2.meta
import jinja2.nodes
import jinja2.parser
import jinja2.sandbox
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormat,
ChatTemplateContentFormatOption,
ChatTemplateResolutionError,
ConversationMessage,
load_chat_template,
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.hf import CachedHfTokenizer, HfTokenizer
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils.func_utils import supports_kw
from .protocol import RendererLike
logger = init_logger(__name__)
_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], str | None]()
"""
Used in `_try_get_processor_chat_template` to avoid calling
`cached_get_processor` again if the processor fails to be loaded.
This is needed because `lru_cache` does not cache when an exception happens.
"""
def _try_get_processor_chat_template(
tokenizer: HfTokenizer,
*,
trust_remote_code: bool,
) -> str | None:
cache_key = (tokenizer.name_or_path, trust_remote_code)
if cache_key in _PROCESSOR_CHAT_TEMPLATES:
return _PROCESSOR_CHAT_TEMPLATES[cache_key]
from transformers import (
PreTrainedTokenizer,
PreTrainedTokenizerFast,
ProcessorMixin,
)
try:
processor = cached_get_processor(
tokenizer.name_or_path,
processor_cls=(
PreTrainedTokenizer,
PreTrainedTokenizerFast,
ProcessorMixin,
),
trust_remote_code=trust_remote_code,
)
if (
isinstance(processor, ProcessorMixin)
and hasattr(processor, "chat_template")
and (chat_template := processor.chat_template) is not None
):
_PROCESSOR_CHAT_TEMPLATES[cache_key] = chat_template
return chat_template
except Exception:
logger.debug(
"Failed to load AutoProcessor chat template for %s",
tokenizer.name_or_path,
exc_info=True,
)
_PROCESSOR_CHAT_TEMPLATES[cache_key] = None
return None
def resolve_chat_template(
tokenizer: HfTokenizer,
chat_template: str | None,
tools: list[dict[str, Any]] | None,
*,
model_config: "ModelConfig",
) -> str | None:
# 1st priority: The given chat template
if chat_template is not None:
return chat_template
# 2nd priority: AutoProcessor chat template, unless tool calling is enabled
if tools is None:
chat_template = _try_get_processor_chat_template(
tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
if chat_template is not None:
return chat_template
# 3rd priority: AutoTokenizer chat template
try:
return tokenizer.get_chat_template(chat_template, tools=tools)
except Exception:
logger.debug(
"Failed to load AutoTokenizer chat template for %s",
tokenizer.name_or_path,
exc_info=True,
)
# 4th priority: Predefined fallbacks
path = get_chat_template_fallback_path(
model_type=model_config.hf_config.model_type,
tokenizer_name_or_path=tokenizer.name_or_path,
)
if path is not None:
logger.info_once(
"Loading chat template fallback for %s as there isn't one "
"defined on HF Hub.",
tokenizer.name_or_path,
)
chat_template = load_chat_template(path)
else:
logger.debug_once(
"There is no chat template fallback for %s", tokenizer.name_or_path
)
return chat_template
def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
if isinstance(node, jinja2.nodes.Name):
return node.ctx == "load" and node.name == varname
return False
def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
if isinstance(node, jinja2.nodes.Getitem):
return (
_is_var_access(node.node, varname)
and isinstance(node.arg, jinja2.nodes.Const)
and node.arg.value == key
)
if isinstance(node, jinja2.nodes.Getattr):
return _is_var_access(node.node, varname) and node.attr == key
return False
def _is_var_or_elems_access(
node: jinja2.nodes.Node,
varname: str,
key: str | None = None,
) -> bool:
if isinstance(node, jinja2.nodes.Filter):
return node.node is not None and _is_var_or_elems_access(
node.node, varname, key
)
if isinstance(node, jinja2.nodes.Test):
return _is_var_or_elems_access(node.node, varname, key)
if isinstance(node, jinja2.nodes.Getitem) and isinstance(
node.arg, jinja2.nodes.Slice
):
return _is_var_or_elems_access(node.node, varname, key)
return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname)
def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
# Global variable that is implicitly defined at the root
yield root, varname
# Iterative BFS
related_varnames = deque([varname])
while related_varnames:
related_varname = related_varnames.popleft()
for assign_ast in root.find_all(jinja2.nodes.Assign):
lhs = assign_ast.target
rhs = assign_ast.node
if _is_var_or_elems_access(rhs, related_varname):
assert isinstance(lhs, jinja2.nodes.Name)
yield assign_ast, lhs.name
# Avoid infinite looping for self-assignment
if lhs.name != related_varname:
related_varnames.append(lhs.name)
# NOTE: The proper way to handle this is to build a CFG so that we can handle
# the scope in which each variable is defined, but that is too complicated
def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
messages_varnames = [
varname for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
]
# Search for {%- for message in messages -%} loops
for loop_ast in root.find_all(jinja2.nodes.For):
loop_iter = loop_ast.iter
loop_target = loop_ast.target
for varname in messages_varnames:
if _is_var_or_elems_access(loop_iter, varname):
assert isinstance(loop_target, jinja2.nodes.Name)
yield loop_ast, loop_target.name
break
def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
message_varnames = [
varname for _, varname in _iter_nodes_assign_messages_item(root)
]
# Search for {%- for content in message['content'] -%} loops
for loop_ast in root.find_all(jinja2.nodes.For):
loop_iter = loop_ast.iter
loop_target = loop_ast.target
for varname in message_varnames:
if _is_var_or_elems_access(loop_iter, varname, "content"):
assert isinstance(loop_target, jinja2.nodes.Name)
yield loop_ast, loop_target.name
break
def _try_extract_ast(chat_template: str) -> jinja2.nodes.Template | None:
import transformers.utils.chat_template_utils as hf_chat_utils
try:
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
return jinja_compiled.environment.parse(chat_template)
except Exception:
logger.exception("Error when compiling Jinja template")
return None
@lru_cache(maxsize=32)
def _detect_content_format(
chat_template: str,
*,
default: ChatTemplateContentFormat,
) -> ChatTemplateContentFormat:
jinja_ast = _try_extract_ast(chat_template)
if jinja_ast is None:
return default
try:
next(_iter_nodes_assign_content_item(jinja_ast))
except StopIteration:
return "string"
except Exception:
logger.exception("Error when parsing AST of Jinja template")
return default
else:
return "openai"
def _resolve_chat_template_content_format(
chat_template: str | None,
tools: list[dict[str, Any]] | None,
tokenizer: HfTokenizer,
*,
model_config: "ModelConfig",
) -> ChatTemplateContentFormat:
resolved_chat_template = resolve_chat_template(
tokenizer,
chat_template=chat_template,
tools=tools,
model_config=model_config,
)
jinja_text = (
resolved_chat_template
if isinstance(resolved_chat_template, str)
else load_chat_template(chat_template, is_literal=True)
)
detected_format = (
"string"
if jinja_text is None
else _detect_content_format(jinja_text, default="string")
)
return detected_format
@lru_cache
def _log_chat_template_content_format(
chat_template: str | None, # For caching purposes
given_format: ChatTemplateContentFormatOption,
detected_format: ChatTemplateContentFormatOption,
):
logger.info(
"Detected the chat template content format to be '%s'. "
"You can set `--chat-template-content-format` to override this.",
detected_format,
)
if given_format != "auto" and given_format != detected_format:
logger.warning(
"You specified `--chat-template-content-format %s` "
"which is different from the detected format '%s'. "
"If our automatic detection is incorrect, please consider "
"opening a GitHub issue so that we can improve it: "
"https://github.com/vllm-project/vllm/issues/new/choose",
given_format,
detected_format,
)
def resolve_chat_template_content_format(
chat_template: str | None,
tools: list[dict[str, Any]] | None,
given_format: ChatTemplateContentFormatOption,
tokenizer: HfTokenizer,
*,
model_config: "ModelConfig",
) -> ChatTemplateContentFormat:
if given_format != "auto":
return given_format
detected_format = _resolve_chat_template_content_format(
chat_template,
tools,
tokenizer,
model_config=model_config,
)
_log_chat_template_content_format(
chat_template,
given_format=given_format,
detected_format=detected_format,
)
return detected_format
# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
# only preserve the parse function used to resolve chat template kwargs
class AssistantTracker(jinja2.ext.Extension):
tags = {"generation"}
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.Node:
lineno = next(parser.stream).lineno
body = parser.parse_statements(("name:endgeneration",), drop_needle=True)
call = self.call_method("_generation_support")
call_block = jinja2.nodes.CallBlock(call, [], [], body)
return call_block.set_lineno(lineno)
def _resolve_chat_template_kwargs(chat_template: str) -> Set[str]:
env = jinja2.sandbox.ImmutableSandboxedEnvironment(
trim_blocks=True,
lstrip_blocks=True,
extensions=[AssistantTracker, jinja2.ext.loopcontrols],
)
parsed_content = env.parse(chat_template)
template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
return template_vars
_cached_resolve_chat_template_kwargs = lru_cache(_resolve_chat_template_kwargs)
@lru_cache
def _get_hf_base_chat_template_params() -> frozenset[str]:
from transformers import PreTrainedTokenizer
# Get standard parameters from HuggingFace's base tokenizer class.
# This dynamically extracts parameters from PreTrainedTokenizer's
# apply_chat_template method, ensuring compatibility with tokenizers
# that use **kwargs to receive standard parameters.
# Read signature from HF's base class - the single source of truth
base_sig = inspect.signature(PreTrainedTokenizer.apply_chat_template)
# Exclude VAR_KEYWORD (**kwargs) and VAR_POSITIONAL (*args) placeholders
return frozenset(
p.name
for p in base_sig.parameters.values()
if p.kind
not in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL)
)
def resolve_chat_template_kwargs(
tokenizer: HfTokenizer,
chat_template: str,
chat_template_kwargs: dict[str, Any],
raise_on_unexpected: bool = True,
) -> dict[str, Any]:
# We exclude chat_template from kwargs here, because
# chat template has been already resolved at this stage
unexpected_vars = {"chat_template", "tokenize"}
if raise_on_unexpected and (
unexpected_in_kwargs := unexpected_vars & chat_template_kwargs.keys()
):
raise ValueError(
"Found unexpected chat template kwargs from request: "
f"{unexpected_in_kwargs}"
)
fn_kw = {
k
for k in chat_template_kwargs
if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
}
template_vars = _cached_resolve_chat_template_kwargs(chat_template)
# Allow standard HF parameters even if tokenizer uses **kwargs to receive them
hf_base_params = _get_hf_base_chat_template_params()
accept_vars = (fn_kw | template_vars | hf_base_params) - unexpected_vars
return {k: v for k, v in chat_template_kwargs.items() if k in accept_vars}
def safe_apply_chat_template(
model_config: "ModelConfig",
tokenizer: HfTokenizer,
conversation: list[ConversationMessage],
*,
tools: list[dict[str, Any]] | None = None,
chat_template: str | None = None,
tokenize: bool = True,
**kwargs,
) -> str | list[int]:
chat_template = resolve_chat_template(
tokenizer,
chat_template=chat_template,
tools=tools,
model_config=model_config,
)
if chat_template is None:
raise ChatTemplateResolutionError(
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
)
resolved_kwargs = resolve_chat_template_kwargs(
tokenizer=tokenizer,
chat_template=chat_template,
chat_template_kwargs=kwargs,
)
try:
return tokenizer.apply_chat_template(
conversation=conversation, # type: ignore[arg-type]
tools=tools, # type: ignore[arg-type]
chat_template=chat_template,
tokenize=tokenize,
**resolved_kwargs,
)
# External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities.
except Exception as e:
# Log and report any library-related exceptions for further
# investigation.
logger.exception(
"An error occurred in `transformers` while applying chat template"
)
raise ValueError(str(e)) from e
class HfRenderer(RendererLike):
@classmethod
def from_config(
cls,
config: ModelConfig,
tokenizer_kwargs: dict[str, Any],
) -> "RendererLike":
return cls(config, tokenizer_kwargs)
def __init__(
self,
config: ModelConfig,
tokenizer_kwargs: dict[str, Any],
) -> None:
super().__init__()
self.config = config
if config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = cast(
HfTokenizer,
cached_get_tokenizer(
tokenizer_cls=CachedHfTokenizer, # type: ignore[type-abstract]
**tokenizer_kwargs,
),
)
self._tokenizer = tokenizer
@property
def tokenizer(self) -> HfTokenizer | None:
return self._tokenizer
def get_tokenizer(self) -> HfTokenizer:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")
return tokenizer
def render_messages(
self,
messages: list[ChatCompletionMessageParam],
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
model_config = self.config
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = parse_chat_messages(
messages,
model_config,
content_format=resolve_chat_template_content_format(
chat_template=kwargs.get("chat_template"),
tools=kwargs.get("tools"),
given_format=chat_template_content_format,
tokenizer=tokenizer,
model_config=model_config,
),
)
prompt_raw = safe_apply_chat_template(
model_config,
tokenizer,
conversation,
**kwargs,
)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
else TokensPrompt(prompt_token_ids=prompt_raw)
)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt # type: ignore[return-value]
async def render_messages_async(
self,
messages: list[ChatCompletionMessageParam],
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
model_config = self.config
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages,
model_config,
content_format=resolve_chat_template_content_format(
chat_template=kwargs.get("chat_template"),
tools=kwargs.get("tools"),
given_format=chat_template_content_format,
tokenizer=tokenizer,
model_config=model_config,
),
)
prompt_raw = safe_apply_chat_template(
model_config,
tokenizer,
conversation,
**kwargs,
)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
else TokensPrompt(prompt_token_ids=prompt_raw)
)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt # type: ignore[return-value]

147
vllm/renderers/mistral.py Normal file
View File

@@ -0,0 +1,147 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ConversationMessage,
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.async_utils import make_async
from .protocol import RendererLike
logger = init_logger(__name__)
def safe_apply_chat_template(
tokenizer: MistralTokenizer,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> str | list[int]:
from mistral_common.exceptions import MistralCommonException
try:
return tokenizer.apply_chat_template(messages, **kwargs)
# mistral-common uses assert statements to stop processing of input
# if input does not comply with the expected format.
# We convert those assertion errors to ValueErrors so they can be
# properly caught in the preprocessing_input step
except (AssertionError, MistralCommonException) as e:
raise ValueError(str(e)) from e
# External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities.
except Exception as e:
# Log and report any library-related exceptions for further
# investigation.
logger.exception(
"An error occurred in `mistral_common` while applying chat template"
)
raise ValueError(str(e)) from e
class MistralRenderer(RendererLike):
@classmethod
def from_config(
cls,
config: ModelConfig,
tokenizer_kwargs: dict[str, Any],
) -> "RendererLike":
return cls(config, tokenizer_kwargs)
def __init__(
self,
config: ModelConfig,
tokenizer_kwargs: dict[str, Any],
) -> None:
super().__init__()
self.config = config
if config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = cached_get_tokenizer(
tokenizer_cls=MistralTokenizer,
**tokenizer_kwargs,
)
self._tokenizer = tokenizer
self._apply_chat_template_executor = ThreadPoolExecutor(max_workers=1)
self._apply_chat_template_async = make_async(
safe_apply_chat_template, executor=self._apply_chat_template_executor
)
@property
def tokenizer(self) -> MistralTokenizer | None:
return self._tokenizer
def get_tokenizer(self) -> MistralTokenizer:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")
return tokenizer
def render_messages(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = parse_chat_messages(
messages,
self.config,
content_format="string",
)
prompt_raw = safe_apply_chat_template(tokenizer, messages, **kwargs)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
else TokensPrompt(prompt_token_ids=prompt_raw)
)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt # type: ignore[return-value]
async def render_messages_async(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
tokenizer = self.get_tokenizer()
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages,
self.config,
content_format="string",
)
prompt_raw = await self._apply_chat_template_async(
tokenizer, messages, **kwargs
)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
else TokensPrompt(prompt_token_ids=prompt_raw)
)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt # type: ignore[return-value]

View File

@@ -0,0 +1,48 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any, Protocol
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.tokenizers import TokenizerLike
if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ConversationMessage,
)
class RendererLike(Protocol):
@classmethod
def from_config(
cls,
config: "ModelConfig",
tokenizer_kwargs: dict[str, Any],
) -> "RendererLike":
raise NotImplementedError
@property
def tokenizer(self) -> TokenizerLike | None:
raise NotImplementedError
def get_tokenizer(self) -> TokenizerLike:
tokenizer = self.tokenizer
if tokenizer is None:
raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")
return tokenizer
def render_messages(
self,
messages: list["ChatCompletionMessageParam"],
**kwargs,
) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt]:
raise NotImplementedError
async def render_messages_async(
self,
messages: list["ChatCompletionMessageParam"],
**kwargs,
) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt]:
return self.render_messages(messages, **kwargs)

View File

@@ -0,0 +1,88 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from vllm.logger import init_logger
from vllm.tokenizers.registry import tokenizer_args_from_config
from vllm.utils.import_utils import resolve_obj_by_qualname
from .protocol import RendererLike
if TYPE_CHECKING:
from vllm.config import ModelConfig
logger = init_logger(__name__)
_VLLM_RENDERERS = {
"deepseek_v32": ("deepseek_v32", "DeepseekV32Renderer"),
"hf": ("hf", "HfRenderer"),
"grok2": ("grok2", "Grok2Renderer"),
"mistral": ("mistral", "MistralRenderer"),
"terratorch": ("terratorch", "TerratorchRenderer"),
}
@dataclass
class RendererRegistry:
# Renderer mode -> (renderer module, renderer class)
renderers: dict[str, tuple[str, str]] = field(default_factory=dict)
def register(self, renderer_mode: str, module: str, class_name: str) -> None:
if renderer_mode in self.renderers:
logger.warning(
"%s.%s is already registered for renderer_mode=%r. "
"It is overwritten by the new one.",
module,
class_name,
renderer_mode,
)
self.renderers[renderer_mode] = (module, class_name)
return None
def load_renderer_cls(self, renderer_mode: str) -> type[RendererLike]:
if renderer_mode not in self.renderers:
raise ValueError(f"No renderer registered for {renderer_mode=!r}.")
module, class_name = self.renderers[renderer_mode]
logger.debug_once(f"Loading {class_name} for {renderer_mode=!r}")
return resolve_obj_by_qualname(f"{module}.{class_name}")
def load_renderer(
self,
renderer_mode: str,
config: "ModelConfig",
tokenizer_kwargs: dict[str, Any],
) -> RendererLike:
renderer_cls = self.load_renderer_cls(renderer_mode)
return renderer_cls.from_config(config, tokenizer_kwargs)
RENDERER_REGISTRY = RendererRegistry(
{
mode: (f"vllm.renderers.{mod_relname}", cls_name)
for mode, (mod_relname, cls_name) in _VLLM_RENDERERS.items()
}
)
"""The global `RendererRegistry` instance."""
def renderer_from_config(config: "ModelConfig", **kwargs):
tokenizer_mode, tokenizer_name, args, kwargs = tokenizer_args_from_config(
config, **kwargs
)
if config.tokenizer_mode == "auto" and config.model_impl == "terratorch":
renderer_mode = "terratorch"
else:
renderer_mode = tokenizer_mode
return RENDERER_REGISTRY.load_renderer(
renderer_mode,
config,
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
)

View File

@@ -0,0 +1,85 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ConversationMessage,
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from .protocol import RendererLike
logger = init_logger(__name__)
class TerratorchRenderer(RendererLike):
@classmethod
def from_config(
cls,
config: "ModelConfig",
tokenizer_kwargs: dict[str, Any],
) -> "RendererLike":
return cls(config)
def __init__(self, config: ModelConfig) -> None:
super().__init__()
self.config = config
if not config.skip_tokenizer_init:
raise ValueError("Terratorch renderer requires `skip_tokenizer_init=True`")
@property
def tokenizer(self) -> TokenizerLike | None:
return None
def get_tokenizer(self) -> TokenizerLike:
raise ValueError("Tokenizer not available for Terratorch renderer")
def render_messages(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
model_config = self.config
conversation, mm_data, mm_uuids = parse_chat_messages(
messages,
model_config,
content_format="string",
)
prompt = TokensPrompt(prompt_token_ids=[1])
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt
async def render_messages_async(
self,
messages: list[ChatCompletionMessageParam],
**kwargs,
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
model_config = self.config
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages,
model_config,
content_format="string",
)
prompt = TokensPrompt(prompt_token_ids=[1]) # Dummy token IDs
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt

View File

@@ -23,9 +23,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.renderers import RendererLike
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from vllm.tokenizers import TokenizerLike
from vllm.tracing import init_tracer
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.usage.usage_lib import UsageContext
@@ -106,9 +107,7 @@ class AsyncLLM(EngineClient):
"enabling logging without default stat loggers."
)
tokenizer = cached_tokenizer_from_config(self.model_config)
self.input_processor = InputProcessor(self.vllm_config, tokenizer)
self.input_processor = InputProcessor(self.vllm_config)
self.io_processor = get_io_processor(
self.vllm_config,
self.model_config.io_processor_plugin,
@@ -709,13 +708,12 @@ class AsyncLLM(EngineClient):
def tokenizer(self) -> TokenizerLike | None:
return self.input_processor.tokenizer
async def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None:
raise ValueError(
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
def get_tokenizer(self) -> TokenizerLike:
return self.input_processor.get_tokenizer()
return self.tokenizer
@property
def renderer(self) -> RendererLike:
return self.input_processor.renderer
async def is_tracing_enabled(self) -> bool:
return self.observability_config.otlp_traces_endpoint is not None # type: ignore

View File

@@ -19,6 +19,7 @@ from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing.context import set_request_id
from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
from vllm.renderers import RendererLike
from vllm.sampling_params import _SAMPLING_EPS, SamplingParams
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
@@ -45,7 +46,6 @@ class InputProcessor:
def __init__(
self,
vllm_config: VllmConfig,
tokenizer: TokenizerLike | None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
) -> None:
self.vllm_config = vllm_config
@@ -61,8 +61,7 @@ class InputProcessor:
self.input_preprocessor = InputPreprocessor(
self.model_config,
tokenizer,
self.vllm_config.observability_config,
vllm_config.observability_config,
mm_registry,
mm_processor_cache=self.mm_processor_cache,
)
@@ -71,6 +70,13 @@ class InputProcessor:
def tokenizer(self) -> TokenizerLike | None:
return self.input_preprocessor.tokenizer
def get_tokenizer(self) -> TokenizerLike:
return self.input_preprocessor.get_tokenizer()
@property
def renderer(self) -> RendererLike:
return self.input_preprocessor.renderer
def _validate_logprobs(
self,
params: SamplingParams,

View File

@@ -21,9 +21,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.renderers import RendererLike
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from vllm.tokenizers import TokenizerLike
from vllm.tracing import init_tracer
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine import EngineCoreRequest
@@ -84,9 +85,7 @@ class LLMEngine:
self.dp_group = None
self.should_execute_dummy_batch = False
tokenizer = cached_tokenizer_from_config(self.model_config)
self.input_processor = InputProcessor(self.vllm_config, tokenizer)
self.input_processor = InputProcessor(self.vllm_config)
self.io_processor = get_io_processor(
self.vllm_config,
self.model_config.io_processor_plugin,
@@ -357,12 +356,11 @@ class LLMEngine:
return self.input_processor.tokenizer
def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None:
raise ValueError(
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
return self.input_processor.get_tokenizer()
return self.tokenizer
@property
def renderer(self) -> RendererLike:
return self.input_processor.renderer
def do_log_stats(self) -> None:
"""Log stats if logging is enabled."""