SmolLM3-3B tool call fix: template bugs found and patched
This commit is contained in:
103
NOTES.md
Normal file
103
NOTES.md
Normal file
@@ -0,0 +1,103 @@
|
||||
# SmolLM3-3B Tool Call Fix — Notes
|
||||
|
||||
## Problem
|
||||
|
||||
The SmolLM3-3B model's chat template has three bugs that break multi-turn tool calling in vLLM.
|
||||
|
||||
## Bugs Found
|
||||
|
||||
### Bug 1: Tool responses rendered as plain user messages
|
||||
**Location:** `chat_template.jinja`, main loop, `message.role == "tool"` branch
|
||||
|
||||
**Original:**
|
||||
```jinja2
|
||||
{%- elif message.role == "tool" -%}
|
||||
{{ "<|im_start|>" + "user\n" + content + "<|im_end|>\n" }}
|
||||
```
|
||||
|
||||
Tool responses show up as `<|im_start|>user\n...<|im_end|>` — the model cannot distinguish a tool result from a new user turn. When it sees weather data in a user message, it re-invokes the tool instead of answering.
|
||||
|
||||
**Fix:** Use the model's dedicated `tool_response_start`/`tool_response_end` tokens (128013/128014) to wrap tool responses so the model can distinguish them from user messages.
|
||||
|
||||
### Bug 2: Assistant tool_calls not rendered in history
|
||||
**Location:** `chat_template.jinja`, main loop, `message.role == "assistant"` branch
|
||||
|
||||
When the assistant message has `tool_calls`, the template only renders `content` (often empty/None) and drops the entire `tool_calls` array. The model never sees its own prior tool invocations.
|
||||
|
||||
**Fix:** Render tool calls using the model's native `tool_call_start`/`tool_call_end` tokens (128015/128016) with proper JSON format.
|
||||
|
||||
### Bug 3: Thinking mode inverted
|
||||
**Location:** `chat_template.jinja`, main loop and generation prompt
|
||||
|
||||
When `reasoning_mode == "/think"`, the template does NOT wrap content in think tags. When `reasoning_mode == "/no_think"`, it DOES wrap in `...` tags. Completely backwards.
|
||||
|
||||
**Fix:** `/think` mode wraps content in `...` tags. `/no_think` renders plain text.
|
||||
|
||||
## Special Tokens
|
||||
|
||||
The model has these tool-related tokens in its tokenizer (added_tokens_decoder):
|
||||
|
||||
| Token ID | Text | Purpose |
|
||||
|----------|------|---------|
|
||||
| 128002 | `...` | Think end |
|
||||
| 128013 | `...` | Tool call start |
|
||||
| 128016 | `...` | Tool call end |
|
||||
|
||||
## How the Fix Works
|
||||
|
||||
### Template Changes
|
||||
|
||||
1. **Tool responses** now render as:
|
||||
```
|
||||
<|im_start|>user
|
||||
[tool_response_start]
|
||||
{tool result content}
|
||||
[tool_response_end]<|im_end|>
|
||||
```
|
||||
Instead of a bare user message.
|
||||
|
||||
2. **Assistant tool calls** now render as:
|
||||
```
|
||||
<|im_start|>assistant
|
||||
{"name": "func_name", "arguments": {...}}
|
||||
[tool_call_end]<|im_end|>
|
||||
```
|
||||
Instead of being dropped entirely.
|
||||
|
||||
3. **Thinking mode** is now correctly mapped: `/think` → think tags, `/no_think` → plain text.
|
||||
|
||||
### Key Technical Details
|
||||
|
||||
- The template uses Jinja2's `~` operator instead of `+` for string concatenation. This avoids type errors when `message.content` is `None` (Jinja2's `~` coerces to string, `+` does not).
|
||||
- The `tool_call_start`/`tool_call_end` tokens are Unicode private-use-area characters that can't be typed in a text editor. The template must be generated programmatically using `gen_template.py`.
|
||||
- The `tc.function.name` and `tc.function.arguments` Jinja2 dot notation works correctly because Jinja2 resolves `dict.key` as `dict["key"]`.
|
||||
- The `{% generation %}` tag is vLLM-specific and marks the assistant output region. It must be preserved.
|
||||
|
||||
## Files
|
||||
|
||||
- `model-files/chat_template.jinja` — The fixed template (generated, contains Unicode PUA characters)
|
||||
- `model-files/gen_template.py` — Script to regenerate the template inside the container where the tokenizer is available
|
||||
- `model-files/hermes_tool_parser.py` — vLLM Hermes tool parser (unchanged, works as-is for parsing `...` format)
|
||||
|
||||
## Deploying
|
||||
|
||||
1. Run `gen_template.py` inside the vLLM container:
|
||||
```bash
|
||||
docker cp model-files/gen_template.py smol-vllm-1:/tmp/
|
||||
docker exec smol-vllm-1 python3 /tmp/gen_template.py
|
||||
```
|
||||
|
||||
2. Copy the generated template to the mounted volume:
|
||||
```bash
|
||||
docker cp smol-vllm-1:/root/chat_template.jinja /root/smol/chat_template.jinja
|
||||
```
|
||||
|
||||
3. Restart the container:
|
||||
```bash
|
||||
cd /root/smol && docker compose restart
|
||||
```
|
||||
|
||||
## Remaining Issues
|
||||
|
||||
- The model sometimes re-invokes tools in a loop instead of providing a final text answer. This is likely a training issue with the `/no_think` mode — the model outputs reasoning as content text but still generates tool calls.
|
||||
- The Hermes tool parser works for parsing `...` blocks but the streaming parser may buffer long argument strings. This is a vLLM-level issue, not a template issue.
|
||||
102
model-files/chat_template.jinja
Normal file
102
model-files/chat_template.jinja
Normal file
@@ -0,0 +1,102 @@
|
||||
{# ───── defaults ───── #}
|
||||
{%- if enable_thinking is not defined -%}
|
||||
{%- set enable_thinking = true -%}
|
||||
{%- endif -%}
|
||||
|
||||
{# ───── reasoning mode ───── #}
|
||||
{%- if enable_thinking -%}
|
||||
{%- set reasoning_mode = "/think" -%}
|
||||
{%- else -%}
|
||||
{%- set reasoning_mode = "/no_think" -%}
|
||||
{%- endif -%}
|
||||
|
||||
{# ───── header (system message) ───── #}
|
||||
{{- "<|im_start|>system\n" -}}
|
||||
|
||||
{%- if messages[0].role == "system" -%}
|
||||
{%- set system_message = messages[0].content -%}
|
||||
{%- if "/no_think" in system_message -%}
|
||||
{%- set reasoning_mode = "/no_think" -%}
|
||||
{%- elif "/think" in system_message -%}
|
||||
{%- set reasoning_mode = "/think" -%}
|
||||
{%- endif -%}
|
||||
{%- set custom_instructions = system_message.replace("/no_think", "").replace("/think", "").rstrip() -%}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if "/system_override" in system_message -%}
|
||||
{{- custom_instructions.replace("/system_override", "").rstrip() -}}
|
||||
{%- else -%}
|
||||
{{- "## Metadata\n\n" -}}
|
||||
{{- "Knowledge Cutoff Date: June 2025\n" -}}
|
||||
{%- set today = strftime_now("%d %B %Y") -%}
|
||||
{{- "Today Date: " ~ today ~ "\n" -}}
|
||||
{{- "Reasoning Mode: " + reasoning_mode + "\n\n" -}}
|
||||
|
||||
{{- "## Custom Instructions\n\n" -}}
|
||||
{%- if custom_instructions -%}
|
||||
{{- custom_instructions + "\n\n" -}}
|
||||
{%- elif reasoning_mode == "/think" -%}
|
||||
{{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n" -}}
|
||||
{%- else -%}
|
||||
{{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n" -}}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if xml_tools or python_tools or tools -%}
|
||||
{{- "### Tools\n\n" -}}
|
||||
{%- if xml_tools or tools -%}
|
||||
{%- if tools -%}
|
||||
{%- set xml_tools = tools -%}
|
||||
{%- endif -%}
|
||||
{%- set ns = namespace(xml_tool_string="You may call one or more functions to assist with the user query.\nYou are provided with function signatures within <tools></tools> XML tags:\n\n<tools>\n") -%}
|
||||
{%- for tool in xml_tools[:] -%}
|
||||
{%- set ns.xml_tool_string = ns.xml_tool_string ~ (tool | tojson) ~ "\n" -%}
|
||||
{%- endfor -%}
|
||||
{%- set xml_tool_string = ns.xml_tool_string + "</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>" -%}
|
||||
{{- xml_tool_string -}}
|
||||
{%- endif -%}
|
||||
{%- if python_tools -%}
|
||||
{%- set ns = namespace(python_tool_string="You may call one or more functions as python tools.\n<tools>\n") -%}
|
||||
{%- for tool in python_tools[:] -%}
|
||||
{%- set ns.python_tool_string = ns.python_tool_string ~ (tool | string) ~ "\n" -%}
|
||||
{%- endfor -%}
|
||||
{%- set python_tool_string = ns.python_tool_string + "</tools>\n\nThe state persists between code executions." -%}
|
||||
{{- python_tool_string -}}
|
||||
{%- endif -%}
|
||||
{{- "\n\n" -}}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{{- "<|im_end|>\n" -}}
|
||||
|
||||
{# ───── main loop ───── #}
|
||||
{%- for message in messages -%}
|
||||
{%- if message.role == "user" -%}
|
||||
{{ "<|im_start|>user\n" + message.content + "<|im_end|>\n" }}
|
||||
{%- elif message.role == "assistant" -%}
|
||||
{% generation %}
|
||||
{%- if message.tool_calls -%}
|
||||
{%- set ns = namespace(tc_text="") -%}
|
||||
{%- for tc in message.tool_calls -%}
|
||||
{%- set ns.tc_text = ns.tc_text ~ "<tool_call>\n{\"name\": \"" ~ tc.function.name ~ "\", \"arguments\": " ~ tc.function.arguments ~ "}\n</tool_call>" -%}
|
||||
{%- endfor -%}
|
||||
{{ "<|im_start|>assistant\n" ~ (message.content if message.content is string else "") ~ ns.tc_text ~ "<|im_end|>\n" }}
|
||||
{%- else -%}
|
||||
{%- if reasoning_mode == "/think" -%}
|
||||
{{ "<|im_start|>assistant\n<think>\n" ~ (message.content if message.content is string else "") ~ "\n</think><|im_end|>\n" }}
|
||||
{%- else -%}
|
||||
{{ "<|im_start|>assistant\n" ~ (message.content if message.content is string else "") ~ "<|im_end|>\n" }}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
{% endgeneration %}
|
||||
{%- elif message.role == "tool" -%}
|
||||
{{ "<|im_start|>user\n<tool_response>\n" ~ (message.content if message.content is string else "") ~ "\n</tool_response><|im_end|>\n" }}
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{# ───── generation prompt ───── #}
|
||||
{%- if add_generation_prompt -%}
|
||||
{%- if reasoning_mode == "/think" -%}
|
||||
{{ "<|im_start|>assistant\n<think>\n" }}
|
||||
{%- else -%}
|
||||
{{ "<|im_start|>assistant\n" }}
|
||||
{%- endif -%}
|
||||
{%- endif -%}
|
||||
134
model-files/gen_template.py
Normal file
134
model-files/gen_template.py
Normal file
@@ -0,0 +1,134 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate the PRODUCTION fixed chat_template.jinja for SmolLM3-3B.
|
||||
Uses a hybrid approach: raw strings for Jinja2, concat for special tokens."""
|
||||
from transformers import AutoTokenizer
|
||||
tok = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM3-3B")
|
||||
|
||||
THINK_S = tok.decode([128002])
|
||||
THINK_E = tok.decode([128003])
|
||||
RESP_S = tok.decode([128013])
|
||||
RESP_E = tok.decode([128014])
|
||||
TC_S = tok.decode([128015])
|
||||
TC_E = tok.decode([128016])
|
||||
|
||||
# Build the template as a list of chunks
|
||||
# Key: any line with special tokens gets its own append with string concat
|
||||
T = []
|
||||
|
||||
# ─── System header (no special tokens) ───
|
||||
T.append('{# ───── defaults ───── #}\n')
|
||||
T.append('{%- if enable_thinking is not defined -%}\n')
|
||||
T.append('{%- set enable_thinking = true -%}\n')
|
||||
T.append('{%- endif -%}\n\n')
|
||||
T.append('{# ───── reasoning mode ───── #}\n')
|
||||
T.append('{%- if enable_thinking -%}\n')
|
||||
T.append(' {%- set reasoning_mode = "/think" -%}\n')
|
||||
T.append('{%- else -%}\n')
|
||||
T.append(' {%- set reasoning_mode = "/no_think" -%}\n')
|
||||
T.append('{%- endif -%}\n\n')
|
||||
T.append('{# ───── header (system message) ───── #}\n')
|
||||
T.append('{{- "<|im_start|>system\\n" -}}\n\n')
|
||||
T.append('{%- if messages[0].role == "system" -%}\n')
|
||||
T.append(' {%- set system_message = messages[0].content -%}\n')
|
||||
T.append(' {%- if "/no_think" in system_message -%}\n')
|
||||
T.append(' {%- set reasoning_mode = "/no_think" -%}\n')
|
||||
T.append(' {%- elif "/think" in system_message -%}\n')
|
||||
T.append(' {%- set reasoning_mode = "/think" -%}\n')
|
||||
T.append(' {%- endif -%}\n')
|
||||
T.append(' {%- set custom_instructions = system_message.replace("/no_think", "").replace("/think", "").rstrip() -%}\n')
|
||||
T.append('{%- endif -%}\n\n')
|
||||
T.append('{%- if "/system_override" in system_message -%}\n')
|
||||
T.append(' {{- custom_instructions.replace("/system_override", "").rstrip() -}}\n')
|
||||
T.append('{%- else -%}\n')
|
||||
T.append(' {{- "## Metadata\\n\\n" -}}\n')
|
||||
T.append(' {{- "Knowledge Cutoff Date: June 2025\\n" -}}\n')
|
||||
T.append(' {%- set today = strftime_now("%d %B %Y") -%}\n')
|
||||
T.append(' {{- "Today Date: " ~ today ~ "\\n" -}}\n')
|
||||
T.append(' {{- "Reasoning Mode: " + reasoning_mode + "\\n\\n" -}}\n\n')
|
||||
T.append(' {{- "## Custom Instructions\\n\\n" -}}\n')
|
||||
T.append(' {%- if custom_instructions -%}\n')
|
||||
T.append(' {{- custom_instructions + "\\n\\n" -}}\n')
|
||||
T.append(' {%- elif reasoning_mode == "/think" -%}\n')
|
||||
T.append(' {{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face.\\n\\n" -}}\n')
|
||||
T.append(' {%- else -%}\n')
|
||||
T.append(' {{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face.\\n\\n" -}}\n')
|
||||
T.append(' {%- endif -%}\n\n')
|
||||
T.append(' {%- if xml_tools or python_tools or tools -%}\n')
|
||||
T.append(' {{- "### Tools\\n\\n" -}}\n')
|
||||
T.append(' {%- if xml_tools or tools -%}\n')
|
||||
T.append(' {%- if tools -%}\n')
|
||||
T.append(' {%- set xml_tools = tools -%}\n')
|
||||
T.append(' {%- endif -%}\n')
|
||||
T.append(' {%- set ns = namespace(xml_tool_string="You may call one or more functions to assist with the user query.\\nYou are provided with function signatures within <tools></tools> XML tags:\\n\\n<tools>\\n") -%}\n')
|
||||
T.append(' {%- for tool in xml_tools[:] -%}\n')
|
||||
T.append(' {%- set ns.xml_tool_string = ns.xml_tool_string ~ (tool | tojson) ~ "\\n" -%}\n')
|
||||
T.append(' {%- endfor -%}\n')
|
||||
|
||||
# Tool format instruction - has special tokens
|
||||
T.append(' {%- set xml_tool_string = ns.xml_tool_string + "</tools>\\n\\nFor each function call, return a json object with function name and arguments within ' + TC_S + ' XML tags:\\n' + TC_S + '\\n{\\"name\\": <function-name>, \\"arguments\\": <args-json-object>}\\n' + TC_E + '" -%}\n')
|
||||
|
||||
T.append(' {{- xml_tool_string -}}\n')
|
||||
T.append(' {%- endif -%}\n')
|
||||
T.append(' {%- if python_tools -%}\n')
|
||||
T.append(' {%- set ns = namespace(python_tool_string="You may call one or more functions as python tools.\\n<tools>\\n") -%}\n')
|
||||
T.append(' {%- for tool in python_tools[:] -%}\n')
|
||||
T.append(' {%- set ns.python_tool_string = ns.python_tool_string ~ (tool | string) ~ "\\n" -%}\n')
|
||||
T.append(' {%- endfor -%}\n')
|
||||
T.append(' {%- set python_tool_string = ns.python_tool_string + "</tools>\\n\\nThe state persists between code executions." -%}\n')
|
||||
T.append(' {{- python_tool_string -}}\n')
|
||||
T.append(' {%- endif -%}\n')
|
||||
T.append(' {{- "\\n\\n" -}}\n')
|
||||
T.append(' {%- endif -%}\n')
|
||||
T.append('{%- endif -%}\n')
|
||||
T.append('{{- "<|im_end|>\\n" -}}\n\n')
|
||||
|
||||
# ─── Main loop ───
|
||||
T.append('{# ───── main loop ───── #}\n')
|
||||
T.append('{%- for message in messages -%}\n')
|
||||
T.append(' {%- if message.role == "user" -%}\n')
|
||||
T.append(' {{ "<|im_start|>user\\n" + message.content + "<|im_end|>\\n" }}\n')
|
||||
T.append(' {%- elif message.role == "assistant" -%}\n')
|
||||
T.append(' {% generation %}\n')
|
||||
T.append(' {%- if message.tool_calls -%}\n')
|
||||
T.append(' {%- set ns = namespace(tc_text="") -%}\n')
|
||||
T.append(' {%- for tc in message.tool_calls -%}\n')
|
||||
|
||||
# FIX: Render tool calls with TC_S/TC_E tokens using ~ (Jinja2 concat)
|
||||
T.append(' {%- set ns.tc_text = ns.tc_text ~ "' + TC_S + '\\n{\\"name\\": \\"" ~ tc.function.name ~ "\\", \\"arguments\\": " ~ tc.function.arguments ~ "}\\n' + TC_E + '" -%}\n')
|
||||
|
||||
T.append(' {%- endfor -%}\n')
|
||||
T.append(' {{ "<|im_start|>assistant\\n" ~ (message.content if message.content is string else "") ~ ns.tc_text ~ "<|im_end|>\\n" }}\n')
|
||||
T.append(' {%- else -%}\n')
|
||||
|
||||
# FIX: /think = use think tags, /no_think = plain text (was inverted in original)
|
||||
T.append(' {%- if reasoning_mode == "/think" -%}\n')
|
||||
T.append(' {{ "<|im_start|>assistant\\n' + THINK_S + '\\n" ~ (message.content if message.content is string else "") ~ "\\n' + THINK_E + '<|im_end|>\\n" }}\n')
|
||||
T.append(' {%- else -%}\n')
|
||||
T.append(' {{ "<|im_start|>assistant\\n" ~ (message.content if message.content is string else "") ~ "<|im_end|>\\n" }}\n')
|
||||
T.append(' {%- endif -%}\n')
|
||||
T.append(' {%- endif -%}\n')
|
||||
T.append(' {% endgeneration %}\n')
|
||||
|
||||
# FIX: Tool role with RESP_S/RESP_E tokens
|
||||
T.append(' {%- elif message.role == "tool" -%}\n')
|
||||
T.append(' {{ "<|im_start|>user\\n' + RESP_S + '\\n" ~ (message.content if message.content is string else "") ~ "\\n' + RESP_E + '<|im_end|>\\n" }}\n')
|
||||
T.append(' {%- endif -%}\n')
|
||||
T.append('{%- endfor -%}\n\n')
|
||||
|
||||
# ─── Generation prompt ───
|
||||
T.append('{# ───── generation prompt ───── #}\n')
|
||||
T.append('{%- if add_generation_prompt -%}\n')
|
||||
T.append(' {%- if reasoning_mode == "/think" -%}\n')
|
||||
T.append(' {{ "<|im_start|>assistant\\n' + THINK_S + '\\n" }}\n')
|
||||
T.append(' {%- else -%}\n')
|
||||
T.append(' {{ "<|im_start|>assistant\\n" }}\n')
|
||||
T.append(' {%- endif -%}\n')
|
||||
T.append('{%- endif -%}\n')
|
||||
|
||||
template = ''.join(T)
|
||||
|
||||
with open('/root/chat_template.jinja', 'w', encoding='utf-8') as f:
|
||||
f.write(template)
|
||||
|
||||
print("Production template written to /root/chat_template.jinja")
|
||||
print(f"Length: {len(template)} bytes")
|
||||
296
model-files/hermes_tool_parser.py
Normal file
296
model-files/hermes_tool_parser.py
Normal file
@@ -0,0 +1,296 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
Tool,
|
||||
ToolParser,
|
||||
)
|
||||
from vllm.utils.mistral import is_mistral_tokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _partial_tag_overlap(text: str, tag: str) -> int:
|
||||
"""Length of the longest prefix of `tag` that matches a suffix of `text`.
|
||||
|
||||
E.g. text ending in "<tool_" returns 6 when tag is "<tool_call>".
|
||||
Returns 0 if there is no overlap.
|
||||
"""
|
||||
max_check = min(len(tag) - 1, len(text))
|
||||
for k in range(max_check, 0, -1):
|
||||
if text.endswith(tag[:k]):
|
||||
return k
|
||||
return 0
|
||||
|
||||
|
||||
def _is_valid_json(text: str) -> bool:
|
||||
try:
|
||||
json.loads(text)
|
||||
return True
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return False
|
||||
|
||||
|
||||
class Hermes2ProToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
|
||||
super().__init__(tokenizer, tools)
|
||||
|
||||
if is_mistral_tokenizer(tokenizer):
|
||||
logger.error("Detected Mistral tokenizer when using a Hermes model")
|
||||
self.model_tokenizer = tokenizer.tokenizer
|
||||
|
||||
self.tool_call_start_token: str = "<tool_call>"
|
||||
self.tool_call_end_token: str = "</tool_call>"
|
||||
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*)", re.DOTALL
|
||||
)
|
||||
self.scratch_pad_regex = re.compile(
|
||||
r"<scratch_pad>(.*?)</scratch_pad>", re.DOTALL
|
||||
)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
# Streaming state: what has been sent to the client.
|
||||
self._sent_content_idx: int = 0
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest | ResponsesRequest
|
||||
) -> ChatCompletionRequest | ResponsesRequest:
|
||||
request = super().adjust_request(request)
|
||||
if request.tools and request.tool_choice != "none":
|
||||
# do not skip special tokens because the tool_call tokens are
|
||||
# marked "special" in some models. Since they are skipped
|
||||
# prior to the call to the tool parser, it breaks tool calling.
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
# sanity check; avoid unnecessary processing
|
||||
if self.tool_call_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
# there are two possible captures - between tags, or between a
|
||||
# tag and end-of-string so the result of
|
||||
# findall is an array of tuples where one is a function call and
|
||||
# the other is None
|
||||
function_call_tuples = self.tool_call_regex.findall(model_output)
|
||||
|
||||
# load the JSON, and then use it to build the Function and
|
||||
# Tool Call
|
||||
raw_function_calls = [
|
||||
json.loads(match[0] if match[0] else match[1])
|
||||
for match in function_call_tuples
|
||||
]
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=function_call["name"],
|
||||
# function call args are JSON but as a string
|
||||
arguments=json.dumps(
|
||||
function_call["arguments"], ensure_ascii=False
|
||||
),
|
||||
),
|
||||
)
|
||||
for function_call in raw_function_calls
|
||||
]
|
||||
|
||||
content = model_output[: model_output.find(self.tool_call_start_token)]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
def _extract_content(self, current_text: str) -> str | None:
|
||||
"""Return unsent non-tool-call text, or None.
|
||||
|
||||
Holds back any suffix that could be a partial <tool_call> tag.
|
||||
"""
|
||||
if self.tool_call_start_token not in current_text:
|
||||
overlap_length = _partial_tag_overlap(
|
||||
current_text, self.tool_call_start_token
|
||||
)
|
||||
sendable_idx = len(current_text) - overlap_length
|
||||
else:
|
||||
sendable_idx = current_text.index(self.tool_call_start_token)
|
||||
|
||||
if sendable_idx > self._sent_content_idx:
|
||||
content = current_text[self._sent_content_idx : sendable_idx]
|
||||
self._sent_content_idx = sendable_idx
|
||||
return content
|
||||
return None
|
||||
|
||||
def _extract_tool_call_jsons(self, text: str) -> list[tuple[str, bool]]:
|
||||
"""Extract (json_text, is_complete) for each <tool_call> region."""
|
||||
results: list[tuple[str, bool]] = []
|
||||
pos = 0
|
||||
while True:
|
||||
start = text.find(self.tool_call_start_token, pos)
|
||||
if start == -1:
|
||||
break
|
||||
json_start = start + len(self.tool_call_start_token)
|
||||
json_end = text.find(self.tool_call_end_token, json_start)
|
||||
if json_end != -1:
|
||||
results.append((text[json_start:json_end].strip(), True))
|
||||
pos = json_end + len(self.tool_call_end_token)
|
||||
else:
|
||||
raw = text[json_start:]
|
||||
# Strip partial </tool_call> suffix if present.
|
||||
overlap = _partial_tag_overlap(raw, self.tool_call_end_token)
|
||||
if overlap:
|
||||
raw = raw[:-overlap]
|
||||
tc_json = raw.strip()
|
||||
# Valid JSON without closing tag = complete body,
|
||||
# tag tokens just haven't arrived yet.
|
||||
is_complete = _is_valid_json(tc_json) if tc_json else False
|
||||
results.append((tc_json, is_complete))
|
||||
break
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_name(tc_json: str) -> str | None:
|
||||
"""Extract tool name, or None if the name isn't complete yet."""
|
||||
match = re.search(r'"name"\s*:\s*"([^"]+)"', tc_json)
|
||||
return match.group(1) if match else None
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_args(tc_json: str, is_complete: bool) -> str | None:
|
||||
"""Extract tool arguments from the tool call JSON.
|
||||
|
||||
Given {"name": "f", "arguments": {"x": 1}}, returns '{"x": 1}'.
|
||||
When is_complete, strips the trailing '}' that closes the outer
|
||||
object (not the arguments). For partial JSON, returns as-is.
|
||||
"""
|
||||
match = re.search(r'"arguments"\s*:\s*', tc_json)
|
||||
if not match:
|
||||
return None
|
||||
raw = tc_json[match.end() :]
|
||||
if is_complete:
|
||||
raw = raw.rstrip()
|
||||
if raw.endswith("}"):
|
||||
raw = raw[:-1].rstrip()
|
||||
return raw
|
||||
|
||||
def _compute_args_diff(
|
||||
self, index: int, tc_json: str, is_complete: bool
|
||||
) -> str | None:
|
||||
"""Return new argument text not yet sent for tool `index`, or None."""
|
||||
args = self._extract_tool_args(tc_json, is_complete)
|
||||
if args is None or len(args) <= len(self.streamed_args_for_tool[index]):
|
||||
return None
|
||||
diff = args[len(self.streamed_args_for_tool[index]) :]
|
||||
self.streamed_args_for_tool[index] = args
|
||||
self.prev_tool_call_arr[index]["arguments"] = args
|
||||
return diff
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
"""Incrementally stream tool call deltas from accumulated output.
|
||||
|
||||
On each invocation, re-parses the full ``current_text`` to find
|
||||
``<tool_call>`` regions, then diffs against previously sent state
|
||||
to emit only new content, tool names, or argument fragments.
|
||||
|
||||
Returns a ``DeltaMessage`` containing either plain content (for
|
||||
text preceding any tool call) or one or more ``DeltaToolCall``
|
||||
entries, or ``None`` if there is nothing new to send yet."""
|
||||
try:
|
||||
# Extract any content before tool calls.
|
||||
content = self._extract_content(current_text)
|
||||
tool_call_jsons = self._extract_tool_call_jsons(current_text)
|
||||
tool_call_deltas: list[DeltaToolCall] = []
|
||||
|
||||
for i, (tc_json, is_complete) in enumerate(tool_call_jsons):
|
||||
if i >= len(self.prev_tool_call_arr):
|
||||
self.prev_tool_call_arr.append({})
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
# Stream back tool name.
|
||||
if "name" not in self.prev_tool_call_arr[i]:
|
||||
name = self._extract_tool_name(tc_json)
|
||||
if not name:
|
||||
# Can't skip to tool i+1 if i isn't ready
|
||||
break
|
||||
self.prev_tool_call_arr[i]["name"] = name
|
||||
tool_call_deltas.append(
|
||||
DeltaToolCall(
|
||||
index=i,
|
||||
type="function",
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(name=name).model_dump(
|
||||
exclude_none=True
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Stream back new tool args by diffing against what was sent.
|
||||
args_diff = self._compute_args_diff(i, tc_json, is_complete)
|
||||
if args_diff:
|
||||
tool_call_deltas.append(
|
||||
DeltaToolCall(
|
||||
index=i,
|
||||
function=DeltaFunctionCall(arguments=args_diff).model_dump(
|
||||
exclude_none=True
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if content or tool_call_deltas:
|
||||
return DeltaMessage(
|
||||
content=content,
|
||||
tool_calls=tool_call_deltas,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
return None
|
||||
1
requirements.txt
Normal file
1
requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
httpx>=0.25.0
|
||||
19
run_tests.sh
Normal file
19
run_tests.sh
Normal file
@@ -0,0 +1,19 @@
|
||||
#!/bin/bash
|
||||
# Run the streaming tool call tests
|
||||
|
||||
set -e
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
|
||||
# Default values
|
||||
export VLLM_API_BASE="${VLLM_API_BASE:-http://95.179.247.150/v1}"
|
||||
export VLLM_API_KEY="${VLLM_API_KEY:-none}"
|
||||
export VLLM_MODEL="${VLLM_MODEL:-HuggingFaceTB/SmolLM3-3B}"
|
||||
|
||||
echo "Configuration:"
|
||||
echo " API_BASE: $VLLM_API_BASE"
|
||||
echo " MODEL: $VLLM_MODEL"
|
||||
echo ""
|
||||
|
||||
# Run the test
|
||||
python3 "$SCRIPT_DIR/test_streaming_tool_calls.py"
|
||||
386
test_streaming_tool_calls.py
Normal file
386
test_streaming_tool_calls.py
Normal file
@@ -0,0 +1,386 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test suite for vLLM GLM-5.1 streaming tool calls.
|
||||
|
||||
Reproduces the issue where long string parameters in tool calls
|
||||
are buffered entirely before being emitted during streaming.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
import httpx
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
# Configuration - will be set via environment or direct assignment
|
||||
API_BASE = os.environ.get("VLLM_API_BASE", "http://95.179.247.150/v1")
|
||||
API_KEY = os.environ.get("VLLM_API_KEY", "none")
|
||||
MODEL = os.environ.get("VLLM_MODEL", "HuggingFaceTB/SmolLM3-3B")
|
||||
|
||||
|
||||
def timestamp():
|
||||
return datetime.now().strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
|
||||
def test_streaming_tool_call_with_code():
|
||||
"""
|
||||
Test streaming a tool call with a long string parameter.
|
||||
|
||||
This prompts the model to generate code via a tool call,
|
||||
which should stream incrementally if the patch works correctly.
|
||||
"""
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "write_file",
|
||||
"description": "Write content to a file. Use this to save code, text, or other content.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "Name of the file to write"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The content to write to the file"
|
||||
}
|
||||
},
|
||||
"required": ["filename", "content"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Write a Python implementation of a binary search tree with insert, search, and delete methods. Include docstrings and type hints. Save it to bst.py using the write_file tool."
|
||||
}
|
||||
]
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"TEST: Streaming tool call with long string parameter")
|
||||
print(f"API: {API_BASE}")
|
||||
print(f"Model: {MODEL}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Track streaming events
|
||||
chunks_received = []
|
||||
first_chunk_time = None
|
||||
last_chunk_time = None
|
||||
tool_call_chunks = []
|
||||
accumulated_content = ""
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
with httpx.Client(timeout=120.0) as client:
|
||||
with client.stream(
|
||||
"POST",
|
||||
f"{API_BASE}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": MODEL,
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"tool_choice": "auto",
|
||||
"stream": True,
|
||||
"max_tokens": 4096
|
||||
}
|
||||
) as response:
|
||||
print(f"[{timestamp()}] Response status: {response.status_code}")
|
||||
|
||||
for line in response.iter_lines():
|
||||
if not line or line == "data: [DONE]":
|
||||
continue
|
||||
|
||||
if line.startswith("data: "):
|
||||
chunk_data = line[6:]
|
||||
try:
|
||||
chunk = json.loads(chunk_data)
|
||||
|
||||
if first_chunk_time is None:
|
||||
first_chunk_time = time.time()
|
||||
print(f"\n[{timestamp()}] FIRST CHUNK RECEIVED ({first_chunk_time - start_time:.3f}s)")
|
||||
|
||||
last_chunk_time = time.time()
|
||||
chunks_received.append(chunk)
|
||||
|
||||
# Extract delta content
|
||||
if chunk.get("choices"):
|
||||
delta = chunk["choices"][0].get("delta", {})
|
||||
|
||||
# Check for tool calls in delta
|
||||
if delta.get("tool_calls"):
|
||||
for tc in delta["tool_calls"]:
|
||||
tc_index = tc.get("index", 0)
|
||||
tc_function = tc.get("function", {})
|
||||
|
||||
if tc_function.get("name"):
|
||||
print(f"\n[{timestamp()}] Tool call name: {tc_function['name']}")
|
||||
|
||||
if tc_function.get("arguments"):
|
||||
args_chunk = tc_function["arguments"]
|
||||
tool_call_chunks.append(args_chunk)
|
||||
accumulated_content += args_chunk
|
||||
|
||||
# Print progress every ~500 chars
|
||||
if len(accumulated_content) % 500 < len(args_chunk):
|
||||
print(f"[{timestamp()}] Accumulated {len(accumulated_content)} chars...")
|
||||
|
||||
# Regular content
|
||||
if delta.get("content"):
|
||||
print(f"[{timestamp()}] Content chunk: {delta['content'][:50]}...")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"[{timestamp()}] JSON decode error: {e}")
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# Summary
|
||||
print(f"\n{'='*60}")
|
||||
print("SUMMARY")
|
||||
print(f"{'='*60}")
|
||||
print(f"Total chunks received: {len(chunks_received)}")
|
||||
print(f"Total time: {end_time - start_time:.3f}s")
|
||||
|
||||
if first_chunk_time:
|
||||
print(f"Time to first chunk: {first_chunk_time - start_time:.3f}s")
|
||||
|
||||
if tool_call_chunks:
|
||||
print(f"Tool call chunks: {len(tool_call_chunks)}")
|
||||
print(f"Total tool call content: {len(accumulated_content)} chars")
|
||||
|
||||
# Try to parse the accumulated arguments
|
||||
print(f"\nAttempting to parse tool call arguments...")
|
||||
try:
|
||||
args = json.loads(accumulated_content)
|
||||
print(f"Successfully parsed!")
|
||||
print(f" - filename: {args.get('filename', 'N/A')}")
|
||||
print(f" - content length: {len(args.get('content', ''))} chars")
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Failed to parse: {e}")
|
||||
print(f"Raw accumulated content (first 500 chars):\n{accumulated_content[:500]}")
|
||||
|
||||
# Verdict
|
||||
print(f"\n{'='*60}")
|
||||
if len(tool_call_chunks) > 1:
|
||||
print("✓ PASS: Tool call arguments arrived in multiple chunks")
|
||||
print(f" Chunks: {len(tool_call_chunks)}, indicating incremental streaming")
|
||||
elif len(tool_call_chunks) == 1 and len(accumulated_content) > 1000:
|
||||
print("✗ FAIL: Tool call arguments arrived in a single chunk")
|
||||
print(" This indicates buffering, not true streaming")
|
||||
else:
|
||||
print("? INCONCLUSIVE: Not enough data or no tool call occurred")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
return {
|
||||
"chunks_received": len(chunks_received),
|
||||
"tool_call_chunks": len(tool_call_chunks),
|
||||
"accumulated_length": len(accumulated_content),
|
||||
"total_time": end_time - start_time
|
||||
}
|
||||
|
||||
|
||||
def test_streaming_tool_call_with_json():
|
||||
"""
|
||||
Test streaming a tool call that returns structured JSON data.
|
||||
"""
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "save_config",
|
||||
"description": "Save a configuration object",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"description": "Configuration object with many fields"
|
||||
}
|
||||
},
|
||||
"required": ["config"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Create a detailed configuration for a web server with the following sections: server (host, port, ssl), logging (level, format, outputs), cache (enabled, ttl, max_size), rate_limiting (enabled, requests_per_minute, burst), cors (enabled, origins, methods, headers), security (headers, csp, hsts). Use the save_config tool."
|
||||
}
|
||||
]
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"TEST: Streaming tool call with nested JSON")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
tool_call_chunks = []
|
||||
accumulated_content = ""
|
||||
start_time = time.time()
|
||||
|
||||
with httpx.Client(timeout=120.0) as client:
|
||||
with client.stream(
|
||||
"POST",
|
||||
f"{API_BASE}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": MODEL,
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"tool_choice": "auto",
|
||||
"stream": True,
|
||||
"max_tokens": 2048
|
||||
}
|
||||
) as response:
|
||||
for line in response.iter_lines():
|
||||
if not line or line == "data: [DONE]":
|
||||
continue
|
||||
|
||||
if line.startswith("data: "):
|
||||
try:
|
||||
chunk = json.loads(line[6:])
|
||||
if chunk.get("choices"):
|
||||
delta = chunk["choices"][0].get("delta", {})
|
||||
if delta.get("tool_calls"):
|
||||
for tc in delta["tool_calls"]:
|
||||
if tc.get("function", {}).get("arguments"):
|
||||
args_chunk = tc["function"]["arguments"]
|
||||
tool_call_chunks.append(args_chunk)
|
||||
accumulated_content += args_chunk
|
||||
print(f"[{timestamp()}] Chunk {len(tool_call_chunks)}: +{len(args_chunk)} chars (total: {len(accumulated_content)})")
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Total chunks: {len(tool_call_chunks)}, Total content: {len(accumulated_content)} chars")
|
||||
print(f"Time: {end_time - start_time:.3f}s")
|
||||
|
||||
if len(tool_call_chunks) > 1:
|
||||
print("✓ PASS: Arguments streamed in multiple chunks")
|
||||
elif len(tool_call_chunks) == 1:
|
||||
print("✗ FAIL: Arguments arrived in single chunk (buffered)")
|
||||
else:
|
||||
print("? No tool call occurred")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
def test_non_streaming_tool_call():
|
||||
"""
|
||||
Baseline test: non-streaming tool call for comparison.
|
||||
"""
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "write_file",
|
||||
"description": "Write content to a file",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filename": {"type": "string"},
|
||||
"content": {"type": "string"}
|
||||
},
|
||||
"required": ["filename", "content"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Write a simple Python hello world and save it using the write_file tool."
|
||||
}
|
||||
]
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"TEST: Non-streaming tool call (baseline)")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
with httpx.Client(timeout=120.0) as client:
|
||||
response = client.post(
|
||||
f"{API_BASE}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": MODEL,
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"tool_choice": "auto",
|
||||
"stream": False,
|
||||
"max_tokens": 1024
|
||||
}
|
||||
)
|
||||
|
||||
result = response.json()
|
||||
end_time = time.time()
|
||||
|
||||
print(f"Status: {response.status_code}")
|
||||
print(f"Time: {end_time - start_time:.3f}s")
|
||||
|
||||
if result.get("choices"):
|
||||
message = result["choices"][0].get("message", {})
|
||||
if message.get("tool_calls"):
|
||||
for tc in message["tool_calls"]:
|
||||
print(f"Tool: {tc['function']['name']}")
|
||||
args = json.loads(tc["function"]["arguments"])
|
||||
print(f"Arguments parsed successfully")
|
||||
print(f" - filename: {args.get('filename')}")
|
||||
print(f" - content length: {len(args.get('content', ''))}")
|
||||
else:
|
||||
print("No tool call in response")
|
||||
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
def main():
|
||||
print("\n" + "="*60)
|
||||
print("vLLM GLM-5.1 Streaming Tool Call Tests")
|
||||
print("="*60)
|
||||
|
||||
# Check API connectivity
|
||||
print(f"\nChecking API at {API_BASE}...")
|
||||
try:
|
||||
with httpx.Client(timeout=10.0) as client:
|
||||
response = client.get(f"{API_BASE.replace('/v1', '')}/health")
|
||||
print(f"Health check: {response.status_code}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not reach API - {e}")
|
||||
|
||||
# Run tests
|
||||
print("\nRunning tests...\n")
|
||||
|
||||
# Test 1: Non-streaming baseline
|
||||
test_non_streaming_tool_call()
|
||||
|
||||
# Test 2: Streaming with nested JSON
|
||||
test_streaming_tool_call_with_json()
|
||||
|
||||
# Test 3: Main test - streaming with long code
|
||||
result = test_streaming_tool_call_with_code()
|
||||
|
||||
print("\nAll tests complete.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
234
test_tool_diagnosis.py
Normal file
234
test_tool_diagnosis.py
Normal file
@@ -0,0 +1,234 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Focused test to diagnose GLM-5.1 tool response issue.
|
||||
|
||||
The issue: Model sees tool response as blank.
|
||||
"""
|
||||
|
||||
import httpx
|
||||
import json
|
||||
|
||||
API_BASE = "http://95.179.247.150/v1"
|
||||
API_KEY = "whatever"
|
||||
MODEL = "HuggingFaceTB/SmolLM3-3B"
|
||||
|
||||
|
||||
def test_simple_tool_response():
|
||||
"""
|
||||
Minimal test: Send a tool response and see if the model can use it.
|
||||
"""
|
||||
|
||||
# Simulate a conversation where a tool was called
|
||||
messages = [
|
||||
{"role": "user", "content": "Call the test function"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": "call_123",
|
||||
"type": "function",
|
||||
"function": {"name": "test_func", "arguments": "{}"}
|
||||
}]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"content": "SUCCESS: The function returned value 42"
|
||||
}
|
||||
]
|
||||
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_func",
|
||||
"description": "A test function",
|
||||
"parameters": {"type": "object", "properties": {}}
|
||||
}
|
||||
}]
|
||||
|
||||
print("=" * 60)
|
||||
print("Request messages:")
|
||||
print(json.dumps(messages, indent=2))
|
||||
print("=" * 60)
|
||||
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
# Non-streaming to get full response
|
||||
response = client.post(
|
||||
f"{API_BASE}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": MODEL,
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"stream": False,
|
||||
"max_tokens": 256
|
||||
}
|
||||
)
|
||||
|
||||
result = response.json()
|
||||
|
||||
print("\nFull response:")
|
||||
print(json.dumps(result, indent=2))
|
||||
|
||||
if result.get("choices"):
|
||||
content = result["choices"][0].get("message", {}).get("content", "")
|
||||
print("\n" + "=" * 60)
|
||||
print("Model response content:")
|
||||
print(content)
|
||||
print("=" * 60)
|
||||
|
||||
# Check if the tool result is referenced
|
||||
if "42" in content:
|
||||
print("\n✓ PASS: Model referenced the tool result (42)")
|
||||
else:
|
||||
print("\n✗ FAIL: Model did NOT reference the tool result (42)")
|
||||
|
||||
# Check for signs the model didn't see the result
|
||||
if "don't have" in content.lower() or "cannot access" in content.lower():
|
||||
print("✗ Model indicates it cannot see tool result")
|
||||
|
||||
|
||||
def test_without_tools_param():
|
||||
"""
|
||||
Test what happens if we don't pass tools in the follow-up request.
|
||||
Some APIs need tools to be passed on every request.
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Call the test function"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": "call_123",
|
||||
"type": "function",
|
||||
"function": {"name": "test_func", "arguments": "{}"}
|
||||
}]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"content": "SUCCESS: The function returned value 42"
|
||||
}
|
||||
]
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Test WITHOUT tools param in follow-up")
|
||||
print("=" * 60)
|
||||
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
response = client.post(
|
||||
f"{API_BASE}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": MODEL,
|
||||
"messages": messages,
|
||||
# No tools param
|
||||
"stream": False,
|
||||
"max_tokens": 256
|
||||
}
|
||||
)
|
||||
|
||||
result = response.json()
|
||||
|
||||
if result.get("choices"):
|
||||
content = result["choices"][0].get("message", {}).get("content", "")
|
||||
print("Model response:", content[:200])
|
||||
|
||||
if "42" in content:
|
||||
print("✓ Model referenced the tool result")
|
||||
|
||||
|
||||
def test_different_content_formats():
|
||||
"""
|
||||
Test if the issue is with how content is formatted.
|
||||
"""
|
||||
|
||||
# Test 1: String content (standard)
|
||||
messages_string = [
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": "call_123",
|
||||
"type": "function",
|
||||
"function": {"name": "calc", "arguments": "{}"}
|
||||
}]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"content": "The answer is 4"
|
||||
}
|
||||
]
|
||||
|
||||
# Test 2: Content as array (OpenAI format)
|
||||
messages_array = [
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": "call_123",
|
||||
"type": "function",
|
||||
"function": {"name": "calc", "arguments": "{}"}
|
||||
}]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_123",
|
||||
"content": [{"type": "text", "text": "The answer is 4"}]
|
||||
}
|
||||
]
|
||||
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "calc",
|
||||
"description": "Calculator",
|
||||
"parameters": {"type": "object", "properties": {}}
|
||||
}
|
||||
}]
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Test: String content vs Array content")
|
||||
print("=" * 60)
|
||||
|
||||
with httpx.Client(timeout=60.0) as client:
|
||||
for name, msgs in [("String content", messages_string), ("Array content", messages_array)]:
|
||||
print(f"\n--- {name} ---")
|
||||
response = client.post(
|
||||
f"{API_BASE}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": MODEL,
|
||||
"messages": msgs,
|
||||
"tools": tools,
|
||||
"stream": False,
|
||||
"max_tokens": 128
|
||||
}
|
||||
)
|
||||
|
||||
result = response.json()
|
||||
if result.get("choices"):
|
||||
content = result["choices"][0].get("message", {}).get("content", "")
|
||||
print(f"Response: {content[:150]}")
|
||||
if "4" in content:
|
||||
print("✓ Referenced tool result")
|
||||
else:
|
||||
print("✗ Did NOT reference tool result")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("GLM-5.1 Tool Response Diagnosis")
|
||||
print("=" * 60)
|
||||
|
||||
test_simple_tool_response()
|
||||
test_without_tools_param()
|
||||
test_different_content_formats()
|
||||
445
test_tool_response.py
Normal file
445
test_tool_response.py
Normal file
@@ -0,0 +1,445 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test for tool call response handling in GLM-5.1.
|
||||
|
||||
Tests the multi-turn flow:
|
||||
1. Send a prompt that triggers a tool call
|
||||
2. Send back the tool result
|
||||
3. Verify the model can see and use the tool response
|
||||
|
||||
This reproduces the issue where tool responses appear blank to the model.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import httpx
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
API_BASE = os.environ.get("VLLM_API_BASE", "http://95.179.247.150/v1")
|
||||
API_KEY = os.environ.get("VLLM_API_KEY", "none")
|
||||
MODEL = os.environ.get("VLLM_MODEL", "HuggingFaceTB/SmolLM3-3B")
|
||||
|
||||
|
||||
def timestamp():
|
||||
return datetime.now().strftime("%H:%M:%S.%f")[:-3]
|
||||
|
||||
|
||||
def test_tool_call_response_flow(streaming: bool = True):
|
||||
"""
|
||||
Test the full tool call -> response -> follow-up flow.
|
||||
|
||||
This simulates:
|
||||
1. User asks for weather
|
||||
2. Model calls get_weather tool
|
||||
3. We send back the weather data
|
||||
4. Model should see and use that data
|
||||
"""
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather for a location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "City and state, e.g. 'New York, NY'"
|
||||
}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
# Initial request that should trigger a tool call
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Tokyo right now?"
|
||||
}
|
||||
]
|
||||
|
||||
mode = "STREAMING" if streaming else "NON-STREAMING"
|
||||
print(f"\n{'='*60}")
|
||||
print(f"TEST: Tool call response flow ({mode})")
|
||||
print(f"API: {API_BASE}")
|
||||
print(f"Model: {MODEL}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
with httpx.Client(timeout=120.0) as client:
|
||||
# Step 1: Send initial request, expect tool call
|
||||
print(f"[{timestamp()}] Step 1: Sending initial request...")
|
||||
|
||||
if streaming:
|
||||
tool_calls = []
|
||||
tool_call_id = None
|
||||
tool_call_name = None
|
||||
accumulated_args = ""
|
||||
|
||||
with client.stream(
|
||||
"POST",
|
||||
f"{API_BASE}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": MODEL,
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"tool_choice": "auto",
|
||||
"stream": True,
|
||||
"max_tokens": 512
|
||||
}
|
||||
) as response:
|
||||
print(f"[{timestamp()}] Response status: {response.status_code}")
|
||||
|
||||
for line in response.iter_lines():
|
||||
if not line or line == "data: [DONE]":
|
||||
continue
|
||||
|
||||
if line.startswith("data: "):
|
||||
try:
|
||||
chunk = json.loads(line[6:])
|
||||
if chunk.get("choices"):
|
||||
delta = chunk["choices"][0].get("delta", {})
|
||||
|
||||
if delta.get("tool_calls"):
|
||||
for tc in delta["tool_calls"]:
|
||||
idx = tc.get("index", 0)
|
||||
|
||||
if tc.get("id"):
|
||||
tool_call_id = tc["id"]
|
||||
|
||||
if tc.get("function", {}).get("name"):
|
||||
tool_call_name = tc["function"]["name"]
|
||||
print(f"[{timestamp()}] Tool call: {tool_call_name}")
|
||||
|
||||
if tc.get("function", {}).get("arguments"):
|
||||
accumulated_args += tc["function"]["arguments"]
|
||||
|
||||
if delta.get("content"):
|
||||
print(f"[{timestamp()}] Content: {delta['content'][:100]}")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"[{timestamp()}] JSON error: {e}")
|
||||
|
||||
if tool_call_name:
|
||||
tool_calls.append({
|
||||
"id": tool_call_id or "call_0",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call_name,
|
||||
"arguments": accumulated_args
|
||||
}
|
||||
})
|
||||
else:
|
||||
# Non-streaming
|
||||
response = client.post(
|
||||
f"{API_BASE}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": MODEL,
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"tool_choice": "auto",
|
||||
"stream": False,
|
||||
"max_tokens": 512
|
||||
}
|
||||
)
|
||||
|
||||
result = response.json()
|
||||
print(f"[{timestamp()}] Response status: {response.status_code}")
|
||||
|
||||
tool_calls = []
|
||||
if result.get("choices"):
|
||||
message = result["choices"][0].get("message", {})
|
||||
if message.get("tool_calls"):
|
||||
tool_calls = message["tool_calls"]
|
||||
for tc in tool_calls:
|
||||
print(f"[{timestamp()}] Tool call: {tc['function']['name']}")
|
||||
print(f"[{timestamp()}] Args: {tc['function']['arguments']}")
|
||||
|
||||
# Check if we got a tool call
|
||||
if not tool_calls:
|
||||
print(f"\n[{timestamp()}] No tool call received - model didn't call the tool")
|
||||
return {"success": False, "reason": "no_tool_call"}
|
||||
|
||||
# Step 2: Parse tool call and prepare response
|
||||
tc = tool_calls[0]
|
||||
tc_id = tc.get("id", "call_0")
|
||||
tc_name = tc["function"]["name"]
|
||||
tc_args = json.loads(tc["function"]["arguments"])
|
||||
|
||||
print(f"\n[{timestamp()}] Step 2: Tool call received")
|
||||
print(f" Name: {tc_name}")
|
||||
print(f" Args: {tc_args}")
|
||||
|
||||
# Simulate tool execution
|
||||
tool_result = {
|
||||
"location": tc_args.get("location", "Unknown"),
|
||||
"temperature": "22°C",
|
||||
"condition": "Partly cloudy",
|
||||
"humidity": "65%",
|
||||
"wind": "15 km/h NE"
|
||||
}
|
||||
|
||||
# Step 3: Send the tool response back
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"tool_calls": tool_calls
|
||||
})
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc_id,
|
||||
"content": json.dumps(tool_result)
|
||||
})
|
||||
|
||||
print(f"\n[{timestamp()}] Step 3: Sending tool response...")
|
||||
print(f" Tool call ID: {tc_id}")
|
||||
print(f" Tool result: {json.dumps(tool_result, indent=2)}")
|
||||
|
||||
# Step 4: Get the model's follow-up response
|
||||
if streaming:
|
||||
final_response = ""
|
||||
print(f"\n[{timestamp()}] Step 4: Receiving model's follow-up (streaming)...")
|
||||
|
||||
with client.stream(
|
||||
"POST",
|
||||
f"{API_BASE}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": MODEL,
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"stream": True,
|
||||
"max_tokens": 512
|
||||
}
|
||||
) as response:
|
||||
for line in response.iter_lines():
|
||||
if not line or line == "data: [DONE]":
|
||||
continue
|
||||
|
||||
if line.startswith("data: "):
|
||||
try:
|
||||
chunk = json.loads(line[6:])
|
||||
if chunk.get("choices"):
|
||||
delta = chunk["choices"][0].get("delta", {})
|
||||
if delta.get("content"):
|
||||
content = delta["content"]
|
||||
final_response += content
|
||||
print(f"[{timestamp()}] Content: {content}", end="", flush=True)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
print() # newline after streaming output
|
||||
else:
|
||||
print(f"\n[{timestamp()}] Step 4: Receiving model's follow-up (non-streaming)...")
|
||||
|
||||
response = client.post(
|
||||
f"{API_BASE}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": MODEL,
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"stream": False,
|
||||
"max_tokens": 512
|
||||
}
|
||||
)
|
||||
|
||||
result = response.json()
|
||||
final_response = ""
|
||||
if result.get("choices"):
|
||||
final_response = result["choices"][0].get("message", {}).get("content", "")
|
||||
|
||||
print(f"\n[{timestamp()}] Final response:\n{final_response}")
|
||||
|
||||
# Check if the model used the tool data
|
||||
success = True
|
||||
issues = []
|
||||
|
||||
# The response should mention the weather data
|
||||
if "22" not in final_response and "22°C" not in final_response:
|
||||
issues.append("Temperature (22°C) not mentioned in response")
|
||||
success = False
|
||||
|
||||
if "cloudy" not in final_response.lower() and "partly cloudy" not in final_response.lower():
|
||||
issues.append("Condition (Partly cloudy) not mentioned in response")
|
||||
success = False
|
||||
|
||||
# Check for signs the model didn't see the data
|
||||
blank_indicators = [
|
||||
"i don't have",
|
||||
"i cannot access",
|
||||
"i'm unable to",
|
||||
"i am unable to",
|
||||
"don't have access",
|
||||
"don't have real-time",
|
||||
"cannot provide real-time"
|
||||
]
|
||||
|
||||
for indicator in blank_indicators:
|
||||
if indicator in final_response.lower():
|
||||
issues.append(f"Model seems unaware of tool result (found: '{indicator}')")
|
||||
success = False
|
||||
break
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
if success:
|
||||
print("✓ PASS: Model correctly used tool response data")
|
||||
else:
|
||||
print("✗ FAIL: Model did not use tool response correctly")
|
||||
for issue in issues:
|
||||
print(f" - {issue}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
return {
|
||||
"success": success,
|
||||
"issues": issues,
|
||||
"final_response": final_response
|
||||
}
|
||||
|
||||
|
||||
def test_tool_response_with_debug_info():
|
||||
"""
|
||||
Test with detailed logging to capture exactly what the model sees.
|
||||
"""
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_time",
|
||||
"description": "Get the current time",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"TEST: Tool response with debug info (non-streaming)")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "What time is it?"}
|
||||
]
|
||||
|
||||
with httpx.Client(timeout=120.0) as client:
|
||||
# Get tool call
|
||||
print(f"[{timestamp()}] Sending initial request...")
|
||||
response = client.post(
|
||||
f"{API_BASE}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": MODEL,
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"tool_choice": "auto",
|
||||
"stream": False,
|
||||
"max_tokens": 256
|
||||
}
|
||||
)
|
||||
|
||||
result = response.json()
|
||||
|
||||
if not result.get("choices") or not result["choices"][0].get("message", {}).get("tool_calls"):
|
||||
print("No tool call - skipping test")
|
||||
return
|
||||
|
||||
tool_call = result["choices"][0]["message"]["tool_calls"][0]
|
||||
tc_id = tool_call["id"]
|
||||
|
||||
print(f"[{timestamp()}] Tool call: {tool_call['function']['name']}")
|
||||
print(f"[{timestamp()}] Tool call ID: {tc_id}")
|
||||
|
||||
# Add tool response
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"tool_calls": [tool_call]
|
||||
})
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tc_id,
|
||||
"content": "The current time is 3:45 PM on Thursday, April 9, 2026."
|
||||
})
|
||||
|
||||
# Debug: print the full messages array we're about to send
|
||||
print(f"\n[{timestamp()}] Sending follow-up with these messages:")
|
||||
print(json.dumps(messages, indent=2))
|
||||
|
||||
# Get follow-up
|
||||
response2 = client.post(
|
||||
f"{API_BASE}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {API_KEY}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": MODEL,
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"stream": False,
|
||||
"max_tokens": 256
|
||||
}
|
||||
)
|
||||
|
||||
result2 = response2.json()
|
||||
print(f"\n[{timestamp()}] Full response:")
|
||||
print(json.dumps(result2, indent=2))
|
||||
|
||||
if result2.get("choices"):
|
||||
content = result2["choices"][0].get("message", {}).get("content", "")
|
||||
|
||||
print(f"\n[{timestamp()}] Model response content: {content}")
|
||||
|
||||
# Check if time is mentioned
|
||||
if "3:45" in content or "3:45 PM" in content:
|
||||
print("\n✓ Model used the tool response (time mentioned)")
|
||||
else:
|
||||
print("\n✗ Model may not have seen the tool response (time not mentioned)")
|
||||
|
||||
|
||||
def main():
|
||||
print("\n" + "="*60)
|
||||
print("GLM-5.1 Tool Call Response Tests")
|
||||
print("="*60)
|
||||
|
||||
# Test non-streaming first (simpler to debug)
|
||||
print("\n--- Test 1: Non-streaming tool response flow ---")
|
||||
test_tool_call_response_flow(streaming=False)
|
||||
|
||||
# Test streaming
|
||||
print("\n--- Test 2: Streaming tool response flow ---")
|
||||
test_tool_call_response_flow(streaming=True)
|
||||
|
||||
# Debug test
|
||||
print("\n--- Test 3: Debug info test ---")
|
||||
test_tool_response_with_debug_info()
|
||||
|
||||
print("\nAll tests complete.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user