387 lines
14 KiB
Python
Executable File
387 lines
14 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Test suite for vLLM GLM-5.1 streaming tool calls.
|
|
|
|
Reproduces the issue where long string parameters in tool calls
|
|
are buffered entirely before being emitted during streaming.
|
|
"""
|
|
|
|
import os
|
|
import time
|
|
import json
|
|
import httpx
|
|
from datetime import datetime
|
|
|
|
|
|
# Configuration - will be set via environment or direct assignment
|
|
API_BASE = os.environ.get("VLLM_API_BASE", "http://localhost:8000/v1")
|
|
API_KEY = os.environ.get("VLLM_API_KEY", "none")
|
|
MODEL = os.environ.get("VLLM_MODEL", "zai-org/GLM-5.1-FP8")
|
|
|
|
|
|
def timestamp():
|
|
return datetime.now().strftime("%H:%M:%S.%f")[:-3]
|
|
|
|
|
|
def test_streaming_tool_call_with_code():
|
|
"""
|
|
Test streaming a tool call with a long string parameter.
|
|
|
|
This prompts the model to generate code via a tool call,
|
|
which should stream incrementally if the patch works correctly.
|
|
"""
|
|
|
|
tools = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "write_file",
|
|
"description": "Write content to a file. Use this to save code, text, or other content.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"filename": {
|
|
"type": "string",
|
|
"description": "Name of the file to write"
|
|
},
|
|
"content": {
|
|
"type": "string",
|
|
"description": "The content to write to the file"
|
|
}
|
|
},
|
|
"required": ["filename", "content"]
|
|
}
|
|
}
|
|
}
|
|
]
|
|
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": "Write a Python implementation of a binary search tree with insert, search, and delete methods. Include docstrings and type hints. Save it to bst.py using the write_file tool."
|
|
}
|
|
]
|
|
|
|
print(f"\n{'='*60}")
|
|
print(f"TEST: Streaming tool call with long string parameter")
|
|
print(f"API: {API_BASE}")
|
|
print(f"Model: {MODEL}")
|
|
print(f"{'='*60}\n")
|
|
|
|
# Track streaming events
|
|
chunks_received = []
|
|
first_chunk_time = None
|
|
last_chunk_time = None
|
|
tool_call_chunks = []
|
|
accumulated_content = ""
|
|
|
|
start_time = time.time()
|
|
|
|
with httpx.Client(timeout=120.0) as client:
|
|
with client.stream(
|
|
"POST",
|
|
f"{API_BASE}/chat/completions",
|
|
headers={
|
|
"Authorization": f"Bearer {API_KEY}",
|
|
"Content-Type": "application/json"
|
|
},
|
|
json={
|
|
"model": MODEL,
|
|
"messages": messages,
|
|
"tools": tools,
|
|
"tool_choice": "auto",
|
|
"stream": True,
|
|
"max_tokens": 4096
|
|
}
|
|
) as response:
|
|
print(f"[{timestamp()}] Response status: {response.status_code}")
|
|
|
|
for line in response.iter_lines():
|
|
if not line or line == "data: [DONE]":
|
|
continue
|
|
|
|
if line.startswith("data: "):
|
|
chunk_data = line[6:]
|
|
try:
|
|
chunk = json.loads(chunk_data)
|
|
|
|
if first_chunk_time is None:
|
|
first_chunk_time = time.time()
|
|
print(f"\n[{timestamp()}] FIRST CHUNK RECEIVED ({first_chunk_time - start_time:.3f}s)")
|
|
|
|
last_chunk_time = time.time()
|
|
chunks_received.append(chunk)
|
|
|
|
# Extract delta content
|
|
if chunk.get("choices"):
|
|
delta = chunk["choices"][0].get("delta", {})
|
|
|
|
# Check for tool calls in delta
|
|
if delta.get("tool_calls"):
|
|
for tc in delta["tool_calls"]:
|
|
tc_index = tc.get("index", 0)
|
|
tc_function = tc.get("function", {})
|
|
|
|
if tc_function.get("name"):
|
|
print(f"\n[{timestamp()}] Tool call name: {tc_function['name']}")
|
|
|
|
if tc_function.get("arguments"):
|
|
args_chunk = tc_function["arguments"]
|
|
tool_call_chunks.append(args_chunk)
|
|
accumulated_content += args_chunk
|
|
|
|
# Print progress every ~500 chars
|
|
if len(accumulated_content) % 500 < len(args_chunk):
|
|
print(f"[{timestamp()}] Accumulated {len(accumulated_content)} chars...")
|
|
|
|
# Regular content
|
|
if delta.get("content"):
|
|
print(f"[{timestamp()}] Content chunk: {delta['content'][:50]}...")
|
|
|
|
except json.JSONDecodeError as e:
|
|
print(f"[{timestamp()}] JSON decode error: {e}")
|
|
|
|
end_time = time.time()
|
|
|
|
# Summary
|
|
print(f"\n{'='*60}")
|
|
print("SUMMARY")
|
|
print(f"{'='*60}")
|
|
print(f"Total chunks received: {len(chunks_received)}")
|
|
print(f"Total time: {end_time - start_time:.3f}s")
|
|
|
|
if first_chunk_time:
|
|
print(f"Time to first chunk: {first_chunk_time - start_time:.3f}s")
|
|
|
|
if tool_call_chunks:
|
|
print(f"Tool call chunks: {len(tool_call_chunks)}")
|
|
print(f"Total tool call content: {len(accumulated_content)} chars")
|
|
|
|
# Try to parse the accumulated arguments
|
|
print(f"\nAttempting to parse tool call arguments...")
|
|
try:
|
|
args = json.loads(accumulated_content)
|
|
print(f"Successfully parsed!")
|
|
print(f" - filename: {args.get('filename', 'N/A')}")
|
|
print(f" - content length: {len(args.get('content', ''))} chars")
|
|
except json.JSONDecodeError as e:
|
|
print(f"Failed to parse: {e}")
|
|
print(f"Raw accumulated content (first 500 chars):\n{accumulated_content[:500]}")
|
|
|
|
# Verdict
|
|
print(f"\n{'='*60}")
|
|
if len(tool_call_chunks) > 1:
|
|
print("✓ PASS: Tool call arguments arrived in multiple chunks")
|
|
print(f" Chunks: {len(tool_call_chunks)}, indicating incremental streaming")
|
|
elif len(tool_call_chunks) == 1 and len(accumulated_content) > 1000:
|
|
print("✗ FAIL: Tool call arguments arrived in a single chunk")
|
|
print(" This indicates buffering, not true streaming")
|
|
else:
|
|
print("? INCONCLUSIVE: Not enough data or no tool call occurred")
|
|
print(f"{'='*60}\n")
|
|
|
|
return {
|
|
"chunks_received": len(chunks_received),
|
|
"tool_call_chunks": len(tool_call_chunks),
|
|
"accumulated_length": len(accumulated_content),
|
|
"total_time": end_time - start_time
|
|
}
|
|
|
|
|
|
def test_streaming_tool_call_with_json():
|
|
"""
|
|
Test streaming a tool call that returns structured JSON data.
|
|
"""
|
|
|
|
tools = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "save_config",
|
|
"description": "Save a configuration object",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"config": {
|
|
"type": "object",
|
|
"description": "Configuration object with many fields"
|
|
}
|
|
},
|
|
"required": ["config"]
|
|
}
|
|
}
|
|
}
|
|
]
|
|
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": "Create a detailed configuration for a web server with the following sections: server (host, port, ssl), logging (level, format, outputs), cache (enabled, ttl, max_size), rate_limiting (enabled, requests_per_minute, burst), cors (enabled, origins, methods, headers), security (headers, csp, hsts). Use the save_config tool."
|
|
}
|
|
]
|
|
|
|
print(f"\n{'='*60}")
|
|
print(f"TEST: Streaming tool call with nested JSON")
|
|
print(f"{'='*60}\n")
|
|
|
|
tool_call_chunks = []
|
|
accumulated_content = ""
|
|
start_time = time.time()
|
|
|
|
with httpx.Client(timeout=120.0) as client:
|
|
with client.stream(
|
|
"POST",
|
|
f"{API_BASE}/chat/completions",
|
|
headers={
|
|
"Authorization": f"Bearer {API_KEY}",
|
|
"Content-Type": "application/json"
|
|
},
|
|
json={
|
|
"model": MODEL,
|
|
"messages": messages,
|
|
"tools": tools,
|
|
"tool_choice": "auto",
|
|
"stream": True,
|
|
"max_tokens": 2048
|
|
}
|
|
) as response:
|
|
for line in response.iter_lines():
|
|
if not line or line == "data: [DONE]":
|
|
continue
|
|
|
|
if line.startswith("data: "):
|
|
try:
|
|
chunk = json.loads(line[6:])
|
|
if chunk.get("choices"):
|
|
delta = chunk["choices"][0].get("delta", {})
|
|
if delta.get("tool_calls"):
|
|
for tc in delta["tool_calls"]:
|
|
if tc.get("function", {}).get("arguments"):
|
|
args_chunk = tc["function"]["arguments"]
|
|
tool_call_chunks.append(args_chunk)
|
|
accumulated_content += args_chunk
|
|
print(f"[{timestamp()}] Chunk {len(tool_call_chunks)}: +{len(args_chunk)} chars (total: {len(accumulated_content)})")
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
end_time = time.time()
|
|
|
|
print(f"\n{'='*60}")
|
|
print(f"Total chunks: {len(tool_call_chunks)}, Total content: {len(accumulated_content)} chars")
|
|
print(f"Time: {end_time - start_time:.3f}s")
|
|
|
|
if len(tool_call_chunks) > 1:
|
|
print("✓ PASS: Arguments streamed in multiple chunks")
|
|
elif len(tool_call_chunks) == 1:
|
|
print("✗ FAIL: Arguments arrived in single chunk (buffered)")
|
|
else:
|
|
print("? No tool call occurred")
|
|
print(f"{'='*60}\n")
|
|
|
|
|
|
def test_non_streaming_tool_call():
|
|
"""
|
|
Baseline test: non-streaming tool call for comparison.
|
|
"""
|
|
|
|
tools = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "write_file",
|
|
"description": "Write content to a file",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"filename": {"type": "string"},
|
|
"content": {"type": "string"}
|
|
},
|
|
"required": ["filename", "content"]
|
|
}
|
|
}
|
|
}
|
|
]
|
|
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": "Write a simple Python hello world and save it using the write_file tool."
|
|
}
|
|
]
|
|
|
|
print(f"\n{'='*60}")
|
|
print(f"TEST: Non-streaming tool call (baseline)")
|
|
print(f"{'='*60}\n")
|
|
|
|
start_time = time.time()
|
|
|
|
with httpx.Client(timeout=120.0) as client:
|
|
response = client.post(
|
|
f"{API_BASE}/chat/completions",
|
|
headers={
|
|
"Authorization": f"Bearer {API_KEY}",
|
|
"Content-Type": "application/json"
|
|
},
|
|
json={
|
|
"model": MODEL,
|
|
"messages": messages,
|
|
"tools": tools,
|
|
"tool_choice": "auto",
|
|
"stream": False,
|
|
"max_tokens": 1024
|
|
}
|
|
)
|
|
|
|
result = response.json()
|
|
end_time = time.time()
|
|
|
|
print(f"Status: {response.status_code}")
|
|
print(f"Time: {end_time - start_time:.3f}s")
|
|
|
|
if result.get("choices"):
|
|
message = result["choices"][0].get("message", {})
|
|
if message.get("tool_calls"):
|
|
for tc in message["tool_calls"]:
|
|
print(f"Tool: {tc['function']['name']}")
|
|
args = json.loads(tc["function"]["arguments"])
|
|
print(f"Arguments parsed successfully")
|
|
print(f" - filename: {args.get('filename')}")
|
|
print(f" - content length: {len(args.get('content', ''))}")
|
|
else:
|
|
print("No tool call in response")
|
|
|
|
print(f"{'='*60}\n")
|
|
|
|
|
|
def main():
|
|
print("\n" + "="*60)
|
|
print("vLLM GLM-5.1 Streaming Tool Call Tests")
|
|
print("="*60)
|
|
|
|
# Check API connectivity
|
|
print(f"\nChecking API at {API_BASE}...")
|
|
try:
|
|
with httpx.Client(timeout=10.0) as client:
|
|
response = client.get(f"{API_BASE.replace('/v1', '')}/health")
|
|
print(f"Health check: {response.status_code}")
|
|
except Exception as e:
|
|
print(f"Warning: Could not reach API - {e}")
|
|
|
|
# Run tests
|
|
print("\nRunning tests...\n")
|
|
|
|
# Test 1: Non-streaming baseline
|
|
test_non_streaming_tool_call()
|
|
|
|
# Test 2: Streaming with nested JSON
|
|
test_streaming_tool_call_with_json()
|
|
|
|
# Test 3: Main test - streaming with long code
|
|
result = test_streaming_tool_call_with_code()
|
|
|
|
print("\nAll tests complete.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|