diff --git a/Dockerfile b/Dockerfile index b403eee..9b65dd5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,10 @@ ARG BASE_IMAGE=vllm/vllm-openai:glm51-cu130 FROM ${BASE_IMAGE} +# Patch tool parser for GLM regex fix COPY glm4_moe_tool_parser.py /usr/local/lib/python3.12/dist-packages/vllm/tool_parsers/glm4_moe_tool_parser.py COPY utils.py /usr/local/lib/python3.12/dist-packages/vllm/tool_parsers/utils.py + +# Patch hf renderer to force string content format for GLM models +# This fixes the issue where tool response content is dropped +COPY vllm_patches/hf.py /usr/local/lib/python3.12/dist-packages/vllm/renderers/hf.py diff --git a/README.md b/README.md index 1e3e30f..311fba8 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,11 @@ Patches vLLM's GLM-4/GLM-5.1 tool parser to fix multiple issues with tool call h **Symptom:** When the model makes a tool call and receives a response, it would act as if the response was empty ("The function returned no output") even though valid content was provided. -**Root Cause:** The `func_detail_regex` required a newline between the function name and first argument tag, but GLM-5.1's chat template does NOT include that newline. The regex silently failed to match, tool call extraction failed, and somewhere in that failure path the tool response content got lost. +**Root Cause:** Two bugs working together: + +1. **Tool parser regex mismatch** (`glm4_moe_tool_parser.py`): The `func_detail_regex` required a newline between the function name and first argument tag, but GLM-5.1's chat template doesn't include that newline. The regex silently failed to match. + +2. **Content format detection wrong** (`vllm/renderers/hf.py`): vLLM detected "openai" content format because the GLM template has `{% for tr in m.content %}` for tool responses. But the template then checks `m.content is string` which is False for OpenAI format arrays, causing content to be dropped. **Model output format (no newline after name):** ``` @@ -25,10 +29,8 @@ r"\[TOOL_CALL_START\]([^\n]*)\n(.*)\[TOOL_CALL_END\]" # Requires \n after name r"\[TOOL_CALL_START\]\s*([\w.\-]+)\s*((?:\[ARG_KEY\].*)?)\s*\[TOOL_CALL_END\]" ``` -The fix: -- Uses `\s*` instead of mandatory `\n` -- Makes the arguments group optional for zero-argument calls -- Accepts word chars, dots, and hyphens in function names +**Content format fix:** +Added `_is_glm_model()` detection to force "string" content format for GLM models, bypassing the incorrect auto-detection. ### Issue 2: Zero-Argument Tool Calls Crash @@ -44,8 +46,9 @@ Both paths now use the same robust extraction helpers for consistency. | File | Description | |------|-------------| -| `glm4_moe_tool_parser.py` | Fixed tool parser | +| `glm4_moe_tool_parser.py` | Fixed tool parser (regex fix) | | `utils.py` | Utility functions for partial JSON/tag handling | +| `vllm_patches/hf.py` | Patched renderer (content format fix) | | `Dockerfile` | Overlays patched files onto base image | | `Jenkinsfile` | CI/CD pipeline for building and pushing | | `tests/` | Test suite for tool call validation | diff --git a/tests/test_tool_debug.py b/tests/test_tool_debug.py new file mode 100644 index 0000000..1ae018c --- /dev/null +++ b/tests/test_tool_debug.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python3 +""" +Debug test to see what prompt the model actually receives. +""" + +import httpx +import json + +API_BASE = "https://api.vultrinference.com/v1" +API_KEY = "26DN7PNUB3YRBEPCDNMXKKD6ZODMETRSMOZQ" +MODEL = "zai-org/GLM-5.1-FP8" + + +def test_with_echo(): + """ + Test with echo=True to see the prompt tokens. + """ + + messages = [ + {"role": "user", "content": "Call the test function"}, + { + "role": "assistant", + "tool_calls": [{ + "id": "call_123", + "type": "function", + "function": {"name": "test_func", "arguments": "{}"} + }] + }, + { + "role": "tool", + "tool_call_id": "call_123", + "content": "VALUE_42" + } + ] + + tools = [{ + "type": "function", + "function": { + "name": "test_func", + "description": "A test function", + "parameters": {"type": "object", "properties": {}} + } + }] + + with httpx.Client(timeout=60.0) as client: + # Try to get prompt logprobs which might show us the prompt + response = client.post( + f"{API_BASE}/chat/completions", + headers={ + "Authorization": f"Bearer {API_KEY}", + "Content-Type": "application/json" + }, + json={ + "model": MODEL, + "messages": messages, + "tools": tools, + "stream": False, + "max_tokens": 100, + "logprobs": True, + "top_logprobs": 1, + "echo": True # Return prompt tokens + } + ) + + result = response.json() + + print("Full response:") + print(json.dumps(result, indent=2, ensure_ascii=False)) + + +def test_tool_only_message(): + """ + Test if a tool-only message (no tools param) works. + This is what worked in the previous test. + """ + + messages = [ + {"role": "user", "content": "What is 2+2?"}, + { + "role": "assistant", + "tool_calls": [{ + "id": "call_123", + "type": "function", + "function": {"name": "calc", "arguments": "{}"} + }], + "content": None + }, + { + "role": "tool", + "tool_call_id": "call_123", + "content": "The answer is 42" + } + ] + + # NO tools param - this worked before + with httpx.Client(timeout=60.0) as client: + response = client.post( + f"{API_BASE}/chat/completions", + headers={ + "Authorization": f"Bearer {API_KEY}", + "Content-Type": "application/json" + }, + json={ + "model": MODEL, + "messages": messages, + # NO tools param + "stream": False, + "max_tokens": 100 + } + ) + + result = response.json() + if "choices" in result: + content = result["choices"][0]["message"]["content"] + print(f"\nNo tools param - Response: {content}") + print(f"Contains 42: {'42' in content}") + else: + print(f"\nNo tools param - Error: {result}") + + +def test_with_tools_param(): + """ + Test WITH tools param - this is what fails. + """ + + messages = [ + {"role": "user", "content": "What is 2+2?"}, + { + "role": "assistant", + "tool_calls": [{ + "id": "call_123", + "type": "function", + "function": {"name": "calc", "arguments": "{}"} + }], + "content": None + }, + { + "role": "tool", + "tool_call_id": "call_123", + "content": "The answer is 42" + } + ] + + tools = [{ + "type": "function", + "function": { + "name": "calc", + "description": "Calculator", + "parameters": {"type": "object", "properties": {}} + } + }] + + with httpx.Client(timeout=60.0) as client: + response = client.post( + f"{API_BASE}/chat/completions", + headers={ + "Authorization": f"Bearer {API_KEY}", + "Content-Type": "application/json" + }, + json={ + "model": MODEL, + "messages": messages, + "tools": tools, # WITH tools param + "stream": False, + "max_tokens": 100 + } + ) + + result = response.json() + content = result["choices"][0]["message"]["content"] + print(f"\nWith tools param - Response: {content}") + print(f"Contains 42: {'42' in content}") + + +def test_without_assistant_tool_calls(): + """ + Test if the issue is the assistant message with tool_calls. + What if we just send user -> tool response? + """ + + messages = [ + {"role": "user", "content": "The calculator returned this result"}, + { + "role": "tool", + "tool_call_id": "call_123", + "content": "VALUE_IS_42" + } + ] + + with httpx.Client(timeout=60.0) as client: + response = client.post( + f"{API_BASE}/chat/completions", + headers={ + "Authorization": f"Bearer {API_KEY}", + "Content-Type": "application/json" + }, + json={ + "model": MODEL, + "messages": messages, + "stream": False, + "max_tokens": 100 + } + ) + + result = response.json() + if "choices" in result: + content = result["choices"][0]["message"]["content"] + print(f"\nNo assistant tool_calls - Response: {content}") + print(f"Contains 42: {'42' in content}") + else: + print(f"\nError: {result}") + + +if __name__ == "__main__": + print("=" * 60) + print("Debugging tool response visibility") + print("=" * 60) + + test_tool_only_message() + test_with_tools_param() + test_without_assistant_tool_calls() diff --git a/tests/test_tool_visibility.py b/tests/test_tool_visibility.py new file mode 100644 index 0000000..540455b --- /dev/null +++ b/tests/test_tool_visibility.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python3 +""" +Minimal test - is the tool response content being passed to the model? +""" + +import httpx +import json + +API_BASE = "https://api.vultrinference.com/v1" +API_KEY = "26DN7PNUB3YRBEPCDNMXKKD6ZODMETRSMOZQ" +MODEL = "zai-org/GLM-5.1-FP8" + + +def test_direct_prompt(): + """ + If we could send a direct prompt, what would it look like? + + GLM-5.1 expects tool responses in tags: + {"result": "42"} + + Let's test if the model can see content in that format. + """ + + # Simulate what the prompt SHOULD look like after chat template + messages = [ + {"role": "user", "content": "What did the function return?"}, + { + "role": "assistant", + "content": "I'll call the function.", + "tool_calls": [{ + "id": "call_123", + "type": "function", + "function": {"name": "get_value", "arguments": "{}"} + }] + }, + { + "role": "tool", + "tool_call_id": "call_123", + "content": "UNIQUE_MARKER_42" + } + ] + + tools = [{ + "type": "function", + "function": { + "name": "get_value", + "description": "Get a value", + "parameters": {"type": "object", "properties": {}} + } + }] + + with httpx.Client(timeout=60.0) as client: + response = client.post( + f"{API_BASE}/chat/completions", + headers={ + "Authorization": f"Bearer {API_KEY}", + "Content-Type": "application/json" + }, + json={ + "model": MODEL, + "messages": messages, + "tools": tools, + "stream": False, + "max_tokens": 100 + } + ) + + result = response.json() + + if "choices" in result: + content = result["choices"][0]["message"]["content"] + print(f"Model response: {content}") + print(f"Contains UNIQUE_MARKER_42: {'UNIQUE_MARKER_42' in content}") + else: + print(f"Error: {result}") + + +def test_fake_tool_response_in_user_message(): + """ + Test: What if we put the tool response in a user message instead? + This bypasses the role="tool" handling entirely. + """ + + messages = [ + {"role": "user", "content": "What did the function return?"}, + { + "role": "assistant", + "content": "I called the function.", + "tool_calls": [{ + "id": "call_123", + "type": "function", + "function": {"name": "get_value", "arguments": "{}"} + }] + }, + # Instead of role="tool", use user message + {"role": "user", "content": "The function returned: UNIQUE_MARKER_42"} + ] + + tools = [{ + "type": "function", + "function": { + "name": "get_value", + "description": "Get a value", + "parameters": {"type": "object", "properties": {}} + } + }] + + with httpx.Client(timeout=60.0) as client: + response = client.post( + f"{API_BASE}/chat/completions", + headers={ + "Authorization": f"Bearer {API_KEY}", + "Content-Type": "application/json" + }, + json={ + "model": MODEL, + "messages": messages, + "tools": tools, + "stream": False, + "max_tokens": 100 + } + ) + + result = response.json() + + if "choices" in result: + content = result["choices"][0]["message"]["content"] + print(f"\nUser message hack - Model response: {content}") + print(f"Contains UNIQUE_MARKER_42: {'UNIQUE_MARKER_42' in content}") + else: + print(f"Error: {result}") + + +def test_tool_response_as_observation_format(): + """ + Test: What if we format the tool response in the GLM expected format? + GLM expects: content + """ + + # Try putting the observations tag in the content + messages = [ + {"role": "user", "content": "What did the function return?"}, + { + "role": "assistant", + "content": "I called the function.", + "tool_calls": [{ + "id": "call_123", + "type": "function", + "function": {"name": "get_value", "arguments": "{}"} + }] + }, + { + "role": "tool", + "tool_call_id": "call_123", + "content": "UNIQUE_MARKER_42" + } + ] + + tools = [{ + "type": "function", + "function": { + "name": "get_value", + "description": "Get a value", + "parameters": {"type": "object", "properties": {}} + } + }] + + with httpx.Client(timeout=60.0) as client: + response = client.post( + f"{API_BASE}/chat/completions", + headers={ + "Authorization": f"Bearer {API_KEY}", + "Content-Type": "application/json" + }, + json={ + "model": MODEL, + "messages": messages, + "tools": tools, + "stream": False, + "max_tokens": 100 + } + ) + + result = response.json() + + if "choices" in result: + content = result["choices"][0]["message"]["content"] + print(f"\nWith tags - Model response: {content}") + print(f"Contains UNIQUE_MARKER_42: {'UNIQUE_MARKER_42' in content}") + else: + print(f"Error: {result}") + + +if __name__ == "__main__": + print("Testing tool response visibility") + print("=" * 60) + + test_direct_prompt() + test_fake_tool_response_in_user_message() + test_tool_response_as_observation_format() diff --git a/vllm_patches/hf.py b/vllm_patches/hf.py new file mode 100644 index 0000000..5879563 --- /dev/null +++ b/vllm_patches/hf.py @@ -0,0 +1,771 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import inspect +import itertools +from collections import defaultdict, deque +from collections.abc import Set +from functools import lru_cache +from typing import Any, Literal, cast, overload + +import jinja2 +import jinja2.ext +import jinja2.meta +import jinja2.nodes +import jinja2.parser +import jinja2.sandbox + +from vllm.config import ModelConfig, VllmConfig +from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ChatTemplateContentFormat, + ChatTemplateContentFormatOption, + ChatTemplateResolutionError, + ConversationMessage, + load_chat_template, + parse_chat_messages, + parse_chat_messages_async, +) +from vllm.inputs import MultiModalDataDict, MultiModalUUIDDict +from vllm.logger import init_logger +from vllm.tokenizers.hf import HfTokenizer +from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path +from vllm.transformers_utils.processor import cached_get_processor +from vllm.utils.async_utils import make_async +from vllm.utils.func_utils import supports_kw + +from .base import BaseRenderer +from .inputs import DictPrompt +from .inputs.preprocess import parse_dec_only_prompt +from .params import ChatParams + +logger = init_logger(__name__) + + +_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], str | None]() +""" +Used in `_try_get_processor_chat_template` to avoid calling +`cached_get_processor` again if the processor fails to be loaded. + +This is needed because `lru_cache` does not cache when an exception happens. +""" + + +def _try_get_processor_chat_template( + tokenizer: HfTokenizer, + *, + trust_remote_code: bool, +) -> str | None: + cache_key = (tokenizer.name_or_path, trust_remote_code) + if cache_key in _PROCESSOR_CHAT_TEMPLATES: + return _PROCESSOR_CHAT_TEMPLATES[cache_key] + + from transformers import ( + PreTrainedTokenizer, + PreTrainedTokenizerFast, + ProcessorMixin, + ) + + try: + processor = cached_get_processor( + tokenizer.name_or_path, + processor_cls=( + PreTrainedTokenizer, + PreTrainedTokenizerFast, + ProcessorMixin, + ), + trust_remote_code=trust_remote_code, + ) + if ( + isinstance(processor, ProcessorMixin) + and hasattr(processor, "chat_template") + and (chat_template := processor.chat_template) is not None + ): + _PROCESSOR_CHAT_TEMPLATES[cache_key] = chat_template + return chat_template + except Exception: + logger.debug( + "Failed to load AutoProcessor chat template for %s", + tokenizer.name_or_path, + exc_info=True, + ) + + _PROCESSOR_CHAT_TEMPLATES[cache_key] = None + return None + + +def resolve_chat_template( + tokenizer: HfTokenizer, + chat_template: str | None, + tools: list[dict[str, Any]] | None, + *, + model_config: "ModelConfig", +) -> str | None: + # 1st priority: The given chat template + if chat_template is not None: + # Resolve template names (e.g. "tool_use") to actual Jinja content + # so that downstream kwargs detection can parse template variables. + return tokenizer.get_chat_template(chat_template, tools=tools) + + # 2nd priority: AutoProcessor chat template, unless tool calling is enabled + if tools is None: + chat_template = _try_get_processor_chat_template( + tokenizer, + trust_remote_code=model_config.trust_remote_code, + ) + if chat_template is not None: + return chat_template + + # 3rd priority: AutoTokenizer chat template + try: + return tokenizer.get_chat_template(chat_template, tools=tools) + except Exception: + logger.debug( + "Failed to load AutoTokenizer chat template for %s", + tokenizer.name_or_path, + exc_info=True, + ) + + # 4th priority: Predefined fallbacks + path = get_chat_template_fallback_path( + model_type=model_config.hf_config.model_type, + tokenizer_name_or_path=tokenizer.name_or_path, + ) + if path is not None: + logger.info_once( + "Loading chat template fallback for %s as there isn't one " + "defined on HF Hub.", + tokenizer.name_or_path, + ) + chat_template = load_chat_template(path) + else: + logger.debug_once( + "There is no chat template fallback for %s", tokenizer.name_or_path + ) + + return chat_template + + +def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool: + if isinstance(node, jinja2.nodes.Name): + return node.ctx == "load" and node.name == varname + + return False + + +def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool: + if isinstance(node, jinja2.nodes.Getitem): + return ( + _is_var_access(node.node, varname) + and isinstance(node.arg, jinja2.nodes.Const) + and node.arg.value == key + ) + + if isinstance(node, jinja2.nodes.Getattr): + return _is_var_access(node.node, varname) and node.attr == key + + return False + + +def _is_var_or_elems_access( + node: jinja2.nodes.Node, + varname: str, + key: str | None = None, +) -> bool: + if isinstance(node, jinja2.nodes.Filter): + return node.node is not None and _is_var_or_elems_access( + node.node, varname, key + ) + if isinstance(node, jinja2.nodes.Test): + return _is_var_or_elems_access(node.node, varname, key) + + if isinstance(node, jinja2.nodes.Getitem) and isinstance( + node.arg, jinja2.nodes.Slice + ): + return _is_var_or_elems_access(node.node, varname, key) + + return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname) + + +def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str): + # Global variable that is implicitly defined at the root + yield root, varname + + # Iterative BFS + related_varnames = deque([varname]) + while related_varnames: + related_varname = related_varnames.popleft() + + for assign_ast in root.find_all(jinja2.nodes.Assign): + lhs = assign_ast.target + rhs = assign_ast.node + + if _is_var_or_elems_access(rhs, related_varname): + assert isinstance(lhs, jinja2.nodes.Name) + yield assign_ast, lhs.name + + # Avoid infinite looping for self-assignment + if lhs.name != related_varname: + related_varnames.append(lhs.name) + + +# NOTE: The proper way to handle this is to build a CFG so that we can handle +# the scope in which each variable is defined, but that is too complicated +def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node): + messages_varnames = [ + varname for _, varname in _iter_nodes_assign_var_or_elems(root, "messages") + ] + + # Search for {%- for message in messages -%} loops + for loop_ast in root.find_all(jinja2.nodes.For): + loop_iter = loop_ast.iter + loop_target = loop_ast.target + + for varname in messages_varnames: + if _is_var_or_elems_access(loop_iter, varname): + assert isinstance(loop_target, jinja2.nodes.Name) + yield loop_ast, loop_target.name + break + + +def _iter_nodes_assign_content_item(root: jinja2.nodes.Node): + message_varnames = [ + varname for _, varname in _iter_nodes_assign_messages_item(root) + ] + + # Search for {%- for content in message['content'] -%} loops + for loop_ast in root.find_all(jinja2.nodes.For): + loop_iter = loop_ast.iter + loop_target = loop_ast.target + + for varname in message_varnames: + if _is_var_or_elems_access(loop_iter, varname, "content"): + assert isinstance(loop_target, jinja2.nodes.Name) + yield loop_ast, loop_target.name + break + + +def _try_extract_ast(chat_template: str) -> jinja2.nodes.Template | None: + import transformers.utils.chat_template_utils as hf_chat_utils + + try: + jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template) + return jinja_compiled.environment.parse(chat_template) + except Exception: + logger.exception("Error when compiling Jinja template") + return None + + +@lru_cache(maxsize=32) +def _detect_content_format( + chat_template: str, + *, + default: ChatTemplateContentFormat, +) -> ChatTemplateContentFormat: + jinja_ast = _try_extract_ast(chat_template) + if jinja_ast is None: + return default + + try: + next(_iter_nodes_assign_content_item(jinja_ast)) + except StopIteration: + return "string" + except Exception: + logger.exception("Error when parsing AST of Jinja template") + return default + else: + return "openai" + + +def _is_glm_model(tokenizer: HfTokenizer, model_config: "ModelConfig") -> bool: + """Check if this is a GLM model that requires string content format. + + GLM models (GLM-4, GLM-4.5, GLM-5.x) have a chat template that incorrectly + triggers "openai" content format detection because they iterate over + m.content for tool responses. However, the template expects string content + for tool messages (checking `m.content is string`). + + This detection ensures we force "string" format for GLM models. + """ + # Check tokenizer name/path for GLM indicators + name_or_path = tokenizer.name_or_path.lower() + glm_indicators = ["glm-4", "glm-5", "glm4", "glm5", "zai-org/glm"] + if any(ind in name_or_path for ind in glm_indicators): + return True + + # Check model type in config + if hasattr(model_config, "hf_config") and hasattr(model_config.hf_config, "model_type"): + model_type = model_config.hf_config.model_type.lower() + if "glm" in model_type: + return True + + return False + + +def _resolve_chat_template_content_format( + chat_template: str | None, + tools: list[dict[str, Any]] | None, + tokenizer: HfTokenizer, + *, + model_config: "ModelConfig", +) -> ChatTemplateContentFormat: + # GLM models require "string" content format for tool responses to work + # The template has `{% for tr in m.content %}` which triggers "openai" + # detection, but then checks `m.content is string` which fails for arrays. + if _is_glm_model(tokenizer, model_config): + logger.debug( + "Forcing 'string' content format for GLM model: %s", + tokenizer.name_or_path, + ) + return "string" + + resolved_chat_template = resolve_chat_template( + tokenizer, + chat_template=chat_template, + tools=tools, + model_config=model_config, + ) + + jinja_text = ( + resolved_chat_template + if isinstance(resolved_chat_template, str) + else load_chat_template(chat_template, is_literal=True) + ) + + detected_format = ( + "string" + if jinja_text is None + else _detect_content_format(jinja_text, default="string") + ) + + return detected_format + + +@lru_cache +def _log_chat_template_content_format( + chat_template: str | None, # For caching purposes + given_format: ChatTemplateContentFormatOption, + detected_format: ChatTemplateContentFormatOption, +): + logger.info( + "Detected the chat template content format to be '%s'. " + "You can set `--chat-template-content-format` to override this.", + detected_format, + ) + + if given_format != "auto" and given_format != detected_format: + logger.warning( + "You specified `--chat-template-content-format %s` " + "which is different from the detected format '%s'. " + "If our automatic detection is incorrect, please consider " + "opening a GitHub issue so that we can improve it: " + "https://github.com/vllm-project/vllm/issues/new/choose", + given_format, + detected_format, + ) + + +def resolve_chat_template_content_format( + chat_template: str | None, + tools: list[dict[str, Any]] | None, + given_format: ChatTemplateContentFormatOption, + tokenizer: HfTokenizer, + *, + model_config: "ModelConfig", +) -> ChatTemplateContentFormat: + if given_format != "auto": + return given_format + + detected_format = _resolve_chat_template_content_format( + chat_template, + tools, + tokenizer, + model_config=model_config, + ) + + _log_chat_template_content_format( + chat_template, + given_format=given_format, + detected_format=detected_format, + ) + + return detected_format + + +# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412 +# only preserve the parse function used to resolve chat template kwargs +class AssistantTracker(jinja2.ext.Extension): + tags = {"generation"} + + def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.Node: + lineno = next(parser.stream).lineno + body = parser.parse_statements(("name:endgeneration",), drop_needle=True) + call = self.call_method("_generation_support") + call_block = jinja2.nodes.CallBlock(call, [], [], body) + return call_block.set_lineno(lineno) + + +def _resolve_chat_template_kwargs(chat_template: str) -> Set[str]: + env = jinja2.sandbox.ImmutableSandboxedEnvironment( + trim_blocks=True, + lstrip_blocks=True, + extensions=[AssistantTracker, jinja2.ext.loopcontrols], + ) + parsed_content = env.parse(chat_template) + template_vars = jinja2.meta.find_undeclared_variables(parsed_content) + return template_vars + + +_cached_resolve_chat_template_kwargs = lru_cache(_resolve_chat_template_kwargs) + + +@lru_cache +def _get_hf_base_chat_template_params() -> frozenset[str]: + from transformers import PreTrainedTokenizer + + # Get standard parameters from HuggingFace's base tokenizer class. + # This dynamically extracts parameters from PreTrainedTokenizer's + # apply_chat_template method, ensuring compatibility with tokenizers + # that use **kwargs to receive standard parameters. + + # Read signature from HF's base class - the single source of truth + base_sig = inspect.signature(PreTrainedTokenizer.apply_chat_template) + + # Exclude VAR_KEYWORD (**kwargs) and VAR_POSITIONAL (*args) placeholders + return frozenset( + p.name + for p in base_sig.parameters.values() + if p.kind + not in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL) + ) + + +def resolve_chat_template_kwargs( + tokenizer: HfTokenizer, + chat_template: str, + chat_template_kwargs: dict[str, Any], + raise_on_unexpected: bool = True, +) -> dict[str, Any]: + # We exclude chat_template from kwargs here, because + # chat template has been already resolved at this stage + unexpected_vars = {"chat_template", "tokenize"} + if raise_on_unexpected and ( + unexpected_in_kwargs := unexpected_vars & chat_template_kwargs.keys() + ): + raise ValueError( + "Found unexpected chat template kwargs from request: " + f"{unexpected_in_kwargs}" + ) + + fn_kw = { + k + for k in chat_template_kwargs + if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False) + } + template_vars = _cached_resolve_chat_template_kwargs(chat_template) + + # Allow standard HF parameters even if tokenizer uses **kwargs to receive them + hf_base_params = _get_hf_base_chat_template_params() + + accept_vars = (fn_kw | template_vars | hf_base_params) - unexpected_vars + return {k: v for k, v in chat_template_kwargs.items() if k in accept_vars} + + +@overload +def safe_apply_chat_template( + model_config: "ModelConfig", + tokenizer: HfTokenizer, + conversation: list[ConversationMessage], + *, + tools: list[dict[str, Any]] | None = ..., + chat_template: str | None = ..., + tokenize: Literal[True] = ..., + **kwargs, +) -> list[int]: ... +@overload +def safe_apply_chat_template( + model_config: "ModelConfig", + tokenizer: HfTokenizer, + conversation: list[ConversationMessage], + *, + tools: list[dict[str, Any]] | None = ..., + chat_template: str | None = ..., + tokenize: Literal[False] = ..., + **kwargs, +) -> str: ... +def safe_apply_chat_template( + model_config: "ModelConfig", + tokenizer: HfTokenizer, + conversation: list[ConversationMessage], + *, + tools: list[dict[str, Any]] | None = None, + chat_template: str | None = None, + tokenize: bool = True, + **kwargs, +) -> str | list[int]: + chat_template = resolve_chat_template( + tokenizer, + chat_template=chat_template, + tools=tools, + model_config=model_config, + ) + if chat_template is None: + raise ChatTemplateResolutionError( + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one." + ) + + resolved_kwargs = resolve_chat_template_kwargs( + tokenizer=tokenizer, + chat_template=chat_template, + chat_template_kwargs=kwargs, + ) + + try: + return tokenizer.apply_chat_template( + conversation=conversation, # type: ignore[arg-type] + tools=tools, # type: ignore[arg-type] + chat_template=chat_template, + tokenize=tokenize, + **resolved_kwargs, + ) + # External library exceptions can sometimes occur despite the framework's + # internal exception management capabilities. + except Exception as e: + # Log and report any library-related exceptions for further + # investigation. + logger.exception( + "An error occurred in `transformers` while applying chat template" + ) + raise ValueError(str(e)) from e + + +def rebuild_mm_uuids_from_mm_data( + mm_uuids: MultiModalUUIDDict, + mm_data: MultiModalDataDict, +) -> MultiModalUUIDDict: + """Rebuild mm_uuids after vision_chunk processing. + + When videos are split into chunks, the original UUIDs need to be updated + to reflect the new UUIDs generated for each chunk. + + Args: + mm_uuids: Original UUIDs dictionary + mm_data: Processed multimodal data with vision_chunk items + + Returns: + Updated UUIDs dictionary with chunk UUIDs + """ + vision_chunks = mm_data.get("vision_chunk") + if vision_chunks is None: + return mm_uuids + + assert all(isinstance(item, dict) for item in vision_chunks), ( + "Expected all vision_chunk items to be dicts" + ) + vision_chunks = cast(list[dict[str, Any]], vision_chunks) + vision_chunk_uuids = [ + uuid_val for item in vision_chunks if (uuid_val := item.get("uuid")) is not None + ] + + if vision_chunk_uuids: + mm_uuids = dict(mm_uuids) + mm_uuids["vision_chunk"] = vision_chunk_uuids + + return mm_uuids + + +def build_video_prompts_from_mm_data( + mm_data: MultiModalDataDict, +) -> list[str]: + """Build video prompts from vision_chunk data. + + Collects prompts from video chunks and groups them by video_idx. + + Args: + mm_data: Processed multimodal data with vision_chunk items + + Returns: + List of video prompts, one per video. + """ + vision_chunks = mm_data.get("vision_chunk") + if vision_chunks is None: + return [] + + # Group chunks by video_idx + video_prompts_dict: dict[int, list[str]] = defaultdict(list) + + for item in vision_chunks: + # vision_chunk items are always dicts (VisionChunkImage/VisionChunkVideo) + assert isinstance(item, dict) + if item.get("type") == "video_chunk": + video_idx = item.get("video_idx", 0) + prompt = item.get("prompt", "") + video_prompts_dict[video_idx].append(prompt) + + # Build prompts in video order + video_prompts = [ + "".join(video_prompts_dict[video_idx]) + for video_idx in sorted(video_prompts_dict.keys()) + ] + + return video_prompts + + +def replace_vision_chunk_video_placeholder( + prompt_raw: str | list[int], + mm_data: MultiModalDataDict, + video_placeholder: str | None, +) -> str | list[int]: + # get video placeholder, replace it with runtime video-chunk prompts + if video_placeholder and isinstance(prompt_raw, str): + video_prompts = build_video_prompts_from_mm_data(mm_data) + + # replace in order + prompt_raw_parts = prompt_raw.split(video_placeholder) + if len(prompt_raw_parts) == len(video_prompts) + 1: + prompt_raw = "".join( + itertools.chain.from_iterable(zip(prompt_raw_parts, video_prompts)) + ) + prompt_raw += prompt_raw_parts[-1] + else: + logger.warning( + "Number of video placeholders (%d) does not match " + "number of videos (%d) in the request.", + len(prompt_raw_parts) - 1, + len(video_prompts), + ) + return prompt_raw + + +class HfRenderer(BaseRenderer[HfTokenizer]): + def __init__( + self, + config: VllmConfig, + tokenizer: HfTokenizer | None, + ) -> None: + super().__init__(config, tokenizer) + + self.use_unified_vision_chunk = getattr( + config.model_config.hf_config, "use_unified_vision_chunk", False + ) + + self._apply_chat_template_async = make_async( + safe_apply_chat_template, executor=self._executor + ) + + def render_messages( + self, + messages: list[ChatCompletionMessageParam], + params: ChatParams, + ) -> tuple[list[ConversationMessage], DictPrompt]: + model_config = self.model_config + tokenizer = self.get_tokenizer() + + conversation, mm_data, mm_uuids = parse_chat_messages( + messages, + model_config, + content_format=resolve_chat_template_content_format( + chat_template=params.chat_template, + tools=params.chat_template_kwargs.get("tools"), + given_format=params.chat_template_content_format, + tokenizer=tokenizer, + model_config=model_config, + ), + media_io_kwargs=params.media_io_kwargs, + mm_processor_kwargs=params.mm_processor_kwargs, + ) + + prompt_raw = safe_apply_chat_template( + model_config, + tokenizer, + conversation, + **params.get_apply_chat_template_kwargs(), + ) + + # NOTE: use_unified_vision_chunk is currently specific to Kimi-K2.5 + # model which uses unified vision chunks for both images and videos. + if ( + self.use_unified_vision_chunk + and mm_uuids is not None + and mm_data is not None + ): + mm_uuids = rebuild_mm_uuids_from_mm_data(mm_uuids, mm_data) + + # get video placeholder, replace it with runtime video-chunk prompts + video_placeholder = getattr( + model_config.hf_config, "video_placeholder", None + ) + prompt_raw = cast( + list[int], + replace_vision_chunk_video_placeholder( + prompt_raw, + mm_data, + video_placeholder, + ), + ) + + prompt = parse_dec_only_prompt(prompt_raw) + if mm_data is not None: + prompt["multi_modal_data"] = mm_data + if mm_uuids is not None: + prompt["multi_modal_uuids"] = mm_uuids + + return conversation, prompt + + async def render_messages_async( + self, + messages: list[ChatCompletionMessageParam], + params: ChatParams, + ) -> tuple[list[ConversationMessage], DictPrompt]: + model_config = self.model_config + tokenizer = self.get_tokenizer() + + conversation, mm_data, mm_uuids = await parse_chat_messages_async( + messages, + model_config, + content_format=resolve_chat_template_content_format( + chat_template=params.chat_template, + tools=params.chat_template_kwargs.get("tools"), + given_format=params.chat_template_content_format, + tokenizer=tokenizer, + model_config=model_config, + ), + media_io_kwargs=params.media_io_kwargs, + mm_processor_kwargs=params.mm_processor_kwargs, + ) + + prompt_raw = await self._apply_chat_template_async( + model_config, + tokenizer, + conversation, + **params.get_apply_chat_template_kwargs(), + ) + + # NOTE: use_unified_vision_chunk is currently specific to Kimi-K2.5 + # model which uses unified vision chunks for both images and videos. + if ( + self.use_unified_vision_chunk + and mm_uuids is not None + and mm_data is not None + ): + # get video placeholder, replace it with runtime video-chunk prompts + video_placeholder = getattr( + model_config.hf_config, "video_placeholder", None + ) + prompt_raw = cast( + list[int], + replace_vision_chunk_video_placeholder( + prompt_raw, + mm_data, + video_placeholder, + ), + ) + + prompt = parse_dec_only_prompt(prompt_raw) + if mm_data is not None: + prompt["multi_modal_data"] = mm_data + if mm_uuids is not None: + prompt["multi_modal_uuids"] = mm_uuids + + return conversation, prompt