[Frontend] Introduce Renderer for processing chat messages (using ModelConfig) (#30200)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -11,9 +11,9 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors import IOProcessor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import RendererLike
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.input_processor import InputProcessor
|
||||
|
||||
@@ -26,6 +26,10 @@ class EngineClient(ABC):
|
||||
input_processor: InputProcessor
|
||||
io_processor: IOProcessor | None
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def renderer(self) -> RendererLike: ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_running(self) -> bool: ...
|
||||
@@ -88,11 +92,6 @@ class EngineClient(ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_tokenizer(self) -> TokenizerLike:
|
||||
"""Get the tokenizer"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def is_tracing_enabled(self) -> bool: ...
|
||||
|
||||
|
||||
@@ -2,22 +2,15 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import Counter, defaultdict, deque
|
||||
from collections import Counter, defaultdict
|
||||
from collections.abc import Awaitable, Callable, Iterable
|
||||
from functools import cached_property, lru_cache, partial
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast
|
||||
|
||||
import jinja2
|
||||
import jinja2.ext
|
||||
import jinja2.meta
|
||||
import jinja2.nodes
|
||||
import jinja2.parser
|
||||
import jinja2.sandbox
|
||||
import transformers.utils.chat_template_utils as hf_chat_utils
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionContentPartImageParam,
|
||||
@@ -39,7 +32,6 @@ from openai.types.responses import ResponseInputImageParam
|
||||
from openai_harmony import Message as OpenAIHarmonyMessage
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin
|
||||
|
||||
# pydantic needs the TypedDict from typing_extensions
|
||||
from typing_extensions import Required, TypedDict
|
||||
@@ -50,24 +42,35 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import SupportsMultiModal
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
|
||||
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.utils.func_utils import supports_kw
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
else:
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name == "resolve_hf_chat_template":
|
||||
from vllm.renderers.hf import resolve_chat_template
|
||||
|
||||
warnings.warn(
|
||||
"`vllm.entrypoints.chat_utils.resolve_hf_chat_template` has been moved to "
|
||||
"`vllm.renderers.hf.resolve_chat_template`. "
|
||||
"The old name will be removed in v0.16.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return resolve_chat_template
|
||||
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
class ChatTemplateResolutionError(ValueError):
|
||||
"""Raised when chat template resolution fails.
|
||||
|
||||
@@ -320,325 +323,8 @@ class ConversationMessage(TypedDict, total=False):
|
||||
# Passed in by user
|
||||
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]
|
||||
|
||||
# Used internally
|
||||
_ChatTemplateContentFormat = Literal["string", "openai"]
|
||||
|
||||
|
||||
def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
|
||||
if isinstance(node, jinja2.nodes.Name):
|
||||
return node.ctx == "load" and node.name == varname
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
|
||||
if isinstance(node, jinja2.nodes.Getitem):
|
||||
return (
|
||||
_is_var_access(node.node, varname)
|
||||
and isinstance(node.arg, jinja2.nodes.Const)
|
||||
and node.arg.value == key
|
||||
)
|
||||
|
||||
if isinstance(node, jinja2.nodes.Getattr):
|
||||
return _is_var_access(node.node, varname) and node.attr == key
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _is_var_or_elems_access(
|
||||
node: jinja2.nodes.Node,
|
||||
varname: str,
|
||||
key: str | None = None,
|
||||
) -> bool:
|
||||
if isinstance(node, jinja2.nodes.Filter):
|
||||
return node.node is not None and _is_var_or_elems_access(
|
||||
node.node, varname, key
|
||||
)
|
||||
if isinstance(node, jinja2.nodes.Test):
|
||||
return _is_var_or_elems_access(node.node, varname, key)
|
||||
|
||||
if isinstance(node, jinja2.nodes.Getitem) and isinstance(
|
||||
node.arg, jinja2.nodes.Slice
|
||||
):
|
||||
return _is_var_or_elems_access(node.node, varname, key)
|
||||
|
||||
return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname)
|
||||
|
||||
|
||||
def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
|
||||
# Global variable that is implicitly defined at the root
|
||||
yield root, varname
|
||||
|
||||
# Iterative BFS
|
||||
related_varnames = deque([varname])
|
||||
while related_varnames:
|
||||
related_varname = related_varnames.popleft()
|
||||
|
||||
for assign_ast in root.find_all(jinja2.nodes.Assign):
|
||||
lhs = assign_ast.target
|
||||
rhs = assign_ast.node
|
||||
|
||||
if _is_var_or_elems_access(rhs, related_varname):
|
||||
assert isinstance(lhs, jinja2.nodes.Name)
|
||||
yield assign_ast, lhs.name
|
||||
|
||||
# Avoid infinite looping for self-assignment
|
||||
if lhs.name != related_varname:
|
||||
related_varnames.append(lhs.name)
|
||||
|
||||
|
||||
# NOTE: The proper way to handle this is to build a CFG so that we can handle
|
||||
# the scope in which each variable is defined, but that is too complicated
|
||||
def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
|
||||
messages_varnames = [
|
||||
varname for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
|
||||
]
|
||||
|
||||
# Search for {%- for message in messages -%} loops
|
||||
for loop_ast in root.find_all(jinja2.nodes.For):
|
||||
loop_iter = loop_ast.iter
|
||||
loop_target = loop_ast.target
|
||||
|
||||
for varname in messages_varnames:
|
||||
if _is_var_or_elems_access(loop_iter, varname):
|
||||
assert isinstance(loop_target, jinja2.nodes.Name)
|
||||
yield loop_ast, loop_target.name
|
||||
break
|
||||
|
||||
|
||||
def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
|
||||
message_varnames = [
|
||||
varname for _, varname in _iter_nodes_assign_messages_item(root)
|
||||
]
|
||||
|
||||
# Search for {%- for content in message['content'] -%} loops
|
||||
for loop_ast in root.find_all(jinja2.nodes.For):
|
||||
loop_iter = loop_ast.iter
|
||||
loop_target = loop_ast.target
|
||||
|
||||
for varname in message_varnames:
|
||||
if _is_var_or_elems_access(loop_iter, varname, "content"):
|
||||
assert isinstance(loop_target, jinja2.nodes.Name)
|
||||
yield loop_ast, loop_target.name
|
||||
break
|
||||
|
||||
|
||||
def _try_extract_ast(chat_template: str) -> jinja2.nodes.Template | None:
|
||||
try:
|
||||
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
|
||||
return jinja_compiled.environment.parse(chat_template)
|
||||
except Exception:
|
||||
logger.exception("Error when compiling Jinja template")
|
||||
return None
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _detect_content_format(
|
||||
chat_template: str,
|
||||
*,
|
||||
default: _ChatTemplateContentFormat,
|
||||
) -> _ChatTemplateContentFormat:
|
||||
jinja_ast = _try_extract_ast(chat_template)
|
||||
if jinja_ast is None:
|
||||
return default
|
||||
|
||||
try:
|
||||
next(_iter_nodes_assign_content_item(jinja_ast))
|
||||
except StopIteration:
|
||||
return "string"
|
||||
except Exception:
|
||||
logger.exception("Error when parsing AST of Jinja template")
|
||||
return default
|
||||
else:
|
||||
return "openai"
|
||||
|
||||
|
||||
def resolve_mistral_chat_template(
|
||||
chat_template: str | None,
|
||||
**kwargs: Any,
|
||||
) -> str | None:
|
||||
if chat_template is not None or kwargs.get("chat_template_kwargs") is not None:
|
||||
raise ValueError(
|
||||
"'chat_template' or 'chat_template_kwargs' cannot be overridden "
|
||||
"for mistral tokenizer."
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], str | None]()
|
||||
"""
|
||||
Used in `_try_get_processor_chat_template` to avoid calling
|
||||
`cached_get_processor` again if the processor fails to be loaded.
|
||||
|
||||
This is needed because `lru_cache` does not cache when an exception happens.
|
||||
"""
|
||||
|
||||
|
||||
def _try_get_processor_chat_template(
|
||||
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
|
||||
model_config: ModelConfig,
|
||||
) -> str | None:
|
||||
cache_key = (tokenizer.name_or_path, model_config.trust_remote_code)
|
||||
if cache_key in _PROCESSOR_CHAT_TEMPLATES:
|
||||
return _PROCESSOR_CHAT_TEMPLATES[cache_key]
|
||||
|
||||
try:
|
||||
processor = cached_get_processor(
|
||||
tokenizer.name_or_path,
|
||||
processor_cls=(
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
ProcessorMixin,
|
||||
),
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
)
|
||||
if (
|
||||
isinstance(processor, ProcessorMixin)
|
||||
and hasattr(processor, "chat_template")
|
||||
and (chat_template := processor.chat_template) is not None
|
||||
):
|
||||
_PROCESSOR_CHAT_TEMPLATES[cache_key] = chat_template
|
||||
return chat_template
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Failed to load AutoProcessor chat template for %s",
|
||||
tokenizer.name_or_path,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
_PROCESSOR_CHAT_TEMPLATES[cache_key] = None
|
||||
return None
|
||||
|
||||
|
||||
def resolve_hf_chat_template(
|
||||
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
) -> str | None:
|
||||
# 1st priority: The given chat template
|
||||
if chat_template is not None:
|
||||
return chat_template
|
||||
|
||||
# 2nd priority: AutoProcessor chat template, unless tool calling is enabled
|
||||
if tools is None:
|
||||
chat_template = _try_get_processor_chat_template(tokenizer, model_config)
|
||||
if chat_template is not None:
|
||||
return chat_template
|
||||
|
||||
# 3rd priority: AutoTokenizer chat template
|
||||
try:
|
||||
return tokenizer.get_chat_template(chat_template, tools=tools)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Failed to load AutoTokenizer chat template for %s",
|
||||
tokenizer.name_or_path,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# 4th priority: Predefined fallbacks
|
||||
path = get_chat_template_fallback_path(
|
||||
model_type=model_config.hf_config.model_type,
|
||||
tokenizer_name_or_path=model_config.tokenizer,
|
||||
)
|
||||
if path is not None:
|
||||
logger.info_once(
|
||||
"Loading chat template fallback for %s as there isn't one "
|
||||
"defined on HF Hub.",
|
||||
tokenizer.name_or_path,
|
||||
)
|
||||
chat_template = load_chat_template(path)
|
||||
else:
|
||||
logger.debug_once(
|
||||
"There is no chat template fallback for %s", tokenizer.name_or_path
|
||||
)
|
||||
|
||||
return chat_template
|
||||
|
||||
|
||||
def _resolve_chat_template_content_format(
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
tokenizer: TokenizerLike | None,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
) -> _ChatTemplateContentFormat:
|
||||
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
|
||||
hf_chat_template = resolve_hf_chat_template(
|
||||
tokenizer,
|
||||
chat_template=chat_template,
|
||||
tools=tools,
|
||||
model_config=model_config,
|
||||
)
|
||||
else:
|
||||
hf_chat_template = None
|
||||
|
||||
jinja_text = (
|
||||
hf_chat_template
|
||||
if isinstance(hf_chat_template, str)
|
||||
else load_chat_template(chat_template, is_literal=True)
|
||||
)
|
||||
|
||||
detected_format = (
|
||||
"string"
|
||||
if jinja_text is None
|
||||
else _detect_content_format(jinja_text, default="string")
|
||||
)
|
||||
|
||||
return detected_format
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _log_chat_template_content_format(
|
||||
chat_template: str | None,
|
||||
given_format: ChatTemplateContentFormatOption,
|
||||
detected_format: ChatTemplateContentFormatOption,
|
||||
):
|
||||
logger.info(
|
||||
"Detected the chat template content format to be '%s'. "
|
||||
"You can set `--chat-template-content-format` to override this.",
|
||||
detected_format,
|
||||
)
|
||||
|
||||
if given_format != "auto" and given_format != detected_format:
|
||||
logger.warning(
|
||||
"You specified `--chat-template-content-format %s` "
|
||||
"which is different from the detected format '%s'. "
|
||||
"If our automatic detection is incorrect, please consider "
|
||||
"opening a GitHub issue so that we can improve it: "
|
||||
"https://github.com/vllm-project/vllm/issues/new/choose",
|
||||
given_format,
|
||||
detected_format,
|
||||
)
|
||||
|
||||
|
||||
def resolve_chat_template_content_format(
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
given_format: ChatTemplateContentFormatOption,
|
||||
tokenizer: TokenizerLike | None,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
) -> _ChatTemplateContentFormat:
|
||||
if given_format != "auto":
|
||||
return given_format
|
||||
|
||||
detected_format = _resolve_chat_template_content_format(
|
||||
chat_template,
|
||||
tools,
|
||||
tokenizer,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
_log_chat_template_content_format(
|
||||
chat_template,
|
||||
given_format=given_format,
|
||||
detected_format=detected_format,
|
||||
)
|
||||
|
||||
return detected_format
|
||||
# After resolving "auto"
|
||||
ChatTemplateContentFormat = Literal["string", "openai"]
|
||||
|
||||
|
||||
ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"]
|
||||
@@ -1593,7 +1279,7 @@ _ToolParser = partial(cast, ChatCompletionToolMessageParam)
|
||||
def _parse_chat_message_content(
|
||||
message: ChatCompletionMessageParam,
|
||||
mm_tracker: BaseMultiModalItemTracker,
|
||||
content_format: _ChatTemplateContentFormat,
|
||||
content_format: ChatTemplateContentFormat,
|
||||
interleave_strings: bool,
|
||||
) -> list[ConversationMessage]:
|
||||
role = message["role"]
|
||||
@@ -1669,7 +1355,7 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
|
||||
def parse_chat_messages(
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
model_config: ModelConfig,
|
||||
content_format: _ChatTemplateContentFormat,
|
||||
content_format: ChatTemplateContentFormat,
|
||||
) -> tuple[
|
||||
list[ConversationMessage],
|
||||
MultiModalDataDict | None,
|
||||
@@ -1697,13 +1383,13 @@ def parse_chat_messages(
|
||||
return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
|
||||
|
||||
|
||||
def parse_chat_messages_futures(
|
||||
async def parse_chat_messages_async(
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
model_config: ModelConfig,
|
||||
content_format: _ChatTemplateContentFormat,
|
||||
content_format: ChatTemplateContentFormat,
|
||||
) -> tuple[
|
||||
list[ConversationMessage],
|
||||
Awaitable[MultiModalDataDict | None],
|
||||
MultiModalDataDict | None,
|
||||
MultiModalUUIDDict | None,
|
||||
]:
|
||||
conversation: list[ConversationMessage] = []
|
||||
@@ -1725,174 +1411,7 @@ def parse_chat_messages_futures(
|
||||
|
||||
_postprocess_messages(conversation)
|
||||
|
||||
return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
|
||||
|
||||
|
||||
# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
|
||||
# only preserve the parse function used to resolve chat template kwargs
|
||||
class AssistantTracker(jinja2.ext.Extension):
|
||||
tags = {"generation"}
|
||||
|
||||
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
|
||||
lineno = next(parser.stream).lineno
|
||||
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
|
||||
call = self.call_method("_generation_support")
|
||||
call_block = jinja2.nodes.CallBlock(call, [], [], body)
|
||||
return call_block.set_lineno(lineno)
|
||||
|
||||
|
||||
def _resolve_chat_template_kwargs(
|
||||
chat_template: str,
|
||||
):
|
||||
env = jinja2.sandbox.ImmutableSandboxedEnvironment(
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
extensions=[AssistantTracker, jinja2.ext.loopcontrols],
|
||||
)
|
||||
parsed_content = env.parse(chat_template)
|
||||
template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
|
||||
return template_vars
|
||||
|
||||
|
||||
_cached_resolve_chat_template_kwargs = lru_cache(_resolve_chat_template_kwargs)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _get_hf_base_chat_template_params() -> frozenset[str]:
|
||||
# Get standard parameters from HuggingFace's base tokenizer class.
|
||||
# This dynamically extracts parameters from PreTrainedTokenizer's
|
||||
# apply_chat_template method, ensuring compatibility with tokenizers
|
||||
# that use **kwargs to receive standard parameters.
|
||||
|
||||
# Read signature from HF's base class - the single source of truth
|
||||
base_sig = inspect.signature(PreTrainedTokenizer.apply_chat_template)
|
||||
# Exclude VAR_KEYWORD (**kwargs) and VAR_POSITIONAL (*args) placeholders
|
||||
return frozenset(
|
||||
p.name
|
||||
for p in base_sig.parameters.values()
|
||||
if p.kind
|
||||
not in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL)
|
||||
)
|
||||
|
||||
|
||||
def resolve_chat_template_kwargs(
|
||||
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
|
||||
chat_template: str,
|
||||
chat_template_kwargs: dict[str, Any],
|
||||
raise_on_unexpected: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
# We exclude chat_template from kwargs here, because
|
||||
# chat template has been already resolved at this stage
|
||||
unexpected_vars = {"chat_template", "tokenize"}
|
||||
if raise_on_unexpected and (
|
||||
unexpected_in_kwargs := unexpected_vars & chat_template_kwargs.keys()
|
||||
):
|
||||
raise ValueError(
|
||||
"Found unexpected chat template kwargs from request: "
|
||||
f"{unexpected_in_kwargs}"
|
||||
)
|
||||
|
||||
fn_kw = {
|
||||
k
|
||||
for k in chat_template_kwargs
|
||||
if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
|
||||
}
|
||||
template_vars = _cached_resolve_chat_template_kwargs(chat_template)
|
||||
|
||||
# Allow standard HF parameters even if tokenizer uses **kwargs to receive them
|
||||
hf_base_params = _get_hf_base_chat_template_params()
|
||||
|
||||
accept_vars = (fn_kw | template_vars | hf_base_params) - unexpected_vars
|
||||
return {k: v for k, v in chat_template_kwargs.items() if k in accept_vars}
|
||||
|
||||
|
||||
def apply_hf_chat_template(
|
||||
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
|
||||
conversation: list[ConversationMessage],
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
hf_chat_template = resolve_hf_chat_template(
|
||||
tokenizer,
|
||||
chat_template=chat_template,
|
||||
tools=tools,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
if hf_chat_template is None:
|
||||
raise ChatTemplateResolutionError(
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one."
|
||||
)
|
||||
|
||||
resolved_kwargs = resolve_chat_template_kwargs(
|
||||
tokenizer=tokenizer,
|
||||
chat_template=hf_chat_template,
|
||||
chat_template_kwargs=kwargs,
|
||||
)
|
||||
|
||||
try:
|
||||
return tokenizer.apply_chat_template(
|
||||
conversation=conversation, # type: ignore[arg-type]
|
||||
tools=tools, # type: ignore[arg-type]
|
||||
chat_template=hf_chat_template,
|
||||
tokenize=False,
|
||||
**resolved_kwargs,
|
||||
)
|
||||
|
||||
# External library exceptions can sometimes occur despite the framework's
|
||||
# internal exception management capabilities.
|
||||
except Exception as e:
|
||||
# Log and report any library-related exceptions for further
|
||||
# investigation.
|
||||
logger.exception(
|
||||
"An error occurred in `transformers` while applying chat template"
|
||||
)
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
|
||||
def apply_mistral_chat_template(
|
||||
tokenizer: "MistralTokenizer",
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
**kwargs: Any,
|
||||
) -> list[int]:
|
||||
from mistral_common.exceptions import MistralCommonException
|
||||
|
||||
# The return value of resolve_mistral_chat_template is always None,
|
||||
# and we won't use it.
|
||||
resolve_mistral_chat_template(
|
||||
chat_template=chat_template,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
try:
|
||||
return tokenizer.apply_chat_template(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
# mistral-common uses assert statements to stop processing of input
|
||||
# if input does not comply with the expected format.
|
||||
# We convert those assertion errors to ValueErrors so they can be
|
||||
# properly caught in the preprocessing_input step
|
||||
except (AssertionError, MistralCommonException) as e:
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
# External library exceptions can sometimes occur despite the framework's
|
||||
# internal exception management capabilities.
|
||||
except Exception as e:
|
||||
# Log and report any library-related exceptions for further
|
||||
# investigation.
|
||||
logger.exception(
|
||||
"An error occurred in `mistral_common` while applying chat template"
|
||||
)
|
||||
raise ValueError(str(e)) from e
|
||||
return conversation, await mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
|
||||
|
||||
|
||||
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
|
||||
|
||||
@@ -37,10 +37,6 @@ from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatTemplateContentFormatOption,
|
||||
apply_hf_chat_template,
|
||||
apply_mistral_chat_template,
|
||||
parse_chat_messages,
|
||||
resolve_chat_template_content_format,
|
||||
)
|
||||
from vllm.entrypoints.pooling.score.utils import (
|
||||
ScoreContentPartParam,
|
||||
@@ -786,7 +782,7 @@ class LLM:
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
chat_template_kwargs: dict[str, Any] | None = None,
|
||||
mm_processor_kwargs: dict[str, Any] | None = None,
|
||||
) -> list[TokensPrompt]:
|
||||
) -> list[TextPrompt | TokensPrompt]:
|
||||
"""
|
||||
Generate prompt for a chat conversation. The pre-processed
|
||||
prompt can then be used as input for the other LLM methods.
|
||||
@@ -807,63 +803,27 @@ class LLM:
|
||||
# messages is list[...]
|
||||
list_of_messages = [cast(list[ChatCompletionMessageParam], messages)]
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
model_config = self.model_config
|
||||
resolved_content_format = resolve_chat_template_content_format(
|
||||
chat_template,
|
||||
tools,
|
||||
chat_template_content_format,
|
||||
tokenizer,
|
||||
model_config=model_config,
|
||||
)
|
||||
renderer = self.llm_engine.renderer
|
||||
|
||||
_chat_template_kwargs: dict[str, Any] = dict(
|
||||
chat_template=chat_template,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=continue_final_message,
|
||||
tools=tools,
|
||||
)
|
||||
_chat_template_kwargs.update(chat_template_kwargs or {})
|
||||
chat_template_kwargs = {
|
||||
"chat_template": chat_template,
|
||||
"add_generation_prompt": add_generation_prompt,
|
||||
"continue_final_message": continue_final_message,
|
||||
"tools": tools,
|
||||
**(chat_template_kwargs or {}),
|
||||
}
|
||||
|
||||
prompts: list[TokensPrompt] = []
|
||||
prompts = list[TextPrompt | TokensPrompt]()
|
||||
|
||||
for msgs in list_of_messages:
|
||||
# NOTE: _parse_chat_message_content_parts() currently doesn't
|
||||
# NOTE: renderer.render_messages() currently doesn't
|
||||
# handle mm_processor_kwargs, since there is no implementation in
|
||||
# the chat message parsing for it.
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
_, prompt = renderer.render_messages(
|
||||
msgs,
|
||||
model_config,
|
||||
content_format=resolved_content_format,
|
||||
chat_template_content_format=chat_template_content_format,
|
||||
**chat_template_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
prompt_token_ids = apply_mistral_chat_template(
|
||||
tokenizer,
|
||||
messages=msgs,
|
||||
**_chat_template_kwargs,
|
||||
)
|
||||
else:
|
||||
prompt_str = apply_hf_chat_template(
|
||||
tokenizer=tokenizer,
|
||||
conversation=conversation,
|
||||
model_config=model_config,
|
||||
**_chat_template_kwargs,
|
||||
)
|
||||
# Special tokens are already included in chat templates so
|
||||
# should not be added by the tokenizer in this case.
|
||||
prompt_token_ids = tokenizer.encode(
|
||||
prompt_str, add_special_tokens=False
|
||||
)
|
||||
|
||||
prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
|
||||
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
if mm_processor_kwargs is not None:
|
||||
prompt["mm_processor_kwargs"] = mm_processor_kwargs
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ import vllm.envs as envs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.anthropic.serving import AnthropicServingMessages
|
||||
from vllm.entrypoints.chat_utils import load_chat_template
|
||||
from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.mcp.tool_server import DemoToolServer, MCPToolServer, ToolServer
|
||||
@@ -62,7 +63,6 @@ from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
|
||||
from vllm.entrypoints.utils import (
|
||||
cli_env_setup,
|
||||
log_non_default_args,
|
||||
process_chat_template,
|
||||
process_lora_modules,
|
||||
sanitize_message,
|
||||
)
|
||||
@@ -662,9 +662,7 @@ async def init_app_state(
|
||||
supported_tasks = await engine_client.get_supported_tasks()
|
||||
logger.info("Supported tasks: %s", supported_tasks)
|
||||
|
||||
resolved_chat_template = await process_chat_template(
|
||||
args.chat_template, engine_client, vllm_config.model_config
|
||||
)
|
||||
resolved_chat_template = load_chat_template(args.chat_template)
|
||||
|
||||
if args.tool_server == "demo":
|
||||
tool_server: ToolServer | None = DemoToolServer()
|
||||
|
||||
@@ -186,8 +186,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
# Get the tokenizer from the engine
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
renderer = self.engine_client.renderer
|
||||
|
||||
# Create a minimal dummy request
|
||||
dummy_request = ChatCompletionRequest(
|
||||
@@ -203,7 +202,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# 3. Tokenizer initialization for chat
|
||||
await self._preprocess_chat(
|
||||
dummy_request,
|
||||
tokenizer,
|
||||
renderer,
|
||||
dummy_request.messages,
|
||||
chat_template=self.chat_template,
|
||||
chat_template_content_format=self.chat_template_content_format,
|
||||
@@ -247,7 +246,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
raise self.engine_client.dead_error
|
||||
|
||||
try:
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
renderer = self.engine_client.renderer
|
||||
tokenizer = renderer.tokenizer
|
||||
|
||||
tool_parser = self.tool_parser
|
||||
|
||||
@@ -308,7 +308,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
conversation, engine_prompts = await self._preprocess_chat(
|
||||
request,
|
||||
tokenizer,
|
||||
renderer,
|
||||
request.messages,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
chat_template_content_format=self.chat_template_content_format,
|
||||
@@ -365,8 +365,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
)
|
||||
|
||||
model_name = self.models.model_name(lora_request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
except (ValueError, TypeError, RuntimeError) as e:
|
||||
logger.exception("Error preparing request components")
|
||||
return self.create_error_response(e)
|
||||
@@ -463,6 +461,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
(result_generator,) = generators
|
||||
|
||||
# Streaming response
|
||||
tokenizer = self.renderer.tokenizer
|
||||
|
||||
if request.stream:
|
||||
return self.chat_completion_stream_generator(
|
||||
request,
|
||||
@@ -1784,7 +1784,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
else:
|
||||
if tokenizer is None:
|
||||
raise ValueError(
|
||||
"Tokenizer not available when `skip_tokenizer_init=True`"
|
||||
"Unable to get tokenizer because `skip_tokenizer_init=True`"
|
||||
)
|
||||
|
||||
token = tokenizer.decode(token_id)
|
||||
|
||||
@@ -117,12 +117,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
)
|
||||
|
||||
try:
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
renderer = self._get_renderer(tokenizer)
|
||||
|
||||
renderer = self._get_completion_renderer()
|
||||
engine_prompts = await renderer.render_prompt_and_embeds(
|
||||
prompt_or_prompts=request.prompt,
|
||||
prompt_embeds=request.prompt_embeds,
|
||||
@@ -163,11 +158,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
except (ValueError, TypeError, RuntimeError) as e:
|
||||
logger.exception("Error preparing request components")
|
||||
return self.create_error_response(e)
|
||||
@@ -280,6 +270,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
stream = request.stream and not request.use_beam_search
|
||||
|
||||
# Streaming response
|
||||
tokenizer = self.renderer.tokenizer
|
||||
|
||||
if stream:
|
||||
return self.completion_stream_generator(
|
||||
request,
|
||||
|
||||
@@ -6,10 +6,9 @@ import sys
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from http import HTTPStatus
|
||||
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
|
||||
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar, cast
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
@@ -26,10 +25,6 @@ from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatTemplateContentFormatOption,
|
||||
ConversationMessage,
|
||||
apply_hf_chat_template,
|
||||
apply_mistral_chat_template,
|
||||
parse_chat_messages_futures,
|
||||
resolve_chat_template_content_format,
|
||||
)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
@@ -113,10 +108,9 @@ from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
from vllm.renderers import RendererLike
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.tracing import (
|
||||
contains_trace_headers,
|
||||
@@ -127,10 +121,8 @@ from vllm.utils import random_uuid
|
||||
from vllm.utils.async_utils import (
|
||||
AsyncMicrobatchTokenizer,
|
||||
collect_from_async_generator,
|
||||
make_async,
|
||||
merge_async_iterators,
|
||||
)
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
|
||||
|
||||
@@ -215,7 +207,6 @@ class ResponseGenerationMixin:
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[RequestT]):
|
||||
# Shared across all requests
|
||||
request: RequestT
|
||||
raw_request: Request | None = None
|
||||
model_name: str
|
||||
@@ -223,9 +214,6 @@ class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, Generic[Requ
|
||||
created_time: int = field(default_factory=lambda: int(time.time()))
|
||||
lora_request: LoRARequest | None = None
|
||||
|
||||
# Shared across most requests
|
||||
tokenizer: TokenizerLike | None = None
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class ClassificationServeContext(ServeContext[ClassificationRequest]):
|
||||
@@ -261,16 +249,13 @@ class OpenAIServing:
|
||||
|
||||
self.request_logger = request_logger
|
||||
self.return_tokens_as_token_ids = return_tokens_as_token_ids
|
||||
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
|
||||
self._apply_mistral_chat_template_async = make_async(
|
||||
apply_mistral_chat_template, executor=self._tokenizer_executor
|
||||
)
|
||||
|
||||
self._async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer] = {}
|
||||
self.log_error_stack = log_error_stack
|
||||
|
||||
self.input_processor = self.models.input_processor
|
||||
self.io_processor = self.models.io_processor
|
||||
self.renderer = self.models.renderer
|
||||
self.model_config = self.models.model_config
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
|
||||
@@ -557,14 +542,14 @@ class OpenAIServing:
|
||||
prompt_logprobs=None,
|
||||
)
|
||||
|
||||
def _get_renderer(self, tokenizer: TokenizerLike | None) -> BaseRenderer:
|
||||
def _get_completion_renderer(self) -> BaseRenderer:
|
||||
"""
|
||||
Get a Renderer instance with the provided tokenizer.
|
||||
Uses shared async tokenizer pool for efficiency.
|
||||
"""
|
||||
return CompletionRenderer(
|
||||
model_config=self.model_config,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer=self.renderer.tokenizer,
|
||||
async_tokenizer_pool=self._async_tokenizer_pool,
|
||||
)
|
||||
|
||||
@@ -1183,7 +1168,7 @@ class OpenAIServing:
|
||||
async def _preprocess_chat(
|
||||
self,
|
||||
request: ChatLikeRequest | ResponsesRequest,
|
||||
tokenizer: TokenizerLike | None,
|
||||
renderer: RendererLike,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
chat_template: str | None,
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
@@ -1196,59 +1181,58 @@ class OpenAIServing:
|
||||
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
|
||||
add_special_tokens: bool = False,
|
||||
) -> tuple[list[ConversationMessage], list[TokensPrompt]]:
|
||||
model_config = self.model_config
|
||||
|
||||
resolved_content_format = resolve_chat_template_content_format(
|
||||
chat_template,
|
||||
tool_dicts,
|
||||
chat_template_content_format,
|
||||
tokenizer,
|
||||
model_config=model_config,
|
||||
)
|
||||
conversation, mm_data_future, mm_uuids = parse_chat_messages_futures(
|
||||
messages,
|
||||
model_config,
|
||||
content_format=resolved_content_format,
|
||||
)
|
||||
|
||||
_chat_template_kwargs: dict[str, Any] = dict(
|
||||
chat_template=chat_template,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=continue_final_message,
|
||||
tools=tool_dicts,
|
||||
documents=documents,
|
||||
)
|
||||
_chat_template_kwargs |= self._prepare_extra_chat_template_kwargs(
|
||||
chat_template_kwargs = {
|
||||
"chat_template": chat_template,
|
||||
"add_generation_prompt": add_generation_prompt,
|
||||
"continue_final_message": continue_final_message,
|
||||
"tools": tool_dicts,
|
||||
"documents": documents,
|
||||
**(chat_template_kwargs or {}),
|
||||
}
|
||||
chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
|
||||
chat_template_kwargs,
|
||||
default_chat_template_kwargs,
|
||||
)
|
||||
|
||||
request_prompt: str | list[int]
|
||||
# Use the async tokenizer in `OpenAIServing` if possible.
|
||||
# Later we can move it into the renderer so that we can return both
|
||||
# text and token IDs in the same prompt from `render_messages_async`
|
||||
# which is used for logging and `enable_response_messages`.
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
if tokenizer is None:
|
||||
request_prompt = "placeholder"
|
||||
elif isinstance(tokenizer, MistralTokenizer):
|
||||
request_prompt = await self._apply_mistral_chat_template_async(
|
||||
tokenizer,
|
||||
messages=messages,
|
||||
**_chat_template_kwargs,
|
||||
)
|
||||
elif isinstance(tokenizer, DeepseekV32Tokenizer):
|
||||
request_prompt = tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
messages=messages,
|
||||
model_config=model_config,
|
||||
**_chat_template_kwargs,
|
||||
conversation, engine_prompt = await renderer.render_messages_async(
|
||||
messages,
|
||||
chat_template_content_format=chat_template_content_format,
|
||||
tokenize=(
|
||||
chat_template_kwargs.pop("tokenize", False)
|
||||
or isinstance(renderer.tokenizer, MistralTokenizer)
|
||||
),
|
||||
**chat_template_kwargs,
|
||||
)
|
||||
|
||||
if "prompt_token_ids" not in engine_prompt:
|
||||
extra_data = engine_prompt
|
||||
engine_prompt = await self._tokenize_prompt_input_async(
|
||||
request,
|
||||
renderer.get_tokenizer(),
|
||||
engine_prompt["prompt"],
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
# Fill in other keys like MM data
|
||||
engine_prompt.update(extra_data) # type: ignore
|
||||
else:
|
||||
request_prompt = apply_hf_chat_template(
|
||||
tokenizer=tokenizer,
|
||||
conversation=conversation,
|
||||
model_config=model_config,
|
||||
**_chat_template_kwargs,
|
||||
self._validate_input(
|
||||
request=request,
|
||||
input_ids=engine_prompt["prompt_token_ids"], # type: ignore
|
||||
input_text="",
|
||||
)
|
||||
|
||||
mm_data = await mm_data_future
|
||||
engine_prompt = cast(TokensPrompt, engine_prompt)
|
||||
|
||||
if request.mm_processor_kwargs is not None:
|
||||
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
|
||||
if (cache_salt := getattr(request, "cache_salt", None)) is not None:
|
||||
engine_prompt["cache_salt"] = cache_salt
|
||||
|
||||
# tool parsing is done only if a tool_parser has been set and if
|
||||
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser
|
||||
@@ -1264,49 +1248,10 @@ class OpenAIServing:
|
||||
"or Responses API requests."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
tokenizer = renderer.get_tokenizer()
|
||||
request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore
|
||||
|
||||
if tokenizer is None:
|
||||
assert isinstance(request_prompt, str), (
|
||||
"Prompt has to be a string",
|
||||
"when the tokenizer is not initialised",
|
||||
)
|
||||
prompt_inputs = TokensPrompt(prompt=request_prompt, prompt_token_ids=[1])
|
||||
elif isinstance(request_prompt, str):
|
||||
prompt_inputs = await self._tokenize_prompt_input_async(
|
||||
request,
|
||||
tokenizer,
|
||||
request_prompt,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
else:
|
||||
# For MistralTokenizer
|
||||
assert is_list_of(request_prompt, int), (
|
||||
"Prompt has to be either a string or a list of token ids"
|
||||
)
|
||||
input_text = tokenizer.decode(request_prompt)
|
||||
prompt_inputs = self._validate_input(
|
||||
request=request,
|
||||
input_ids=request_prompt,
|
||||
input_text=input_text,
|
||||
)
|
||||
|
||||
engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["prompt_token_ids"])
|
||||
if "prompt" in prompt_inputs:
|
||||
engine_prompt["prompt"] = prompt_inputs["prompt"]
|
||||
|
||||
if mm_data is not None:
|
||||
engine_prompt["multi_modal_data"] = mm_data
|
||||
|
||||
if mm_uuids is not None:
|
||||
engine_prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
if request.mm_processor_kwargs is not None:
|
||||
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
|
||||
|
||||
if hasattr(request, "cache_salt") and request.cache_salt is not None:
|
||||
engine_prompt["cache_salt"] = request.cache_salt
|
||||
|
||||
return conversation, [engine_prompt]
|
||||
|
||||
async def _process_inputs(
|
||||
@@ -1341,7 +1286,7 @@ class OpenAIServing:
|
||||
async def _render_next_turn(
|
||||
self,
|
||||
request: ResponsesRequest,
|
||||
tokenizer: TokenizerLike | None,
|
||||
renderer: RendererLike,
|
||||
messages: list[ResponseInputOutputItem],
|
||||
tool_dicts: list[dict[str, Any]] | None,
|
||||
tool_parser,
|
||||
@@ -1354,7 +1299,7 @@ class OpenAIServing:
|
||||
|
||||
_, engine_prompts = await self._preprocess_chat(
|
||||
request,
|
||||
tokenizer,
|
||||
renderer,
|
||||
new_messages,
|
||||
tool_dicts=tool_dicts,
|
||||
tool_parser=tool_parser,
|
||||
@@ -1431,7 +1376,7 @@ class OpenAIServing:
|
||||
elif isinstance(context, ParsableContext):
|
||||
engine_prompts = await self._render_next_turn(
|
||||
context.request,
|
||||
context.tokenizer,
|
||||
context.renderer,
|
||||
context.parser.response_messages,
|
||||
context.tool_dicts,
|
||||
context.tool_parser_cls,
|
||||
|
||||
@@ -61,6 +61,7 @@ class OpenAIServingModels:
|
||||
|
||||
self.input_processor = self.engine_client.input_processor
|
||||
self.io_processor = self.engine_client.io_processor
|
||||
self.renderer = self.engine_client.renderer
|
||||
self.model_config = self.engine_client.model_config
|
||||
self.max_model_len = self.model_config.max_model_len
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ from vllm.entrypoints.openai.responses.protocol import (
|
||||
from vllm.entrypoints.openai.responses.utils import construct_tool_dicts
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
|
||||
from vllm.renderers import RendererLike
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import ToolParser
|
||||
from vllm.utils import random_uuid
|
||||
@@ -260,7 +261,7 @@ class ParsableContext(ConversationContext):
|
||||
self,
|
||||
*,
|
||||
response_messages: list[ResponseInputOutputItem],
|
||||
tokenizer: TokenizerLike,
|
||||
renderer: RendererLike,
|
||||
reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser] | None,
|
||||
request: ResponsesRequest,
|
||||
available_tools: list[str] | None,
|
||||
@@ -279,6 +280,7 @@ class ParsableContext(ConversationContext):
|
||||
if reasoning_parser_cls is None:
|
||||
raise ValueError("reasoning_parser_cls must be provided.")
|
||||
|
||||
tokenizer = renderer.get_tokenizer()
|
||||
self.parser = get_responses_parser_for_simple_context(
|
||||
tokenizer=tokenizer,
|
||||
reasoning_parser_cls=reasoning_parser_cls,
|
||||
@@ -288,6 +290,7 @@ class ParsableContext(ConversationContext):
|
||||
)
|
||||
self.tool_parser_cls = tool_parser_cls
|
||||
self.request = request
|
||||
self.renderer = renderer
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.available_tools = available_tools or []
|
||||
|
||||
@@ -121,6 +121,7 @@ from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob as SampleLogprob
|
||||
from vllm.logprobs import SampleLogprobs
|
||||
from vllm.outputs import CompletionOutput
|
||||
from vllm.renderers import RendererLike
|
||||
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils import random_uuid
|
||||
@@ -380,7 +381,8 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
model_name = self.models.model_name(lora_request)
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
renderer = self.engine_client.renderer
|
||||
tokenizer = renderer.get_tokenizer()
|
||||
|
||||
if self.use_harmony:
|
||||
messages, engine_prompts = self._make_request_with_harmony(
|
||||
@@ -388,7 +390,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
)
|
||||
else:
|
||||
messages, engine_prompts = await self._make_request(
|
||||
request, prev_response, tokenizer
|
||||
request, prev_response, renderer
|
||||
)
|
||||
|
||||
except (
|
||||
@@ -454,7 +456,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
# tokens during generation instead of at the end
|
||||
context = ParsableContext(
|
||||
response_messages=messages,
|
||||
tokenizer=tokenizer,
|
||||
renderer=renderer,
|
||||
reasoning_parser_cls=self.reasoning_parser,
|
||||
request=request,
|
||||
tool_parser_cls=self.tool_parser,
|
||||
@@ -585,7 +587,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
self,
|
||||
request: ResponsesRequest,
|
||||
prev_response: ResponsesResponse | None,
|
||||
tokenizer: TokenizerLike,
|
||||
renderer: RendererLike,
|
||||
):
|
||||
tool_dicts = construct_tool_dicts(request.tools, request.tool_choice)
|
||||
# Construct the input messages.
|
||||
@@ -607,7 +609,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
|
||||
_, engine_prompts = await self._preprocess_chat(
|
||||
request,
|
||||
tokenizer,
|
||||
renderer,
|
||||
messages,
|
||||
tool_dicts=tool_dicts,
|
||||
tool_parser=self.tool_parser,
|
||||
@@ -631,6 +633,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
raise NotImplementedError(
|
||||
"Only 'auto' tool_choice is supported in response API with Harmony"
|
||||
)
|
||||
|
||||
messages = self._construct_input_messages_with_harmony(request, prev_response)
|
||||
prompt_token_ids = render_for_completion(messages)
|
||||
engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
|
||||
|
||||
@@ -28,21 +28,17 @@ def register_pooling_api_routers(app: FastAPI):
|
||||
async def init_pooling_state(
|
||||
engine_client: "EngineClient", state: "State", args: "Namespace"
|
||||
):
|
||||
from vllm.entrypoints.chat_utils import load_chat_template
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.pooling.classify.serving import ServingClassification
|
||||
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
|
||||
from vllm.entrypoints.pooling.score.serving import ServingScores
|
||||
from vllm.entrypoints.utils import process_chat_template
|
||||
from vllm.tasks import POOLING_TASKS
|
||||
|
||||
supported_tasks = await engine_client.get_supported_tasks()
|
||||
|
||||
vllm_config = engine_client.vllm_config
|
||||
|
||||
resolved_chat_template = await process_chat_template(
|
||||
args.chat_template, engine_client, vllm_config.model_config
|
||||
)
|
||||
resolved_chat_template = load_chat_template(args.chat_template)
|
||||
|
||||
if args.enable_log_requests:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
|
||||
@@ -54,8 +54,6 @@ class ClassificationMixin(OpenAIServing):
|
||||
"""
|
||||
ctx = cast(ClassificationServeContext, ctx)
|
||||
try:
|
||||
ctx.tokenizer = await self.engine_client.get_tokenizer()
|
||||
|
||||
request_obj = ctx.request
|
||||
|
||||
if isinstance(request_obj, ClassificationChatRequest):
|
||||
@@ -76,7 +74,7 @@ class ClassificationMixin(OpenAIServing):
|
||||
|
||||
_, engine_prompts = await self._preprocess_chat(
|
||||
cast(ChatCompletionRequest, chat_request),
|
||||
ctx.tokenizer,
|
||||
self.renderer,
|
||||
messages,
|
||||
chat_template=(
|
||||
chat_request.chat_template
|
||||
@@ -104,7 +102,7 @@ class ClassificationMixin(OpenAIServing):
|
||||
ctx.engine_prompts = []
|
||||
return None
|
||||
|
||||
renderer = self._get_renderer(ctx.tokenizer)
|
||||
renderer = self._get_completion_renderer()
|
||||
prompt_input = cast(str | list[str], input_data)
|
||||
ctx.engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=prompt_input,
|
||||
|
||||
@@ -78,13 +78,10 @@ class EmbeddingMixin(OpenAIServing):
|
||||
try:
|
||||
ctx.lora_request = self._maybe_get_adapters(ctx.request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
renderer = self._get_renderer(tokenizer)
|
||||
|
||||
if isinstance(ctx.request, EmbeddingChatRequest):
|
||||
_, ctx.engine_prompts = await self._preprocess_chat(
|
||||
ctx.request,
|
||||
tokenizer,
|
||||
self.renderer,
|
||||
ctx.request.messages,
|
||||
chat_template=ctx.request.chat_template or ctx.chat_template,
|
||||
chat_template_content_format=ctx.chat_template_content_format,
|
||||
@@ -93,6 +90,7 @@ class EmbeddingMixin(OpenAIServing):
|
||||
add_special_tokens=ctx.request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
renderer = self._get_completion_renderer()
|
||||
ctx.engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=ctx.request.input,
|
||||
config=self._build_render_config(ctx.request),
|
||||
|
||||
@@ -94,12 +94,6 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
renderer = self._get_renderer(tokenizer)
|
||||
|
||||
if getattr(request, "dimensions", None) is not None:
|
||||
return self.create_error_response(
|
||||
"dimensions is currently not supported"
|
||||
@@ -140,7 +134,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
|
||||
_, engine_prompts = await self._preprocess_chat(
|
||||
request,
|
||||
tokenizer,
|
||||
self.renderer,
|
||||
request.messages,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
chat_template_content_format=self.chat_template_content_format,
|
||||
@@ -149,6 +143,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
elif isinstance(request, PoolingCompletionRequest):
|
||||
renderer = self._get_completion_renderer()
|
||||
engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=request.input,
|
||||
config=self._build_render_config(request),
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Request
|
||||
@@ -63,6 +64,8 @@ class ServingScores(OpenAIServing):
|
||||
)
|
||||
self.score_template = score_template
|
||||
|
||||
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
async def _embedding_score(
|
||||
self,
|
||||
tokenizer: TokenizerLike,
|
||||
@@ -283,8 +286,7 @@ class ServingScores(OpenAIServing):
|
||||
raw_request: Request | None = None,
|
||||
) -> list[PoolingRequestOutput] | ErrorResponse:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
|
||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
|
||||
|
||||
|
||||
@@ -16,12 +16,12 @@ from vllm.entrypoints.chat_utils import (
|
||||
MultiModalItemTracker,
|
||||
_ContentPart,
|
||||
_parse_chat_message_content_part,
|
||||
apply_hf_chat_template,
|
||||
)
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.model_executor.models.interfaces import supports_score_template
|
||||
from vllm.multimodal.inputs import MultiModalDataDict
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.renderers.hf import safe_apply_chat_template
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
ScoreContentPartParam: TypeAlias = (
|
||||
@@ -224,15 +224,16 @@ def get_score_prompt(
|
||||
# If that fails because there is no such template,
|
||||
# fall back to the default implementation.
|
||||
try:
|
||||
full_prompt = apply_hf_chat_template(
|
||||
full_prompt = safe_apply_chat_template(
|
||||
model_config,
|
||||
tokenizer,
|
||||
[
|
||||
{"role": "query", "content": prompt_1},
|
||||
{"role": "document", "content": prompt_2},
|
||||
],
|
||||
score_template,
|
||||
chat_template=score_template,
|
||||
tools=None,
|
||||
model_config=model_config,
|
||||
tokenize=False,
|
||||
)
|
||||
prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs)
|
||||
except ChatTemplateResolutionError:
|
||||
|
||||
@@ -67,9 +67,6 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
try:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
renderer = self._get_renderer(tokenizer)
|
||||
|
||||
if isinstance(request, TokenizeChatRequest):
|
||||
tool_dicts = (
|
||||
None
|
||||
@@ -86,7 +83,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
|
||||
_, engine_prompts = await self._preprocess_chat(
|
||||
request,
|
||||
tokenizer,
|
||||
self.renderer,
|
||||
request.messages,
|
||||
tool_dicts=tool_dicts,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
@@ -97,6 +94,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
renderer = self._get_completion_renderer()
|
||||
engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=request.prompt,
|
||||
config=self._build_render_config(request),
|
||||
@@ -116,6 +114,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
|
||||
token_strs = None
|
||||
if request.return_token_strs:
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
token_strs = tokenizer.convert_ids_to_tokens(input_ids)
|
||||
|
||||
return TokenizeResponse(
|
||||
@@ -137,8 +136,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
request_id = f"tokenize-{self._base_request_id(raw_request)}"
|
||||
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
|
||||
self._log_inputs(
|
||||
request_id,
|
||||
@@ -161,7 +159,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
) -> TokenizerInfoResponse | ErrorResponse:
|
||||
"""Get comprehensive tokenizer information."""
|
||||
try:
|
||||
tokenizer = await self.engine_client.get_tokenizer()
|
||||
tokenizer = self.renderer.get_tokenizer()
|
||||
info = TokenizerInfo(tokenizer, self.chat_template).to_dict()
|
||||
return TokenizerInfoResponse(**info)
|
||||
except Exception as e:
|
||||
|
||||
@@ -6,7 +6,6 @@ import dataclasses
|
||||
import functools
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import regex as re
|
||||
@@ -14,17 +13,9 @@ from fastapi import Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from starlette.background import BackgroundTask, BackgroundTasks
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
load_chat_template,
|
||||
resolve_hf_chat_template,
|
||||
resolve_mistral_chat_template,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -301,40 +292,6 @@ def process_lora_modules(
|
||||
return lora_modules
|
||||
|
||||
|
||||
async def process_chat_template(
|
||||
args_chat_template: Path | str | None,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
) -> str | None:
|
||||
resolved_chat_template = load_chat_template(args_chat_template)
|
||||
if resolved_chat_template is not None:
|
||||
# Get the tokenizer to check official template
|
||||
tokenizer = await engine_client.get_tokenizer()
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
# The warning is logged in resolve_mistral_chat_template.
|
||||
resolved_chat_template = resolve_mistral_chat_template(
|
||||
chat_template=resolved_chat_template
|
||||
)
|
||||
else:
|
||||
hf_chat_template = resolve_hf_chat_template(
|
||||
tokenizer=tokenizer,
|
||||
chat_template=None,
|
||||
tools=None,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
if hf_chat_template != resolved_chat_template:
|
||||
logger.warning(
|
||||
"Using supplied chat template: %s\n"
|
||||
"It is different from official chat template '%s'. "
|
||||
"This discrepancy may lead to performance degradation.",
|
||||
resolved_chat_template,
|
||||
model_config.model,
|
||||
)
|
||||
return resolved_chat_template
|
||||
|
||||
|
||||
def sanitize_message(message: str) -> str:
|
||||
# Avoid leaking memory address from object reprs
|
||||
return re.sub(r" at 0x[0-9a-f]+>", ">", message)
|
||||
|
||||
@@ -17,6 +17,7 @@ from vllm.multimodal.inputs import (
|
||||
MultiModalUUIDDict,
|
||||
)
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||
from vllm.renderers import renderer_from_config
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.utils.jsontree import json_iter_leaves
|
||||
from vllm.v1.metrics.stats import MultiModalCacheStats
|
||||
@@ -46,7 +47,6 @@ class InputPreprocessor:
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
tokenizer: TokenizerLike | None,
|
||||
observability_config: ObservabilityConfig | None = None,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
mm_processor_cache: BaseMultiModalProcessorCache | None = None,
|
||||
@@ -54,20 +54,19 @@ class InputPreprocessor:
|
||||
super().__init__()
|
||||
|
||||
self.model_config = model_config
|
||||
self.tokenizer = tokenizer
|
||||
self.observability_config = observability_config
|
||||
self.renderer = renderer_from_config(model_config)
|
||||
self.mm_registry = mm_registry
|
||||
self.mm_processor_cache = mm_processor_cache
|
||||
|
||||
self.mm_cache_stats = MultiModalCacheStats() if mm_processor_cache else None
|
||||
|
||||
def get_tokenizer(self) -> TokenizerLike:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError(
|
||||
"You cannot pass text prompts when `skip_tokenizer_init=True`"
|
||||
)
|
||||
@property
|
||||
def tokenizer(self) -> TokenizerLike | None:
|
||||
return self.renderer.tokenizer
|
||||
|
||||
return self.tokenizer
|
||||
def get_tokenizer(self) -> TokenizerLike:
|
||||
return self.renderer.get_tokenizer()
|
||||
|
||||
def get_bos_token_id(self) -> int | None:
|
||||
if self.tokenizer is None:
|
||||
|
||||
7
vllm/renderers/__init__.py
Normal file
7
vllm/renderers/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .protocol import RendererLike
|
||||
from .registry import RendererRegistry, renderer_from_config
|
||||
|
||||
__all__ = ["RendererLike", "RendererRegistry", "renderer_from_config"]
|
||||
119
vllm/renderers/deepseek_v32.py
Normal file
119
vllm/renderers/deepseek_v32.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ConversationMessage,
|
||||
parse_chat_messages,
|
||||
parse_chat_messages_async,
|
||||
)
|
||||
from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
|
||||
|
||||
from .protocol import RendererLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DeepseekV32Renderer(RendererLike):
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ModelConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "RendererLike":
|
||||
return cls(config, tokenizer_kwargs)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
if config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = cached_get_tokenizer(
|
||||
tokenizer_cls=DeepseekV32Tokenizer,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
|
||||
self._tokenizer = tokenizer
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> DeepseekV32Tokenizer | None:
|
||||
return self._tokenizer
|
||||
|
||||
def get_tokenizer(self) -> DeepseekV32Tokenizer:
|
||||
tokenizer = self.tokenizer
|
||||
if tokenizer is None:
|
||||
raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")
|
||||
|
||||
return tokenizer
|
||||
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
**kwargs,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
messages,
|
||||
self.config,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
prompt_raw = tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
messages=messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
prompt = (
|
||||
TextPrompt(prompt=prompt_raw)
|
||||
if isinstance(prompt_raw, str)
|
||||
else TokensPrompt(prompt_token_ids=prompt_raw)
|
||||
)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt # type: ignore[return-value]
|
||||
|
||||
async def render_messages_async(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
**kwargs,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
||||
messages,
|
||||
self.config,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
prompt_raw = tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
messages=messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
prompt = (
|
||||
TextPrompt(prompt=prompt_raw)
|
||||
if isinstance(prompt_raw, str)
|
||||
else TokensPrompt(prompt_token_ids=prompt_raw)
|
||||
)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt # type: ignore[return-value]
|
||||
119
vllm/renderers/grok2.py
Normal file
119
vllm/renderers/grok2.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ConversationMessage,
|
||||
parse_chat_messages,
|
||||
parse_chat_messages_async,
|
||||
)
|
||||
from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.grok2 import Grok2Tokenizer
|
||||
|
||||
from .protocol import RendererLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Grok2Renderer(RendererLike):
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ModelConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "RendererLike":
|
||||
return cls(config, tokenizer_kwargs)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
if config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = cached_get_tokenizer(
|
||||
tokenizer_cls=Grok2Tokenizer,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
|
||||
self._tokenizer = tokenizer
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> Grok2Tokenizer | None:
|
||||
return self._tokenizer
|
||||
|
||||
def get_tokenizer(self) -> Grok2Tokenizer:
|
||||
tokenizer = self.tokenizer
|
||||
if tokenizer is None:
|
||||
raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")
|
||||
|
||||
return tokenizer
|
||||
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
**kwargs,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
messages,
|
||||
self.config,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
prompt_raw = tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
messages=messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
prompt = (
|
||||
TextPrompt(prompt=prompt_raw)
|
||||
if isinstance(prompt_raw, str)
|
||||
else TokensPrompt(prompt_token_ids=prompt_raw)
|
||||
)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt # type: ignore[return-value]
|
||||
|
||||
async def render_messages_async(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
**kwargs,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
||||
messages,
|
||||
self.config,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
prompt_raw = tokenizer.apply_chat_template(
|
||||
conversation=conversation,
|
||||
messages=messages,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
prompt = (
|
||||
TextPrompt(prompt=prompt_raw)
|
||||
if isinstance(prompt_raw, str)
|
||||
else TokensPrompt(prompt_token_ids=prompt_raw)
|
||||
)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt # type: ignore[return-value]
|
||||
600
vllm/renderers/hf.py
Normal file
600
vllm/renderers/hf.py
Normal file
@@ -0,0 +1,600 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import inspect
|
||||
from collections import deque
|
||||
from collections.abc import Set
|
||||
from functools import lru_cache
|
||||
from typing import Any, cast
|
||||
|
||||
import jinja2
|
||||
import jinja2.ext
|
||||
import jinja2.meta
|
||||
import jinja2.nodes
|
||||
import jinja2.parser
|
||||
import jinja2.sandbox
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatTemplateContentFormat,
|
||||
ChatTemplateContentFormatOption,
|
||||
ChatTemplateResolutionError,
|
||||
ConversationMessage,
|
||||
load_chat_template,
|
||||
parse_chat_messages,
|
||||
parse_chat_messages_async,
|
||||
)
|
||||
from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.hf import CachedHfTokenizer, HfTokenizer
|
||||
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
from vllm.utils.func_utils import supports_kw
|
||||
|
||||
from .protocol import RendererLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], str | None]()
|
||||
"""
|
||||
Used in `_try_get_processor_chat_template` to avoid calling
|
||||
`cached_get_processor` again if the processor fails to be loaded.
|
||||
|
||||
This is needed because `lru_cache` does not cache when an exception happens.
|
||||
"""
|
||||
|
||||
|
||||
def _try_get_processor_chat_template(
|
||||
tokenizer: HfTokenizer,
|
||||
*,
|
||||
trust_remote_code: bool,
|
||||
) -> str | None:
|
||||
cache_key = (tokenizer.name_or_path, trust_remote_code)
|
||||
if cache_key in _PROCESSOR_CHAT_TEMPLATES:
|
||||
return _PROCESSOR_CHAT_TEMPLATES[cache_key]
|
||||
|
||||
from transformers import (
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
ProcessorMixin,
|
||||
)
|
||||
|
||||
try:
|
||||
processor = cached_get_processor(
|
||||
tokenizer.name_or_path,
|
||||
processor_cls=(
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
ProcessorMixin,
|
||||
),
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if (
|
||||
isinstance(processor, ProcessorMixin)
|
||||
and hasattr(processor, "chat_template")
|
||||
and (chat_template := processor.chat_template) is not None
|
||||
):
|
||||
_PROCESSOR_CHAT_TEMPLATES[cache_key] = chat_template
|
||||
return chat_template
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Failed to load AutoProcessor chat template for %s",
|
||||
tokenizer.name_or_path,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
_PROCESSOR_CHAT_TEMPLATES[cache_key] = None
|
||||
return None
|
||||
|
||||
|
||||
def resolve_chat_template(
|
||||
tokenizer: HfTokenizer,
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
*,
|
||||
model_config: "ModelConfig",
|
||||
) -> str | None:
|
||||
# 1st priority: The given chat template
|
||||
if chat_template is not None:
|
||||
return chat_template
|
||||
|
||||
# 2nd priority: AutoProcessor chat template, unless tool calling is enabled
|
||||
if tools is None:
|
||||
chat_template = _try_get_processor_chat_template(
|
||||
tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
)
|
||||
if chat_template is not None:
|
||||
return chat_template
|
||||
|
||||
# 3rd priority: AutoTokenizer chat template
|
||||
try:
|
||||
return tokenizer.get_chat_template(chat_template, tools=tools)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Failed to load AutoTokenizer chat template for %s",
|
||||
tokenizer.name_or_path,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# 4th priority: Predefined fallbacks
|
||||
path = get_chat_template_fallback_path(
|
||||
model_type=model_config.hf_config.model_type,
|
||||
tokenizer_name_or_path=tokenizer.name_or_path,
|
||||
)
|
||||
if path is not None:
|
||||
logger.info_once(
|
||||
"Loading chat template fallback for %s as there isn't one "
|
||||
"defined on HF Hub.",
|
||||
tokenizer.name_or_path,
|
||||
)
|
||||
chat_template = load_chat_template(path)
|
||||
else:
|
||||
logger.debug_once(
|
||||
"There is no chat template fallback for %s", tokenizer.name_or_path
|
||||
)
|
||||
|
||||
return chat_template
|
||||
|
||||
|
||||
def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
|
||||
if isinstance(node, jinja2.nodes.Name):
|
||||
return node.ctx == "load" and node.name == varname
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
|
||||
if isinstance(node, jinja2.nodes.Getitem):
|
||||
return (
|
||||
_is_var_access(node.node, varname)
|
||||
and isinstance(node.arg, jinja2.nodes.Const)
|
||||
and node.arg.value == key
|
||||
)
|
||||
|
||||
if isinstance(node, jinja2.nodes.Getattr):
|
||||
return _is_var_access(node.node, varname) and node.attr == key
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _is_var_or_elems_access(
|
||||
node: jinja2.nodes.Node,
|
||||
varname: str,
|
||||
key: str | None = None,
|
||||
) -> bool:
|
||||
if isinstance(node, jinja2.nodes.Filter):
|
||||
return node.node is not None and _is_var_or_elems_access(
|
||||
node.node, varname, key
|
||||
)
|
||||
if isinstance(node, jinja2.nodes.Test):
|
||||
return _is_var_or_elems_access(node.node, varname, key)
|
||||
|
||||
if isinstance(node, jinja2.nodes.Getitem) and isinstance(
|
||||
node.arg, jinja2.nodes.Slice
|
||||
):
|
||||
return _is_var_or_elems_access(node.node, varname, key)
|
||||
|
||||
return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname)
|
||||
|
||||
|
||||
def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
|
||||
# Global variable that is implicitly defined at the root
|
||||
yield root, varname
|
||||
|
||||
# Iterative BFS
|
||||
related_varnames = deque([varname])
|
||||
while related_varnames:
|
||||
related_varname = related_varnames.popleft()
|
||||
|
||||
for assign_ast in root.find_all(jinja2.nodes.Assign):
|
||||
lhs = assign_ast.target
|
||||
rhs = assign_ast.node
|
||||
|
||||
if _is_var_or_elems_access(rhs, related_varname):
|
||||
assert isinstance(lhs, jinja2.nodes.Name)
|
||||
yield assign_ast, lhs.name
|
||||
|
||||
# Avoid infinite looping for self-assignment
|
||||
if lhs.name != related_varname:
|
||||
related_varnames.append(lhs.name)
|
||||
|
||||
|
||||
# NOTE: The proper way to handle this is to build a CFG so that we can handle
|
||||
# the scope in which each variable is defined, but that is too complicated
|
||||
def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
|
||||
messages_varnames = [
|
||||
varname for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
|
||||
]
|
||||
|
||||
# Search for {%- for message in messages -%} loops
|
||||
for loop_ast in root.find_all(jinja2.nodes.For):
|
||||
loop_iter = loop_ast.iter
|
||||
loop_target = loop_ast.target
|
||||
|
||||
for varname in messages_varnames:
|
||||
if _is_var_or_elems_access(loop_iter, varname):
|
||||
assert isinstance(loop_target, jinja2.nodes.Name)
|
||||
yield loop_ast, loop_target.name
|
||||
break
|
||||
|
||||
|
||||
def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
|
||||
message_varnames = [
|
||||
varname for _, varname in _iter_nodes_assign_messages_item(root)
|
||||
]
|
||||
|
||||
# Search for {%- for content in message['content'] -%} loops
|
||||
for loop_ast in root.find_all(jinja2.nodes.For):
|
||||
loop_iter = loop_ast.iter
|
||||
loop_target = loop_ast.target
|
||||
|
||||
for varname in message_varnames:
|
||||
if _is_var_or_elems_access(loop_iter, varname, "content"):
|
||||
assert isinstance(loop_target, jinja2.nodes.Name)
|
||||
yield loop_ast, loop_target.name
|
||||
break
|
||||
|
||||
|
||||
def _try_extract_ast(chat_template: str) -> jinja2.nodes.Template | None:
|
||||
import transformers.utils.chat_template_utils as hf_chat_utils
|
||||
|
||||
try:
|
||||
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
|
||||
return jinja_compiled.environment.parse(chat_template)
|
||||
except Exception:
|
||||
logger.exception("Error when compiling Jinja template")
|
||||
return None
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _detect_content_format(
|
||||
chat_template: str,
|
||||
*,
|
||||
default: ChatTemplateContentFormat,
|
||||
) -> ChatTemplateContentFormat:
|
||||
jinja_ast = _try_extract_ast(chat_template)
|
||||
if jinja_ast is None:
|
||||
return default
|
||||
|
||||
try:
|
||||
next(_iter_nodes_assign_content_item(jinja_ast))
|
||||
except StopIteration:
|
||||
return "string"
|
||||
except Exception:
|
||||
logger.exception("Error when parsing AST of Jinja template")
|
||||
return default
|
||||
else:
|
||||
return "openai"
|
||||
|
||||
|
||||
def _resolve_chat_template_content_format(
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
tokenizer: HfTokenizer,
|
||||
*,
|
||||
model_config: "ModelConfig",
|
||||
) -> ChatTemplateContentFormat:
|
||||
resolved_chat_template = resolve_chat_template(
|
||||
tokenizer,
|
||||
chat_template=chat_template,
|
||||
tools=tools,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
jinja_text = (
|
||||
resolved_chat_template
|
||||
if isinstance(resolved_chat_template, str)
|
||||
else load_chat_template(chat_template, is_literal=True)
|
||||
)
|
||||
|
||||
detected_format = (
|
||||
"string"
|
||||
if jinja_text is None
|
||||
else _detect_content_format(jinja_text, default="string")
|
||||
)
|
||||
|
||||
return detected_format
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _log_chat_template_content_format(
|
||||
chat_template: str | None, # For caching purposes
|
||||
given_format: ChatTemplateContentFormatOption,
|
||||
detected_format: ChatTemplateContentFormatOption,
|
||||
):
|
||||
logger.info(
|
||||
"Detected the chat template content format to be '%s'. "
|
||||
"You can set `--chat-template-content-format` to override this.",
|
||||
detected_format,
|
||||
)
|
||||
|
||||
if given_format != "auto" and given_format != detected_format:
|
||||
logger.warning(
|
||||
"You specified `--chat-template-content-format %s` "
|
||||
"which is different from the detected format '%s'. "
|
||||
"If our automatic detection is incorrect, please consider "
|
||||
"opening a GitHub issue so that we can improve it: "
|
||||
"https://github.com/vllm-project/vllm/issues/new/choose",
|
||||
given_format,
|
||||
detected_format,
|
||||
)
|
||||
|
||||
|
||||
def resolve_chat_template_content_format(
|
||||
chat_template: str | None,
|
||||
tools: list[dict[str, Any]] | None,
|
||||
given_format: ChatTemplateContentFormatOption,
|
||||
tokenizer: HfTokenizer,
|
||||
*,
|
||||
model_config: "ModelConfig",
|
||||
) -> ChatTemplateContentFormat:
|
||||
if given_format != "auto":
|
||||
return given_format
|
||||
|
||||
detected_format = _resolve_chat_template_content_format(
|
||||
chat_template,
|
||||
tools,
|
||||
tokenizer,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
_log_chat_template_content_format(
|
||||
chat_template,
|
||||
given_format=given_format,
|
||||
detected_format=detected_format,
|
||||
)
|
||||
|
||||
return detected_format
|
||||
|
||||
|
||||
# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
|
||||
# only preserve the parse function used to resolve chat template kwargs
|
||||
class AssistantTracker(jinja2.ext.Extension):
|
||||
tags = {"generation"}
|
||||
|
||||
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.Node:
|
||||
lineno = next(parser.stream).lineno
|
||||
body = parser.parse_statements(("name:endgeneration",), drop_needle=True)
|
||||
call = self.call_method("_generation_support")
|
||||
call_block = jinja2.nodes.CallBlock(call, [], [], body)
|
||||
return call_block.set_lineno(lineno)
|
||||
|
||||
|
||||
def _resolve_chat_template_kwargs(chat_template: str) -> Set[str]:
|
||||
env = jinja2.sandbox.ImmutableSandboxedEnvironment(
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
extensions=[AssistantTracker, jinja2.ext.loopcontrols],
|
||||
)
|
||||
parsed_content = env.parse(chat_template)
|
||||
template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
|
||||
return template_vars
|
||||
|
||||
|
||||
_cached_resolve_chat_template_kwargs = lru_cache(_resolve_chat_template_kwargs)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _get_hf_base_chat_template_params() -> frozenset[str]:
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
# Get standard parameters from HuggingFace's base tokenizer class.
|
||||
# This dynamically extracts parameters from PreTrainedTokenizer's
|
||||
# apply_chat_template method, ensuring compatibility with tokenizers
|
||||
# that use **kwargs to receive standard parameters.
|
||||
|
||||
# Read signature from HF's base class - the single source of truth
|
||||
base_sig = inspect.signature(PreTrainedTokenizer.apply_chat_template)
|
||||
|
||||
# Exclude VAR_KEYWORD (**kwargs) and VAR_POSITIONAL (*args) placeholders
|
||||
return frozenset(
|
||||
p.name
|
||||
for p in base_sig.parameters.values()
|
||||
if p.kind
|
||||
not in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL)
|
||||
)
|
||||
|
||||
|
||||
def resolve_chat_template_kwargs(
|
||||
tokenizer: HfTokenizer,
|
||||
chat_template: str,
|
||||
chat_template_kwargs: dict[str, Any],
|
||||
raise_on_unexpected: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
# We exclude chat_template from kwargs here, because
|
||||
# chat template has been already resolved at this stage
|
||||
unexpected_vars = {"chat_template", "tokenize"}
|
||||
if raise_on_unexpected and (
|
||||
unexpected_in_kwargs := unexpected_vars & chat_template_kwargs.keys()
|
||||
):
|
||||
raise ValueError(
|
||||
"Found unexpected chat template kwargs from request: "
|
||||
f"{unexpected_in_kwargs}"
|
||||
)
|
||||
|
||||
fn_kw = {
|
||||
k
|
||||
for k in chat_template_kwargs
|
||||
if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
|
||||
}
|
||||
template_vars = _cached_resolve_chat_template_kwargs(chat_template)
|
||||
|
||||
# Allow standard HF parameters even if tokenizer uses **kwargs to receive them
|
||||
hf_base_params = _get_hf_base_chat_template_params()
|
||||
|
||||
accept_vars = (fn_kw | template_vars | hf_base_params) - unexpected_vars
|
||||
return {k: v for k, v in chat_template_kwargs.items() if k in accept_vars}
|
||||
|
||||
|
||||
def safe_apply_chat_template(
|
||||
model_config: "ModelConfig",
|
||||
tokenizer: HfTokenizer,
|
||||
conversation: list[ConversationMessage],
|
||||
*,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
chat_template: str | None = None,
|
||||
tokenize: bool = True,
|
||||
**kwargs,
|
||||
) -> str | list[int]:
|
||||
chat_template = resolve_chat_template(
|
||||
tokenizer,
|
||||
chat_template=chat_template,
|
||||
tools=tools,
|
||||
model_config=model_config,
|
||||
)
|
||||
if chat_template is None:
|
||||
raise ChatTemplateResolutionError(
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one."
|
||||
)
|
||||
|
||||
resolved_kwargs = resolve_chat_template_kwargs(
|
||||
tokenizer=tokenizer,
|
||||
chat_template=chat_template,
|
||||
chat_template_kwargs=kwargs,
|
||||
)
|
||||
|
||||
try:
|
||||
return tokenizer.apply_chat_template(
|
||||
conversation=conversation, # type: ignore[arg-type]
|
||||
tools=tools, # type: ignore[arg-type]
|
||||
chat_template=chat_template,
|
||||
tokenize=tokenize,
|
||||
**resolved_kwargs,
|
||||
)
|
||||
# External library exceptions can sometimes occur despite the framework's
|
||||
# internal exception management capabilities.
|
||||
except Exception as e:
|
||||
# Log and report any library-related exceptions for further
|
||||
# investigation.
|
||||
logger.exception(
|
||||
"An error occurred in `transformers` while applying chat template"
|
||||
)
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
|
||||
class HfRenderer(RendererLike):
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ModelConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "RendererLike":
|
||||
return cls(config, tokenizer_kwargs)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
if config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = cast(
|
||||
HfTokenizer,
|
||||
cached_get_tokenizer(
|
||||
tokenizer_cls=CachedHfTokenizer, # type: ignore[type-abstract]
|
||||
**tokenizer_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
self._tokenizer = tokenizer
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> HfTokenizer | None:
|
||||
return self._tokenizer
|
||||
|
||||
def get_tokenizer(self) -> HfTokenizer:
|
||||
tokenizer = self.tokenizer
|
||||
if tokenizer is None:
|
||||
raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")
|
||||
|
||||
return tokenizer
|
||||
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
|
||||
**kwargs,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
|
||||
model_config = self.config
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
messages,
|
||||
model_config,
|
||||
content_format=resolve_chat_template_content_format(
|
||||
chat_template=kwargs.get("chat_template"),
|
||||
tools=kwargs.get("tools"),
|
||||
given_format=chat_template_content_format,
|
||||
tokenizer=tokenizer,
|
||||
model_config=model_config,
|
||||
),
|
||||
)
|
||||
|
||||
prompt_raw = safe_apply_chat_template(
|
||||
model_config,
|
||||
tokenizer,
|
||||
conversation,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
prompt = (
|
||||
TextPrompt(prompt=prompt_raw)
|
||||
if isinstance(prompt_raw, str)
|
||||
else TokensPrompt(prompt_token_ids=prompt_raw)
|
||||
)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt # type: ignore[return-value]
|
||||
|
||||
async def render_messages_async(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
|
||||
**kwargs,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
|
||||
model_config = self.config
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
||||
messages,
|
||||
model_config,
|
||||
content_format=resolve_chat_template_content_format(
|
||||
chat_template=kwargs.get("chat_template"),
|
||||
tools=kwargs.get("tools"),
|
||||
given_format=chat_template_content_format,
|
||||
tokenizer=tokenizer,
|
||||
model_config=model_config,
|
||||
),
|
||||
)
|
||||
|
||||
prompt_raw = safe_apply_chat_template(
|
||||
model_config,
|
||||
tokenizer,
|
||||
conversation,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
prompt = (
|
||||
TextPrompt(prompt=prompt_raw)
|
||||
if isinstance(prompt_raw, str)
|
||||
else TokensPrompt(prompt_token_ids=prompt_raw)
|
||||
)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt # type: ignore[return-value]
|
||||
147
vllm/renderers/mistral.py
Normal file
147
vllm/renderers/mistral.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ConversationMessage,
|
||||
parse_chat_messages,
|
||||
parse_chat_messages_async,
|
||||
)
|
||||
from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import cached_get_tokenizer
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils.async_utils import make_async
|
||||
|
||||
from .protocol import RendererLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def safe_apply_chat_template(
|
||||
tokenizer: MistralTokenizer,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
**kwargs,
|
||||
) -> str | list[int]:
|
||||
from mistral_common.exceptions import MistralCommonException
|
||||
|
||||
try:
|
||||
return tokenizer.apply_chat_template(messages, **kwargs)
|
||||
# mistral-common uses assert statements to stop processing of input
|
||||
# if input does not comply with the expected format.
|
||||
# We convert those assertion errors to ValueErrors so they can be
|
||||
# properly caught in the preprocessing_input step
|
||||
except (AssertionError, MistralCommonException) as e:
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
# External library exceptions can sometimes occur despite the framework's
|
||||
# internal exception management capabilities.
|
||||
except Exception as e:
|
||||
# Log and report any library-related exceptions for further
|
||||
# investigation.
|
||||
logger.exception(
|
||||
"An error occurred in `mistral_common` while applying chat template"
|
||||
)
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
|
||||
class MistralRenderer(RendererLike):
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ModelConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "RendererLike":
|
||||
return cls(config, tokenizer_kwargs)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelConfig,
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
if config.skip_tokenizer_init:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = cached_get_tokenizer(
|
||||
tokenizer_cls=MistralTokenizer,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
|
||||
self._tokenizer = tokenizer
|
||||
|
||||
self._apply_chat_template_executor = ThreadPoolExecutor(max_workers=1)
|
||||
self._apply_chat_template_async = make_async(
|
||||
safe_apply_chat_template, executor=self._apply_chat_template_executor
|
||||
)
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> MistralTokenizer | None:
|
||||
return self._tokenizer
|
||||
|
||||
def get_tokenizer(self) -> MistralTokenizer:
|
||||
tokenizer = self.tokenizer
|
||||
if tokenizer is None:
|
||||
raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")
|
||||
|
||||
return tokenizer
|
||||
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
**kwargs,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
messages,
|
||||
self.config,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
prompt_raw = safe_apply_chat_template(tokenizer, messages, **kwargs)
|
||||
|
||||
prompt = (
|
||||
TextPrompt(prompt=prompt_raw)
|
||||
if isinstance(prompt_raw, str)
|
||||
else TokensPrompt(prompt_token_ids=prompt_raw)
|
||||
)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt # type: ignore[return-value]
|
||||
|
||||
async def render_messages_async(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
**kwargs,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
|
||||
tokenizer = self.get_tokenizer()
|
||||
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
||||
messages,
|
||||
self.config,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
prompt_raw = await self._apply_chat_template_async(
|
||||
tokenizer, messages, **kwargs
|
||||
)
|
||||
|
||||
prompt = (
|
||||
TextPrompt(prompt=prompt_raw)
|
||||
if isinstance(prompt_raw, str)
|
||||
else TokensPrompt(prompt_token_ids=prompt_raw)
|
||||
)
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt # type: ignore[return-value]
|
||||
48
vllm/renderers/protocol.py
Normal file
48
vllm/renderers/protocol.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ConversationMessage,
|
||||
)
|
||||
|
||||
|
||||
class RendererLike(Protocol):
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: "ModelConfig",
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "RendererLike":
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> TokenizerLike | None:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_tokenizer(self) -> TokenizerLike:
|
||||
tokenizer = self.tokenizer
|
||||
if tokenizer is None:
|
||||
raise ValueError("Tokenizer not available when `skip_tokenizer_init=True`")
|
||||
|
||||
return tokenizer
|
||||
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list["ChatCompletionMessageParam"],
|
||||
**kwargs,
|
||||
) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def render_messages_async(
|
||||
self,
|
||||
messages: list["ChatCompletionMessageParam"],
|
||||
**kwargs,
|
||||
) -> tuple[list["ConversationMessage"], TextPrompt | TokensPrompt]:
|
||||
return self.render_messages(messages, **kwargs)
|
||||
88
vllm/renderers/registry.py
Normal file
88
vllm/renderers/registry.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers.registry import tokenizer_args_from_config
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
|
||||
from .protocol import RendererLike
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
_VLLM_RENDERERS = {
|
||||
"deepseek_v32": ("deepseek_v32", "DeepseekV32Renderer"),
|
||||
"hf": ("hf", "HfRenderer"),
|
||||
"grok2": ("grok2", "Grok2Renderer"),
|
||||
"mistral": ("mistral", "MistralRenderer"),
|
||||
"terratorch": ("terratorch", "TerratorchRenderer"),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RendererRegistry:
|
||||
# Renderer mode -> (renderer module, renderer class)
|
||||
renderers: dict[str, tuple[str, str]] = field(default_factory=dict)
|
||||
|
||||
def register(self, renderer_mode: str, module: str, class_name: str) -> None:
|
||||
if renderer_mode in self.renderers:
|
||||
logger.warning(
|
||||
"%s.%s is already registered for renderer_mode=%r. "
|
||||
"It is overwritten by the new one.",
|
||||
module,
|
||||
class_name,
|
||||
renderer_mode,
|
||||
)
|
||||
|
||||
self.renderers[renderer_mode] = (module, class_name)
|
||||
|
||||
return None
|
||||
|
||||
def load_renderer_cls(self, renderer_mode: str) -> type[RendererLike]:
|
||||
if renderer_mode not in self.renderers:
|
||||
raise ValueError(f"No renderer registered for {renderer_mode=!r}.")
|
||||
|
||||
module, class_name = self.renderers[renderer_mode]
|
||||
logger.debug_once(f"Loading {class_name} for {renderer_mode=!r}")
|
||||
|
||||
return resolve_obj_by_qualname(f"{module}.{class_name}")
|
||||
|
||||
def load_renderer(
|
||||
self,
|
||||
renderer_mode: str,
|
||||
config: "ModelConfig",
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> RendererLike:
|
||||
renderer_cls = self.load_renderer_cls(renderer_mode)
|
||||
return renderer_cls.from_config(config, tokenizer_kwargs)
|
||||
|
||||
|
||||
RENDERER_REGISTRY = RendererRegistry(
|
||||
{
|
||||
mode: (f"vllm.renderers.{mod_relname}", cls_name)
|
||||
for mode, (mod_relname, cls_name) in _VLLM_RENDERERS.items()
|
||||
}
|
||||
)
|
||||
"""The global `RendererRegistry` instance."""
|
||||
|
||||
|
||||
def renderer_from_config(config: "ModelConfig", **kwargs):
|
||||
tokenizer_mode, tokenizer_name, args, kwargs = tokenizer_args_from_config(
|
||||
config, **kwargs
|
||||
)
|
||||
|
||||
if config.tokenizer_mode == "auto" and config.model_impl == "terratorch":
|
||||
renderer_mode = "terratorch"
|
||||
else:
|
||||
renderer_mode = tokenizer_mode
|
||||
|
||||
return RENDERER_REGISTRY.load_renderer(
|
||||
renderer_mode,
|
||||
config,
|
||||
tokenizer_kwargs={**kwargs, "tokenizer_name": tokenizer_name},
|
||||
)
|
||||
85
vllm/renderers/terratorch.py
Normal file
85
vllm/renderers/terratorch.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import (
|
||||
ChatCompletionMessageParam,
|
||||
ConversationMessage,
|
||||
parse_chat_messages,
|
||||
parse_chat_messages_async,
|
||||
)
|
||||
from vllm.inputs import TextPrompt, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
from .protocol import RendererLike
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TerratorchRenderer(RendererLike):
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: "ModelConfig",
|
||||
tokenizer_kwargs: dict[str, Any],
|
||||
) -> "RendererLike":
|
||||
return cls(config)
|
||||
|
||||
def __init__(self, config: ModelConfig) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
|
||||
if not config.skip_tokenizer_init:
|
||||
raise ValueError("Terratorch renderer requires `skip_tokenizer_init=True`")
|
||||
|
||||
@property
|
||||
def tokenizer(self) -> TokenizerLike | None:
|
||||
return None
|
||||
|
||||
def get_tokenizer(self) -> TokenizerLike:
|
||||
raise ValueError("Tokenizer not available for Terratorch renderer")
|
||||
|
||||
def render_messages(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
**kwargs,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
|
||||
model_config = self.config
|
||||
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
messages,
|
||||
model_config,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
prompt = TokensPrompt(prompt_token_ids=[1])
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt
|
||||
|
||||
async def render_messages_async(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
**kwargs,
|
||||
) -> tuple[list[ConversationMessage], TextPrompt | TokensPrompt]:
|
||||
model_config = self.config
|
||||
|
||||
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
||||
messages,
|
||||
model_config,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
prompt = TokensPrompt(prompt_token_ids=[1]) # Dummy token IDs
|
||||
if mm_data is not None:
|
||||
prompt["multi_modal_data"] = mm_data
|
||||
if mm_uuids is not None:
|
||||
prompt["multi_modal_uuids"] = mm_uuids
|
||||
|
||||
return conversation, prompt
|
||||
@@ -23,9 +23,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import RendererLike
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tracing import init_tracer
|
||||
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
@@ -106,9 +107,7 @@ class AsyncLLM(EngineClient):
|
||||
"enabling logging without default stat loggers."
|
||||
)
|
||||
|
||||
tokenizer = cached_tokenizer_from_config(self.model_config)
|
||||
|
||||
self.input_processor = InputProcessor(self.vllm_config, tokenizer)
|
||||
self.input_processor = InputProcessor(self.vllm_config)
|
||||
self.io_processor = get_io_processor(
|
||||
self.vllm_config,
|
||||
self.model_config.io_processor_plugin,
|
||||
@@ -709,13 +708,12 @@ class AsyncLLM(EngineClient):
|
||||
def tokenizer(self) -> TokenizerLike | None:
|
||||
return self.input_processor.tokenizer
|
||||
|
||||
async def get_tokenizer(self) -> TokenizerLike:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError(
|
||||
"Unable to get tokenizer because `skip_tokenizer_init=True`"
|
||||
)
|
||||
def get_tokenizer(self) -> TokenizerLike:
|
||||
return self.input_processor.get_tokenizer()
|
||||
|
||||
return self.tokenizer
|
||||
@property
|
||||
def renderer(self) -> RendererLike:
|
||||
return self.input_processor.renderer
|
||||
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
return self.observability_config.otlp_traces_endpoint is not None # type: ignore
|
||||
|
||||
@@ -19,6 +19,7 @@ from vllm.multimodal.parse import MultiModalDataParser
|
||||
from vllm.multimodal.processing.context import set_request_id
|
||||
from vllm.multimodal.utils import argsort_mm_positions
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import RendererLike
|
||||
from vllm.sampling_params import _SAMPLING_EPS, SamplingParams
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
@@ -45,7 +46,6 @@ class InputProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
tokenizer: TokenizerLike | None,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
@@ -61,8 +61,7 @@ class InputProcessor:
|
||||
|
||||
self.input_preprocessor = InputPreprocessor(
|
||||
self.model_config,
|
||||
tokenizer,
|
||||
self.vllm_config.observability_config,
|
||||
vllm_config.observability_config,
|
||||
mm_registry,
|
||||
mm_processor_cache=self.mm_processor_cache,
|
||||
)
|
||||
@@ -71,6 +70,13 @@ class InputProcessor:
|
||||
def tokenizer(self) -> TokenizerLike | None:
|
||||
return self.input_preprocessor.tokenizer
|
||||
|
||||
def get_tokenizer(self) -> TokenizerLike:
|
||||
return self.input_preprocessor.get_tokenizer()
|
||||
|
||||
@property
|
||||
def renderer(self) -> RendererLike:
|
||||
return self.input_preprocessor.renderer
|
||||
|
||||
def _validate_logprobs(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
|
||||
@@ -21,9 +21,10 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.renderers import RendererLike
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tracing import init_tracer
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
@@ -84,9 +85,7 @@ class LLMEngine:
|
||||
self.dp_group = None
|
||||
self.should_execute_dummy_batch = False
|
||||
|
||||
tokenizer = cached_tokenizer_from_config(self.model_config)
|
||||
|
||||
self.input_processor = InputProcessor(self.vllm_config, tokenizer)
|
||||
self.input_processor = InputProcessor(self.vllm_config)
|
||||
self.io_processor = get_io_processor(
|
||||
self.vllm_config,
|
||||
self.model_config.io_processor_plugin,
|
||||
@@ -357,12 +356,11 @@ class LLMEngine:
|
||||
return self.input_processor.tokenizer
|
||||
|
||||
def get_tokenizer(self) -> TokenizerLike:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError(
|
||||
"Unable to get tokenizer because `skip_tokenizer_init=True`"
|
||||
)
|
||||
return self.input_processor.get_tokenizer()
|
||||
|
||||
return self.tokenizer
|
||||
@property
|
||||
def renderer(self) -> RendererLike:
|
||||
return self.input_processor.renderer
|
||||
|
||||
def do_log_stats(self) -> None:
|
||||
"""Log stats if logging is enabled."""
|
||||
|
||||
Reference in New Issue
Block a user