Add hf.py patch to force string content format for GLM models
- Tool response content was being dropped because vLLM detected 'openai' content format incorrectly for GLM templates - Added _is_glm_model() detection to force 'string' format - Updated Dockerfile to include hf.py patch - Added debug tests for tool visibility
This commit is contained in:
@@ -1,5 +1,10 @@
|
|||||||
ARG BASE_IMAGE=vllm/vllm-openai:glm51-cu130
|
ARG BASE_IMAGE=vllm/vllm-openai:glm51-cu130
|
||||||
FROM ${BASE_IMAGE}
|
FROM ${BASE_IMAGE}
|
||||||
|
|
||||||
|
# Patch tool parser for GLM regex fix
|
||||||
COPY glm4_moe_tool_parser.py /usr/local/lib/python3.12/dist-packages/vllm/tool_parsers/glm4_moe_tool_parser.py
|
COPY glm4_moe_tool_parser.py /usr/local/lib/python3.12/dist-packages/vllm/tool_parsers/glm4_moe_tool_parser.py
|
||||||
COPY utils.py /usr/local/lib/python3.12/dist-packages/vllm/tool_parsers/utils.py
|
COPY utils.py /usr/local/lib/python3.12/dist-packages/vllm/tool_parsers/utils.py
|
||||||
|
|
||||||
|
# Patch hf renderer to force string content format for GLM models
|
||||||
|
# This fixes the issue where tool response content is dropped
|
||||||
|
COPY vllm_patches/hf.py /usr/local/lib/python3.12/dist-packages/vllm/renderers/hf.py
|
||||||
|
|||||||
15
README.md
15
README.md
@@ -8,7 +8,11 @@ Patches vLLM's GLM-4/GLM-5.1 tool parser to fix multiple issues with tool call h
|
|||||||
|
|
||||||
**Symptom:** When the model makes a tool call and receives a response, it would act as if the response was empty ("The function returned no output") even though valid content was provided.
|
**Symptom:** When the model makes a tool call and receives a response, it would act as if the response was empty ("The function returned no output") even though valid content was provided.
|
||||||
|
|
||||||
**Root Cause:** The `func_detail_regex` required a newline between the function name and first argument tag, but GLM-5.1's chat template does NOT include that newline. The regex silently failed to match, tool call extraction failed, and somewhere in that failure path the tool response content got lost.
|
**Root Cause:** Two bugs working together:
|
||||||
|
|
||||||
|
1. **Tool parser regex mismatch** (`glm4_moe_tool_parser.py`): The `func_detail_regex` required a newline between the function name and first argument tag, but GLM-5.1's chat template doesn't include that newline. The regex silently failed to match.
|
||||||
|
|
||||||
|
2. **Content format detection wrong** (`vllm/renderers/hf.py`): vLLM detected "openai" content format because the GLM template has `{% for tr in m.content %}` for tool responses. But the template then checks `m.content is string` which is False for OpenAI format arrays, causing content to be dropped.
|
||||||
|
|
||||||
**Model output format (no newline after name):**
|
**Model output format (no newline after name):**
|
||||||
```
|
```
|
||||||
@@ -25,10 +29,8 @@ r"\[TOOL_CALL_START\]([^\n]*)\n(.*)\[TOOL_CALL_END\]" # Requires \n after name
|
|||||||
r"\[TOOL_CALL_START\]\s*([\w.\-]+)\s*((?:\[ARG_KEY\].*)?)\s*\[TOOL_CALL_END\]"
|
r"\[TOOL_CALL_START\]\s*([\w.\-]+)\s*((?:\[ARG_KEY\].*)?)\s*\[TOOL_CALL_END\]"
|
||||||
```
|
```
|
||||||
|
|
||||||
The fix:
|
**Content format fix:**
|
||||||
- Uses `\s*` instead of mandatory `\n`
|
Added `_is_glm_model()` detection to force "string" content format for GLM models, bypassing the incorrect auto-detection.
|
||||||
- Makes the arguments group optional for zero-argument calls
|
|
||||||
- Accepts word chars, dots, and hyphens in function names
|
|
||||||
|
|
||||||
### Issue 2: Zero-Argument Tool Calls Crash
|
### Issue 2: Zero-Argument Tool Calls Crash
|
||||||
|
|
||||||
@@ -44,8 +46,9 @@ Both paths now use the same robust extraction helpers for consistency.
|
|||||||
|
|
||||||
| File | Description |
|
| File | Description |
|
||||||
|------|-------------|
|
|------|-------------|
|
||||||
| `glm4_moe_tool_parser.py` | Fixed tool parser |
|
| `glm4_moe_tool_parser.py` | Fixed tool parser (regex fix) |
|
||||||
| `utils.py` | Utility functions for partial JSON/tag handling |
|
| `utils.py` | Utility functions for partial JSON/tag handling |
|
||||||
|
| `vllm_patches/hf.py` | Patched renderer (content format fix) |
|
||||||
| `Dockerfile` | Overlays patched files onto base image |
|
| `Dockerfile` | Overlays patched files onto base image |
|
||||||
| `Jenkinsfile` | CI/CD pipeline for building and pushing |
|
| `Jenkinsfile` | CI/CD pipeline for building and pushing |
|
||||||
| `tests/` | Test suite for tool call validation |
|
| `tests/` | Test suite for tool call validation |
|
||||||
|
|||||||
221
tests/test_tool_debug.py
Normal file
221
tests/test_tool_debug.py
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Debug test to see what prompt the model actually receives.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import json
|
||||||
|
|
||||||
|
API_BASE = "https://api.vultrinference.com/v1"
|
||||||
|
API_KEY = "26DN7PNUB3YRBEPCDNMXKKD6ZODMETRSMOZQ"
|
||||||
|
MODEL = "zai-org/GLM-5.1-FP8"
|
||||||
|
|
||||||
|
|
||||||
|
def test_with_echo():
|
||||||
|
"""
|
||||||
|
Test with echo=True to see the prompt tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "Call the test function"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": "call_123",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "test_func", "arguments": "{}"}
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_123",
|
||||||
|
"content": "VALUE_42"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
tools = [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "test_func",
|
||||||
|
"description": "A test function",
|
||||||
|
"parameters": {"type": "object", "properties": {}}
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
# Try to get prompt logprobs which might show us the prompt
|
||||||
|
response = client.post(
|
||||||
|
f"{API_BASE}/chat/completions",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {API_KEY}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"model": MODEL,
|
||||||
|
"messages": messages,
|
||||||
|
"tools": tools,
|
||||||
|
"stream": False,
|
||||||
|
"max_tokens": 100,
|
||||||
|
"logprobs": True,
|
||||||
|
"top_logprobs": 1,
|
||||||
|
"echo": True # Return prompt tokens
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
print("Full response:")
|
||||||
|
print(json.dumps(result, indent=2, ensure_ascii=False))
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_only_message():
|
||||||
|
"""
|
||||||
|
Test if a tool-only message (no tools param) works.
|
||||||
|
This is what worked in the previous test.
|
||||||
|
"""
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "What is 2+2?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": "call_123",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "calc", "arguments": "{}"}
|
||||||
|
}],
|
||||||
|
"content": None
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_123",
|
||||||
|
"content": "The answer is 42"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# NO tools param - this worked before
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
response = client.post(
|
||||||
|
f"{API_BASE}/chat/completions",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {API_KEY}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"model": MODEL,
|
||||||
|
"messages": messages,
|
||||||
|
# NO tools param
|
||||||
|
"stream": False,
|
||||||
|
"max_tokens": 100
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
if "choices" in result:
|
||||||
|
content = result["choices"][0]["message"]["content"]
|
||||||
|
print(f"\nNo tools param - Response: {content}")
|
||||||
|
print(f"Contains 42: {'42' in content}")
|
||||||
|
else:
|
||||||
|
print(f"\nNo tools param - Error: {result}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_with_tools_param():
|
||||||
|
"""
|
||||||
|
Test WITH tools param - this is what fails.
|
||||||
|
"""
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "What is 2+2?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": "call_123",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "calc", "arguments": "{}"}
|
||||||
|
}],
|
||||||
|
"content": None
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_123",
|
||||||
|
"content": "The answer is 42"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
tools = [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "calc",
|
||||||
|
"description": "Calculator",
|
||||||
|
"parameters": {"type": "object", "properties": {}}
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
response = client.post(
|
||||||
|
f"{API_BASE}/chat/completions",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {API_KEY}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"model": MODEL,
|
||||||
|
"messages": messages,
|
||||||
|
"tools": tools, # WITH tools param
|
||||||
|
"stream": False,
|
||||||
|
"max_tokens": 100
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
content = result["choices"][0]["message"]["content"]
|
||||||
|
print(f"\nWith tools param - Response: {content}")
|
||||||
|
print(f"Contains 42: {'42' in content}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_without_assistant_tool_calls():
|
||||||
|
"""
|
||||||
|
Test if the issue is the assistant message with tool_calls.
|
||||||
|
What if we just send user -> tool response?
|
||||||
|
"""
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "The calculator returned this result"},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_123",
|
||||||
|
"content": "VALUE_IS_42"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
response = client.post(
|
||||||
|
f"{API_BASE}/chat/completions",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {API_KEY}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"model": MODEL,
|
||||||
|
"messages": messages,
|
||||||
|
"stream": False,
|
||||||
|
"max_tokens": 100
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
if "choices" in result:
|
||||||
|
content = result["choices"][0]["message"]["content"]
|
||||||
|
print(f"\nNo assistant tool_calls - Response: {content}")
|
||||||
|
print(f"Contains 42: {'42' in content}")
|
||||||
|
else:
|
||||||
|
print(f"\nError: {result}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("=" * 60)
|
||||||
|
print("Debugging tool response visibility")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
test_tool_only_message()
|
||||||
|
test_with_tools_param()
|
||||||
|
test_without_assistant_tool_calls()
|
||||||
200
tests/test_tool_visibility.py
Normal file
200
tests/test_tool_visibility.py
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Minimal test - is the tool response content being passed to the model?
|
||||||
|
"""
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import json
|
||||||
|
|
||||||
|
API_BASE = "https://api.vultrinference.com/v1"
|
||||||
|
API_KEY = "26DN7PNUB3YRBEPCDNMXKKD6ZODMETRSMOZQ"
|
||||||
|
MODEL = "zai-org/GLM-5.1-FP8"
|
||||||
|
|
||||||
|
|
||||||
|
def test_direct_prompt():
|
||||||
|
"""
|
||||||
|
If we could send a direct prompt, what would it look like?
|
||||||
|
|
||||||
|
GLM-5.1 expects tool responses in <observations> tags:
|
||||||
|
<observations>{"result": "42"}</observations>
|
||||||
|
|
||||||
|
Let's test if the model can see content in that format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Simulate what the prompt SHOULD look like after chat template
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "What did the function return?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I'll call the function.",
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": "call_123",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "get_value", "arguments": "{}"}
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_123",
|
||||||
|
"content": "UNIQUE_MARKER_42"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
tools = [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_value",
|
||||||
|
"description": "Get a value",
|
||||||
|
"parameters": {"type": "object", "properties": {}}
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
response = client.post(
|
||||||
|
f"{API_BASE}/chat/completions",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {API_KEY}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"model": MODEL,
|
||||||
|
"messages": messages,
|
||||||
|
"tools": tools,
|
||||||
|
"stream": False,
|
||||||
|
"max_tokens": 100
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
if "choices" in result:
|
||||||
|
content = result["choices"][0]["message"]["content"]
|
||||||
|
print(f"Model response: {content}")
|
||||||
|
print(f"Contains UNIQUE_MARKER_42: {'UNIQUE_MARKER_42' in content}")
|
||||||
|
else:
|
||||||
|
print(f"Error: {result}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_fake_tool_response_in_user_message():
|
||||||
|
"""
|
||||||
|
Test: What if we put the tool response in a user message instead?
|
||||||
|
This bypasses the role="tool" handling entirely.
|
||||||
|
"""
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "What did the function return?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I called the function.",
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": "call_123",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "get_value", "arguments": "{}"}
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
# Instead of role="tool", use user message
|
||||||
|
{"role": "user", "content": "The function returned: UNIQUE_MARKER_42"}
|
||||||
|
]
|
||||||
|
|
||||||
|
tools = [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_value",
|
||||||
|
"description": "Get a value",
|
||||||
|
"parameters": {"type": "object", "properties": {}}
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
response = client.post(
|
||||||
|
f"{API_BASE}/chat/completions",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {API_KEY}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"model": MODEL,
|
||||||
|
"messages": messages,
|
||||||
|
"tools": tools,
|
||||||
|
"stream": False,
|
||||||
|
"max_tokens": 100
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
if "choices" in result:
|
||||||
|
content = result["choices"][0]["message"]["content"]
|
||||||
|
print(f"\nUser message hack - Model response: {content}")
|
||||||
|
print(f"Contains UNIQUE_MARKER_42: {'UNIQUE_MARKER_42' in content}")
|
||||||
|
else:
|
||||||
|
print(f"Error: {result}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_response_as_observation_format():
|
||||||
|
"""
|
||||||
|
Test: What if we format the tool response in the GLM expected format?
|
||||||
|
GLM expects: <observations>content</observations>
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Try putting the observations tag in the content
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "What did the function return?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I called the function.",
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": "call_123",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "get_value", "arguments": "{}"}
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_123",
|
||||||
|
"content": "<observations>UNIQUE_MARKER_42</observations>"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
tools = [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_value",
|
||||||
|
"description": "Get a value",
|
||||||
|
"parameters": {"type": "object", "properties": {}}
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
|
||||||
|
with httpx.Client(timeout=60.0) as client:
|
||||||
|
response = client.post(
|
||||||
|
f"{API_BASE}/chat/completions",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {API_KEY}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"model": MODEL,
|
||||||
|
"messages": messages,
|
||||||
|
"tools": tools,
|
||||||
|
"stream": False,
|
||||||
|
"max_tokens": 100
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
if "choices" in result:
|
||||||
|
content = result["choices"][0]["message"]["content"]
|
||||||
|
print(f"\nWith <observations> tags - Model response: {content}")
|
||||||
|
print(f"Contains UNIQUE_MARKER_42: {'UNIQUE_MARKER_42' in content}")
|
||||||
|
else:
|
||||||
|
print(f"Error: {result}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("Testing tool response visibility")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
test_direct_prompt()
|
||||||
|
test_fake_tool_response_in_user_message()
|
||||||
|
test_tool_response_as_observation_format()
|
||||||
771
vllm_patches/hf.py
Normal file
771
vllm_patches/hf.py
Normal file
@@ -0,0 +1,771 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import inspect
|
||||||
|
import itertools
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
from collections.abc import Set
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Any, Literal, cast, overload
|
||||||
|
|
||||||
|
import jinja2
|
||||||
|
import jinja2.ext
|
||||||
|
import jinja2.meta
|
||||||
|
import jinja2.nodes
|
||||||
|
import jinja2.parser
|
||||||
|
import jinja2.sandbox
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig, VllmConfig
|
||||||
|
from vllm.entrypoints.chat_utils import (
|
||||||
|
ChatCompletionMessageParam,
|
||||||
|
ChatTemplateContentFormat,
|
||||||
|
ChatTemplateContentFormatOption,
|
||||||
|
ChatTemplateResolutionError,
|
||||||
|
ConversationMessage,
|
||||||
|
load_chat_template,
|
||||||
|
parse_chat_messages,
|
||||||
|
parse_chat_messages_async,
|
||||||
|
)
|
||||||
|
from vllm.inputs import MultiModalDataDict, MultiModalUUIDDict
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.tokenizers.hf import HfTokenizer
|
||||||
|
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
|
||||||
|
from vllm.transformers_utils.processor import cached_get_processor
|
||||||
|
from vllm.utils.async_utils import make_async
|
||||||
|
from vllm.utils.func_utils import supports_kw
|
||||||
|
|
||||||
|
from .base import BaseRenderer
|
||||||
|
from .inputs import DictPrompt
|
||||||
|
from .inputs.preprocess import parse_dec_only_prompt
|
||||||
|
from .params import ChatParams
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
_PROCESSOR_CHAT_TEMPLATES = dict[tuple[str, bool], str | None]()
|
||||||
|
"""
|
||||||
|
Used in `_try_get_processor_chat_template` to avoid calling
|
||||||
|
`cached_get_processor` again if the processor fails to be loaded.
|
||||||
|
|
||||||
|
This is needed because `lru_cache` does not cache when an exception happens.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _try_get_processor_chat_template(
|
||||||
|
tokenizer: HfTokenizer,
|
||||||
|
*,
|
||||||
|
trust_remote_code: bool,
|
||||||
|
) -> str | None:
|
||||||
|
cache_key = (tokenizer.name_or_path, trust_remote_code)
|
||||||
|
if cache_key in _PROCESSOR_CHAT_TEMPLATES:
|
||||||
|
return _PROCESSOR_CHAT_TEMPLATES[cache_key]
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
PreTrainedTokenizer,
|
||||||
|
PreTrainedTokenizerFast,
|
||||||
|
ProcessorMixin,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
processor = cached_get_processor(
|
||||||
|
tokenizer.name_or_path,
|
||||||
|
processor_cls=(
|
||||||
|
PreTrainedTokenizer,
|
||||||
|
PreTrainedTokenizerFast,
|
||||||
|
ProcessorMixin,
|
||||||
|
),
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
isinstance(processor, ProcessorMixin)
|
||||||
|
and hasattr(processor, "chat_template")
|
||||||
|
and (chat_template := processor.chat_template) is not None
|
||||||
|
):
|
||||||
|
_PROCESSOR_CHAT_TEMPLATES[cache_key] = chat_template
|
||||||
|
return chat_template
|
||||||
|
except Exception:
|
||||||
|
logger.debug(
|
||||||
|
"Failed to load AutoProcessor chat template for %s",
|
||||||
|
tokenizer.name_or_path,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
_PROCESSOR_CHAT_TEMPLATES[cache_key] = None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_chat_template(
|
||||||
|
tokenizer: HfTokenizer,
|
||||||
|
chat_template: str | None,
|
||||||
|
tools: list[dict[str, Any]] | None,
|
||||||
|
*,
|
||||||
|
model_config: "ModelConfig",
|
||||||
|
) -> str | None:
|
||||||
|
# 1st priority: The given chat template
|
||||||
|
if chat_template is not None:
|
||||||
|
# Resolve template names (e.g. "tool_use") to actual Jinja content
|
||||||
|
# so that downstream kwargs detection can parse template variables.
|
||||||
|
return tokenizer.get_chat_template(chat_template, tools=tools)
|
||||||
|
|
||||||
|
# 2nd priority: AutoProcessor chat template, unless tool calling is enabled
|
||||||
|
if tools is None:
|
||||||
|
chat_template = _try_get_processor_chat_template(
|
||||||
|
tokenizer,
|
||||||
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
|
)
|
||||||
|
if chat_template is not None:
|
||||||
|
return chat_template
|
||||||
|
|
||||||
|
# 3rd priority: AutoTokenizer chat template
|
||||||
|
try:
|
||||||
|
return tokenizer.get_chat_template(chat_template, tools=tools)
|
||||||
|
except Exception:
|
||||||
|
logger.debug(
|
||||||
|
"Failed to load AutoTokenizer chat template for %s",
|
||||||
|
tokenizer.name_or_path,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4th priority: Predefined fallbacks
|
||||||
|
path = get_chat_template_fallback_path(
|
||||||
|
model_type=model_config.hf_config.model_type,
|
||||||
|
tokenizer_name_or_path=tokenizer.name_or_path,
|
||||||
|
)
|
||||||
|
if path is not None:
|
||||||
|
logger.info_once(
|
||||||
|
"Loading chat template fallback for %s as there isn't one "
|
||||||
|
"defined on HF Hub.",
|
||||||
|
tokenizer.name_or_path,
|
||||||
|
)
|
||||||
|
chat_template = load_chat_template(path)
|
||||||
|
else:
|
||||||
|
logger.debug_once(
|
||||||
|
"There is no chat template fallback for %s", tokenizer.name_or_path
|
||||||
|
)
|
||||||
|
|
||||||
|
return chat_template
|
||||||
|
|
||||||
|
|
||||||
|
def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
|
||||||
|
if isinstance(node, jinja2.nodes.Name):
|
||||||
|
return node.ctx == "load" and node.name == varname
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
|
||||||
|
if isinstance(node, jinja2.nodes.Getitem):
|
||||||
|
return (
|
||||||
|
_is_var_access(node.node, varname)
|
||||||
|
and isinstance(node.arg, jinja2.nodes.Const)
|
||||||
|
and node.arg.value == key
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(node, jinja2.nodes.Getattr):
|
||||||
|
return _is_var_access(node.node, varname) and node.attr == key
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_var_or_elems_access(
|
||||||
|
node: jinja2.nodes.Node,
|
||||||
|
varname: str,
|
||||||
|
key: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
if isinstance(node, jinja2.nodes.Filter):
|
||||||
|
return node.node is not None and _is_var_or_elems_access(
|
||||||
|
node.node, varname, key
|
||||||
|
)
|
||||||
|
if isinstance(node, jinja2.nodes.Test):
|
||||||
|
return _is_var_or_elems_access(node.node, varname, key)
|
||||||
|
|
||||||
|
if isinstance(node, jinja2.nodes.Getitem) and isinstance(
|
||||||
|
node.arg, jinja2.nodes.Slice
|
||||||
|
):
|
||||||
|
return _is_var_or_elems_access(node.node, varname, key)
|
||||||
|
|
||||||
|
return _is_attr_access(node, varname, key) if key else _is_var_access(node, varname)
|
||||||
|
|
||||||
|
|
||||||
|
def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
|
||||||
|
# Global variable that is implicitly defined at the root
|
||||||
|
yield root, varname
|
||||||
|
|
||||||
|
# Iterative BFS
|
||||||
|
related_varnames = deque([varname])
|
||||||
|
while related_varnames:
|
||||||
|
related_varname = related_varnames.popleft()
|
||||||
|
|
||||||
|
for assign_ast in root.find_all(jinja2.nodes.Assign):
|
||||||
|
lhs = assign_ast.target
|
||||||
|
rhs = assign_ast.node
|
||||||
|
|
||||||
|
if _is_var_or_elems_access(rhs, related_varname):
|
||||||
|
assert isinstance(lhs, jinja2.nodes.Name)
|
||||||
|
yield assign_ast, lhs.name
|
||||||
|
|
||||||
|
# Avoid infinite looping for self-assignment
|
||||||
|
if lhs.name != related_varname:
|
||||||
|
related_varnames.append(lhs.name)
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: The proper way to handle this is to build a CFG so that we can handle
|
||||||
|
# the scope in which each variable is defined, but that is too complicated
|
||||||
|
def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
|
||||||
|
messages_varnames = [
|
||||||
|
varname for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
|
||||||
|
]
|
||||||
|
|
||||||
|
# Search for {%- for message in messages -%} loops
|
||||||
|
for loop_ast in root.find_all(jinja2.nodes.For):
|
||||||
|
loop_iter = loop_ast.iter
|
||||||
|
loop_target = loop_ast.target
|
||||||
|
|
||||||
|
for varname in messages_varnames:
|
||||||
|
if _is_var_or_elems_access(loop_iter, varname):
|
||||||
|
assert isinstance(loop_target, jinja2.nodes.Name)
|
||||||
|
yield loop_ast, loop_target.name
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
|
||||||
|
message_varnames = [
|
||||||
|
varname for _, varname in _iter_nodes_assign_messages_item(root)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Search for {%- for content in message['content'] -%} loops
|
||||||
|
for loop_ast in root.find_all(jinja2.nodes.For):
|
||||||
|
loop_iter = loop_ast.iter
|
||||||
|
loop_target = loop_ast.target
|
||||||
|
|
||||||
|
for varname in message_varnames:
|
||||||
|
if _is_var_or_elems_access(loop_iter, varname, "content"):
|
||||||
|
assert isinstance(loop_target, jinja2.nodes.Name)
|
||||||
|
yield loop_ast, loop_target.name
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def _try_extract_ast(chat_template: str) -> jinja2.nodes.Template | None:
|
||||||
|
import transformers.utils.chat_template_utils as hf_chat_utils
|
||||||
|
|
||||||
|
try:
|
||||||
|
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
|
||||||
|
return jinja_compiled.environment.parse(chat_template)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error when compiling Jinja template")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=32)
|
||||||
|
def _detect_content_format(
|
||||||
|
chat_template: str,
|
||||||
|
*,
|
||||||
|
default: ChatTemplateContentFormat,
|
||||||
|
) -> ChatTemplateContentFormat:
|
||||||
|
jinja_ast = _try_extract_ast(chat_template)
|
||||||
|
if jinja_ast is None:
|
||||||
|
return default
|
||||||
|
|
||||||
|
try:
|
||||||
|
next(_iter_nodes_assign_content_item(jinja_ast))
|
||||||
|
except StopIteration:
|
||||||
|
return "string"
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error when parsing AST of Jinja template")
|
||||||
|
return default
|
||||||
|
else:
|
||||||
|
return "openai"
|
||||||
|
|
||||||
|
|
||||||
|
def _is_glm_model(tokenizer: HfTokenizer, model_config: "ModelConfig") -> bool:
|
||||||
|
"""Check if this is a GLM model that requires string content format.
|
||||||
|
|
||||||
|
GLM models (GLM-4, GLM-4.5, GLM-5.x) have a chat template that incorrectly
|
||||||
|
triggers "openai" content format detection because they iterate over
|
||||||
|
m.content for tool responses. However, the template expects string content
|
||||||
|
for tool messages (checking `m.content is string`).
|
||||||
|
|
||||||
|
This detection ensures we force "string" format for GLM models.
|
||||||
|
"""
|
||||||
|
# Check tokenizer name/path for GLM indicators
|
||||||
|
name_or_path = tokenizer.name_or_path.lower()
|
||||||
|
glm_indicators = ["glm-4", "glm-5", "glm4", "glm5", "zai-org/glm"]
|
||||||
|
if any(ind in name_or_path for ind in glm_indicators):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check model type in config
|
||||||
|
if hasattr(model_config, "hf_config") and hasattr(model_config.hf_config, "model_type"):
|
||||||
|
model_type = model_config.hf_config.model_type.lower()
|
||||||
|
if "glm" in model_type:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_chat_template_content_format(
|
||||||
|
chat_template: str | None,
|
||||||
|
tools: list[dict[str, Any]] | None,
|
||||||
|
tokenizer: HfTokenizer,
|
||||||
|
*,
|
||||||
|
model_config: "ModelConfig",
|
||||||
|
) -> ChatTemplateContentFormat:
|
||||||
|
# GLM models require "string" content format for tool responses to work
|
||||||
|
# The template has `{% for tr in m.content %}` which triggers "openai"
|
||||||
|
# detection, but then checks `m.content is string` which fails for arrays.
|
||||||
|
if _is_glm_model(tokenizer, model_config):
|
||||||
|
logger.debug(
|
||||||
|
"Forcing 'string' content format for GLM model: %s",
|
||||||
|
tokenizer.name_or_path,
|
||||||
|
)
|
||||||
|
return "string"
|
||||||
|
|
||||||
|
resolved_chat_template = resolve_chat_template(
|
||||||
|
tokenizer,
|
||||||
|
chat_template=chat_template,
|
||||||
|
tools=tools,
|
||||||
|
model_config=model_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
jinja_text = (
|
||||||
|
resolved_chat_template
|
||||||
|
if isinstance(resolved_chat_template, str)
|
||||||
|
else load_chat_template(chat_template, is_literal=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
detected_format = (
|
||||||
|
"string"
|
||||||
|
if jinja_text is None
|
||||||
|
else _detect_content_format(jinja_text, default="string")
|
||||||
|
)
|
||||||
|
|
||||||
|
return detected_format
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def _log_chat_template_content_format(
|
||||||
|
chat_template: str | None, # For caching purposes
|
||||||
|
given_format: ChatTemplateContentFormatOption,
|
||||||
|
detected_format: ChatTemplateContentFormatOption,
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"Detected the chat template content format to be '%s'. "
|
||||||
|
"You can set `--chat-template-content-format` to override this.",
|
||||||
|
detected_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
if given_format != "auto" and given_format != detected_format:
|
||||||
|
logger.warning(
|
||||||
|
"You specified `--chat-template-content-format %s` "
|
||||||
|
"which is different from the detected format '%s'. "
|
||||||
|
"If our automatic detection is incorrect, please consider "
|
||||||
|
"opening a GitHub issue so that we can improve it: "
|
||||||
|
"https://github.com/vllm-project/vllm/issues/new/choose",
|
||||||
|
given_format,
|
||||||
|
detected_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_chat_template_content_format(
|
||||||
|
chat_template: str | None,
|
||||||
|
tools: list[dict[str, Any]] | None,
|
||||||
|
given_format: ChatTemplateContentFormatOption,
|
||||||
|
tokenizer: HfTokenizer,
|
||||||
|
*,
|
||||||
|
model_config: "ModelConfig",
|
||||||
|
) -> ChatTemplateContentFormat:
|
||||||
|
if given_format != "auto":
|
||||||
|
return given_format
|
||||||
|
|
||||||
|
detected_format = _resolve_chat_template_content_format(
|
||||||
|
chat_template,
|
||||||
|
tools,
|
||||||
|
tokenizer,
|
||||||
|
model_config=model_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
_log_chat_template_content_format(
|
||||||
|
chat_template,
|
||||||
|
given_format=given_format,
|
||||||
|
detected_format=detected_format,
|
||||||
|
)
|
||||||
|
|
||||||
|
return detected_format
|
||||||
|
|
||||||
|
|
||||||
|
# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
|
||||||
|
# only preserve the parse function used to resolve chat template kwargs
|
||||||
|
class AssistantTracker(jinja2.ext.Extension):
|
||||||
|
tags = {"generation"}
|
||||||
|
|
||||||
|
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.Node:
|
||||||
|
lineno = next(parser.stream).lineno
|
||||||
|
body = parser.parse_statements(("name:endgeneration",), drop_needle=True)
|
||||||
|
call = self.call_method("_generation_support")
|
||||||
|
call_block = jinja2.nodes.CallBlock(call, [], [], body)
|
||||||
|
return call_block.set_lineno(lineno)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_chat_template_kwargs(chat_template: str) -> Set[str]:
|
||||||
|
env = jinja2.sandbox.ImmutableSandboxedEnvironment(
|
||||||
|
trim_blocks=True,
|
||||||
|
lstrip_blocks=True,
|
||||||
|
extensions=[AssistantTracker, jinja2.ext.loopcontrols],
|
||||||
|
)
|
||||||
|
parsed_content = env.parse(chat_template)
|
||||||
|
template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
|
||||||
|
return template_vars
|
||||||
|
|
||||||
|
|
||||||
|
_cached_resolve_chat_template_kwargs = lru_cache(_resolve_chat_template_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def _get_hf_base_chat_template_params() -> frozenset[str]:
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
# Get standard parameters from HuggingFace's base tokenizer class.
|
||||||
|
# This dynamically extracts parameters from PreTrainedTokenizer's
|
||||||
|
# apply_chat_template method, ensuring compatibility with tokenizers
|
||||||
|
# that use **kwargs to receive standard parameters.
|
||||||
|
|
||||||
|
# Read signature from HF's base class - the single source of truth
|
||||||
|
base_sig = inspect.signature(PreTrainedTokenizer.apply_chat_template)
|
||||||
|
|
||||||
|
# Exclude VAR_KEYWORD (**kwargs) and VAR_POSITIONAL (*args) placeholders
|
||||||
|
return frozenset(
|
||||||
|
p.name
|
||||||
|
for p in base_sig.parameters.values()
|
||||||
|
if p.kind
|
||||||
|
not in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_chat_template_kwargs(
|
||||||
|
tokenizer: HfTokenizer,
|
||||||
|
chat_template: str,
|
||||||
|
chat_template_kwargs: dict[str, Any],
|
||||||
|
raise_on_unexpected: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
# We exclude chat_template from kwargs here, because
|
||||||
|
# chat template has been already resolved at this stage
|
||||||
|
unexpected_vars = {"chat_template", "tokenize"}
|
||||||
|
if raise_on_unexpected and (
|
||||||
|
unexpected_in_kwargs := unexpected_vars & chat_template_kwargs.keys()
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Found unexpected chat template kwargs from request: "
|
||||||
|
f"{unexpected_in_kwargs}"
|
||||||
|
)
|
||||||
|
|
||||||
|
fn_kw = {
|
||||||
|
k
|
||||||
|
for k in chat_template_kwargs
|
||||||
|
if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
|
||||||
|
}
|
||||||
|
template_vars = _cached_resolve_chat_template_kwargs(chat_template)
|
||||||
|
|
||||||
|
# Allow standard HF parameters even if tokenizer uses **kwargs to receive them
|
||||||
|
hf_base_params = _get_hf_base_chat_template_params()
|
||||||
|
|
||||||
|
accept_vars = (fn_kw | template_vars | hf_base_params) - unexpected_vars
|
||||||
|
return {k: v for k, v in chat_template_kwargs.items() if k in accept_vars}
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def safe_apply_chat_template(
|
||||||
|
model_config: "ModelConfig",
|
||||||
|
tokenizer: HfTokenizer,
|
||||||
|
conversation: list[ConversationMessage],
|
||||||
|
*,
|
||||||
|
tools: list[dict[str, Any]] | None = ...,
|
||||||
|
chat_template: str | None = ...,
|
||||||
|
tokenize: Literal[True] = ...,
|
||||||
|
**kwargs,
|
||||||
|
) -> list[int]: ...
|
||||||
|
@overload
|
||||||
|
def safe_apply_chat_template(
|
||||||
|
model_config: "ModelConfig",
|
||||||
|
tokenizer: HfTokenizer,
|
||||||
|
conversation: list[ConversationMessage],
|
||||||
|
*,
|
||||||
|
tools: list[dict[str, Any]] | None = ...,
|
||||||
|
chat_template: str | None = ...,
|
||||||
|
tokenize: Literal[False] = ...,
|
||||||
|
**kwargs,
|
||||||
|
) -> str: ...
|
||||||
|
def safe_apply_chat_template(
|
||||||
|
model_config: "ModelConfig",
|
||||||
|
tokenizer: HfTokenizer,
|
||||||
|
conversation: list[ConversationMessage],
|
||||||
|
*,
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
chat_template: str | None = None,
|
||||||
|
tokenize: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
) -> str | list[int]:
|
||||||
|
chat_template = resolve_chat_template(
|
||||||
|
tokenizer,
|
||||||
|
chat_template=chat_template,
|
||||||
|
tools=tools,
|
||||||
|
model_config=model_config,
|
||||||
|
)
|
||||||
|
if chat_template is None:
|
||||||
|
raise ChatTemplateResolutionError(
|
||||||
|
"As of transformers v4.44, default chat template is no longer "
|
||||||
|
"allowed, so you must provide a chat template if the tokenizer "
|
||||||
|
"does not define one."
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved_kwargs = resolve_chat_template_kwargs(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
chat_template=chat_template,
|
||||||
|
chat_template_kwargs=kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return tokenizer.apply_chat_template(
|
||||||
|
conversation=conversation, # type: ignore[arg-type]
|
||||||
|
tools=tools, # type: ignore[arg-type]
|
||||||
|
chat_template=chat_template,
|
||||||
|
tokenize=tokenize,
|
||||||
|
**resolved_kwargs,
|
||||||
|
)
|
||||||
|
# External library exceptions can sometimes occur despite the framework's
|
||||||
|
# internal exception management capabilities.
|
||||||
|
except Exception as e:
|
||||||
|
# Log and report any library-related exceptions for further
|
||||||
|
# investigation.
|
||||||
|
logger.exception(
|
||||||
|
"An error occurred in `transformers` while applying chat template"
|
||||||
|
)
|
||||||
|
raise ValueError(str(e)) from e
|
||||||
|
|
||||||
|
|
||||||
|
def rebuild_mm_uuids_from_mm_data(
|
||||||
|
mm_uuids: MultiModalUUIDDict,
|
||||||
|
mm_data: MultiModalDataDict,
|
||||||
|
) -> MultiModalUUIDDict:
|
||||||
|
"""Rebuild mm_uuids after vision_chunk processing.
|
||||||
|
|
||||||
|
When videos are split into chunks, the original UUIDs need to be updated
|
||||||
|
to reflect the new UUIDs generated for each chunk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mm_uuids: Original UUIDs dictionary
|
||||||
|
mm_data: Processed multimodal data with vision_chunk items
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated UUIDs dictionary with chunk UUIDs
|
||||||
|
"""
|
||||||
|
vision_chunks = mm_data.get("vision_chunk")
|
||||||
|
if vision_chunks is None:
|
||||||
|
return mm_uuids
|
||||||
|
|
||||||
|
assert all(isinstance(item, dict) for item in vision_chunks), (
|
||||||
|
"Expected all vision_chunk items to be dicts"
|
||||||
|
)
|
||||||
|
vision_chunks = cast(list[dict[str, Any]], vision_chunks)
|
||||||
|
vision_chunk_uuids = [
|
||||||
|
uuid_val for item in vision_chunks if (uuid_val := item.get("uuid")) is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
if vision_chunk_uuids:
|
||||||
|
mm_uuids = dict(mm_uuids)
|
||||||
|
mm_uuids["vision_chunk"] = vision_chunk_uuids
|
||||||
|
|
||||||
|
return mm_uuids
|
||||||
|
|
||||||
|
|
||||||
|
def build_video_prompts_from_mm_data(
|
||||||
|
mm_data: MultiModalDataDict,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Build video prompts from vision_chunk data.
|
||||||
|
|
||||||
|
Collects prompts from video chunks and groups them by video_idx.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mm_data: Processed multimodal data with vision_chunk items
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of video prompts, one per video.
|
||||||
|
"""
|
||||||
|
vision_chunks = mm_data.get("vision_chunk")
|
||||||
|
if vision_chunks is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Group chunks by video_idx
|
||||||
|
video_prompts_dict: dict[int, list[str]] = defaultdict(list)
|
||||||
|
|
||||||
|
for item in vision_chunks:
|
||||||
|
# vision_chunk items are always dicts (VisionChunkImage/VisionChunkVideo)
|
||||||
|
assert isinstance(item, dict)
|
||||||
|
if item.get("type") == "video_chunk":
|
||||||
|
video_idx = item.get("video_idx", 0)
|
||||||
|
prompt = item.get("prompt", "")
|
||||||
|
video_prompts_dict[video_idx].append(prompt)
|
||||||
|
|
||||||
|
# Build prompts in video order
|
||||||
|
video_prompts = [
|
||||||
|
"".join(video_prompts_dict[video_idx])
|
||||||
|
for video_idx in sorted(video_prompts_dict.keys())
|
||||||
|
]
|
||||||
|
|
||||||
|
return video_prompts
|
||||||
|
|
||||||
|
|
||||||
|
def replace_vision_chunk_video_placeholder(
|
||||||
|
prompt_raw: str | list[int],
|
||||||
|
mm_data: MultiModalDataDict,
|
||||||
|
video_placeholder: str | None,
|
||||||
|
) -> str | list[int]:
|
||||||
|
# get video placeholder, replace it with runtime video-chunk prompts
|
||||||
|
if video_placeholder and isinstance(prompt_raw, str):
|
||||||
|
video_prompts = build_video_prompts_from_mm_data(mm_data)
|
||||||
|
|
||||||
|
# replace in order
|
||||||
|
prompt_raw_parts = prompt_raw.split(video_placeholder)
|
||||||
|
if len(prompt_raw_parts) == len(video_prompts) + 1:
|
||||||
|
prompt_raw = "".join(
|
||||||
|
itertools.chain.from_iterable(zip(prompt_raw_parts, video_prompts))
|
||||||
|
)
|
||||||
|
prompt_raw += prompt_raw_parts[-1]
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Number of video placeholders (%d) does not match "
|
||||||
|
"number of videos (%d) in the request.",
|
||||||
|
len(prompt_raw_parts) - 1,
|
||||||
|
len(video_prompts),
|
||||||
|
)
|
||||||
|
return prompt_raw
|
||||||
|
|
||||||
|
|
||||||
|
class HfRenderer(BaseRenderer[HfTokenizer]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: VllmConfig,
|
||||||
|
tokenizer: HfTokenizer | None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(config, tokenizer)
|
||||||
|
|
||||||
|
self.use_unified_vision_chunk = getattr(
|
||||||
|
config.model_config.hf_config, "use_unified_vision_chunk", False
|
||||||
|
)
|
||||||
|
|
||||||
|
self._apply_chat_template_async = make_async(
|
||||||
|
safe_apply_chat_template, executor=self._executor
|
||||||
|
)
|
||||||
|
|
||||||
|
def render_messages(
|
||||||
|
self,
|
||||||
|
messages: list[ChatCompletionMessageParam],
|
||||||
|
params: ChatParams,
|
||||||
|
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||||
|
model_config = self.model_config
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
|
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||||
|
messages,
|
||||||
|
model_config,
|
||||||
|
content_format=resolve_chat_template_content_format(
|
||||||
|
chat_template=params.chat_template,
|
||||||
|
tools=params.chat_template_kwargs.get("tools"),
|
||||||
|
given_format=params.chat_template_content_format,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
model_config=model_config,
|
||||||
|
),
|
||||||
|
media_io_kwargs=params.media_io_kwargs,
|
||||||
|
mm_processor_kwargs=params.mm_processor_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_raw = safe_apply_chat_template(
|
||||||
|
model_config,
|
||||||
|
tokenizer,
|
||||||
|
conversation,
|
||||||
|
**params.get_apply_chat_template_kwargs(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: use_unified_vision_chunk is currently specific to Kimi-K2.5
|
||||||
|
# model which uses unified vision chunks for both images and videos.
|
||||||
|
if (
|
||||||
|
self.use_unified_vision_chunk
|
||||||
|
and mm_uuids is not None
|
||||||
|
and mm_data is not None
|
||||||
|
):
|
||||||
|
mm_uuids = rebuild_mm_uuids_from_mm_data(mm_uuids, mm_data)
|
||||||
|
|
||||||
|
# get video placeholder, replace it with runtime video-chunk prompts
|
||||||
|
video_placeholder = getattr(
|
||||||
|
model_config.hf_config, "video_placeholder", None
|
||||||
|
)
|
||||||
|
prompt_raw = cast(
|
||||||
|
list[int],
|
||||||
|
replace_vision_chunk_video_placeholder(
|
||||||
|
prompt_raw,
|
||||||
|
mm_data,
|
||||||
|
video_placeholder,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = parse_dec_only_prompt(prompt_raw)
|
||||||
|
if mm_data is not None:
|
||||||
|
prompt["multi_modal_data"] = mm_data
|
||||||
|
if mm_uuids is not None:
|
||||||
|
prompt["multi_modal_uuids"] = mm_uuids
|
||||||
|
|
||||||
|
return conversation, prompt
|
||||||
|
|
||||||
|
async def render_messages_async(
|
||||||
|
self,
|
||||||
|
messages: list[ChatCompletionMessageParam],
|
||||||
|
params: ChatParams,
|
||||||
|
) -> tuple[list[ConversationMessage], DictPrompt]:
|
||||||
|
model_config = self.model_config
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
|
||||||
|
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
|
||||||
|
messages,
|
||||||
|
model_config,
|
||||||
|
content_format=resolve_chat_template_content_format(
|
||||||
|
chat_template=params.chat_template,
|
||||||
|
tools=params.chat_template_kwargs.get("tools"),
|
||||||
|
given_format=params.chat_template_content_format,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
model_config=model_config,
|
||||||
|
),
|
||||||
|
media_io_kwargs=params.media_io_kwargs,
|
||||||
|
mm_processor_kwargs=params.mm_processor_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_raw = await self._apply_chat_template_async(
|
||||||
|
model_config,
|
||||||
|
tokenizer,
|
||||||
|
conversation,
|
||||||
|
**params.get_apply_chat_template_kwargs(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: use_unified_vision_chunk is currently specific to Kimi-K2.5
|
||||||
|
# model which uses unified vision chunks for both images and videos.
|
||||||
|
if (
|
||||||
|
self.use_unified_vision_chunk
|
||||||
|
and mm_uuids is not None
|
||||||
|
and mm_data is not None
|
||||||
|
):
|
||||||
|
# get video placeholder, replace it with runtime video-chunk prompts
|
||||||
|
video_placeholder = getattr(
|
||||||
|
model_config.hf_config, "video_placeholder", None
|
||||||
|
)
|
||||||
|
prompt_raw = cast(
|
||||||
|
list[int],
|
||||||
|
replace_vision_chunk_video_placeholder(
|
||||||
|
prompt_raw,
|
||||||
|
mm_data,
|
||||||
|
video_placeholder,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = parse_dec_only_prompt(prompt_raw)
|
||||||
|
if mm_data is not None:
|
||||||
|
prompt["multi_modal_data"] = mm_data
|
||||||
|
if mm_uuids is not None:
|
||||||
|
prompt["multi_modal_uuids"] = mm_uuids
|
||||||
|
|
||||||
|
return conversation, prompt
|
||||||
Reference in New Issue
Block a user