Files
vllm-glm/tests/test_streaming_tool_calls.py
2026-04-09 04:28:22 +00:00

387 lines
14 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Test suite for vLLM GLM-5.1 streaming tool calls.
Reproduces the issue where long string parameters in tool calls
are buffered entirely before being emitted during streaming.
"""
import os
import time
import json
import httpx
from datetime import datetime
# Configuration - will be set via environment or direct assignment
API_BASE = os.environ.get("VLLM_API_BASE", "http://localhost:8000/v1")
API_KEY = os.environ.get("VLLM_API_KEY", "none")
MODEL = os.environ.get("VLLM_MODEL", "zai-org/GLM-5.1-FP8")
def timestamp():
return datetime.now().strftime("%H:%M:%S.%f")[:-3]
def test_streaming_tool_call_with_code():
"""
Test streaming a tool call with a long string parameter.
This prompts the model to generate code via a tool call,
which should stream incrementally if the patch works correctly.
"""
tools = [
{
"type": "function",
"function": {
"name": "write_file",
"description": "Write content to a file. Use this to save code, text, or other content.",
"parameters": {
"type": "object",
"properties": {
"filename": {
"type": "string",
"description": "Name of the file to write"
},
"content": {
"type": "string",
"description": "The content to write to the file"
}
},
"required": ["filename", "content"]
}
}
}
]
messages = [
{
"role": "user",
"content": "Write a Python implementation of a binary search tree with insert, search, and delete methods. Include docstrings and type hints. Save it to bst.py using the write_file tool."
}
]
print(f"\n{'='*60}")
print(f"TEST: Streaming tool call with long string parameter")
print(f"API: {API_BASE}")
print(f"Model: {MODEL}")
print(f"{'='*60}\n")
# Track streaming events
chunks_received = []
first_chunk_time = None
last_chunk_time = None
tool_call_chunks = []
accumulated_content = ""
start_time = time.time()
with httpx.Client(timeout=120.0) as client:
with client.stream(
"POST",
f"{API_BASE}/chat/completions",
headers={
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
},
json={
"model": MODEL,
"messages": messages,
"tools": tools,
"tool_choice": "auto",
"stream": True,
"max_tokens": 4096
}
) as response:
print(f"[{timestamp()}] Response status: {response.status_code}")
for line in response.iter_lines():
if not line or line == "data: [DONE]":
continue
if line.startswith("data: "):
chunk_data = line[6:]
try:
chunk = json.loads(chunk_data)
if first_chunk_time is None:
first_chunk_time = time.time()
print(f"\n[{timestamp()}] FIRST CHUNK RECEIVED ({first_chunk_time - start_time:.3f}s)")
last_chunk_time = time.time()
chunks_received.append(chunk)
# Extract delta content
if chunk.get("choices"):
delta = chunk["choices"][0].get("delta", {})
# Check for tool calls in delta
if delta.get("tool_calls"):
for tc in delta["tool_calls"]:
tc_index = tc.get("index", 0)
tc_function = tc.get("function", {})
if tc_function.get("name"):
print(f"\n[{timestamp()}] Tool call name: {tc_function['name']}")
if tc_function.get("arguments"):
args_chunk = tc_function["arguments"]
tool_call_chunks.append(args_chunk)
accumulated_content += args_chunk
# Print progress every ~500 chars
if len(accumulated_content) % 500 < len(args_chunk):
print(f"[{timestamp()}] Accumulated {len(accumulated_content)} chars...")
# Regular content
if delta.get("content"):
print(f"[{timestamp()}] Content chunk: {delta['content'][:50]}...")
except json.JSONDecodeError as e:
print(f"[{timestamp()}] JSON decode error: {e}")
end_time = time.time()
# Summary
print(f"\n{'='*60}")
print("SUMMARY")
print(f"{'='*60}")
print(f"Total chunks received: {len(chunks_received)}")
print(f"Total time: {end_time - start_time:.3f}s")
if first_chunk_time:
print(f"Time to first chunk: {first_chunk_time - start_time:.3f}s")
if tool_call_chunks:
print(f"Tool call chunks: {len(tool_call_chunks)}")
print(f"Total tool call content: {len(accumulated_content)} chars")
# Try to parse the accumulated arguments
print(f"\nAttempting to parse tool call arguments...")
try:
args = json.loads(accumulated_content)
print(f"Successfully parsed!")
print(f" - filename: {args.get('filename', 'N/A')}")
print(f" - content length: {len(args.get('content', ''))} chars")
except json.JSONDecodeError as e:
print(f"Failed to parse: {e}")
print(f"Raw accumulated content (first 500 chars):\n{accumulated_content[:500]}")
# Verdict
print(f"\n{'='*60}")
if len(tool_call_chunks) > 1:
print("✓ PASS: Tool call arguments arrived in multiple chunks")
print(f" Chunks: {len(tool_call_chunks)}, indicating incremental streaming")
elif len(tool_call_chunks) == 1 and len(accumulated_content) > 1000:
print("✗ FAIL: Tool call arguments arrived in a single chunk")
print(" This indicates buffering, not true streaming")
else:
print("? INCONCLUSIVE: Not enough data or no tool call occurred")
print(f"{'='*60}\n")
return {
"chunks_received": len(chunks_received),
"tool_call_chunks": len(tool_call_chunks),
"accumulated_length": len(accumulated_content),
"total_time": end_time - start_time
}
def test_streaming_tool_call_with_json():
"""
Test streaming a tool call that returns structured JSON data.
"""
tools = [
{
"type": "function",
"function": {
"name": "save_config",
"description": "Save a configuration object",
"parameters": {
"type": "object",
"properties": {
"config": {
"type": "object",
"description": "Configuration object with many fields"
}
},
"required": ["config"]
}
}
}
]
messages = [
{
"role": "user",
"content": "Create a detailed configuration for a web server with the following sections: server (host, port, ssl), logging (level, format, outputs), cache (enabled, ttl, max_size), rate_limiting (enabled, requests_per_minute, burst), cors (enabled, origins, methods, headers), security (headers, csp, hsts). Use the save_config tool."
}
]
print(f"\n{'='*60}")
print(f"TEST: Streaming tool call with nested JSON")
print(f"{'='*60}\n")
tool_call_chunks = []
accumulated_content = ""
start_time = time.time()
with httpx.Client(timeout=120.0) as client:
with client.stream(
"POST",
f"{API_BASE}/chat/completions",
headers={
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
},
json={
"model": MODEL,
"messages": messages,
"tools": tools,
"tool_choice": "auto",
"stream": True,
"max_tokens": 2048
}
) as response:
for line in response.iter_lines():
if not line or line == "data: [DONE]":
continue
if line.startswith("data: "):
try:
chunk = json.loads(line[6:])
if chunk.get("choices"):
delta = chunk["choices"][0].get("delta", {})
if delta.get("tool_calls"):
for tc in delta["tool_calls"]:
if tc.get("function", {}).get("arguments"):
args_chunk = tc["function"]["arguments"]
tool_call_chunks.append(args_chunk)
accumulated_content += args_chunk
print(f"[{timestamp()}] Chunk {len(tool_call_chunks)}: +{len(args_chunk)} chars (total: {len(accumulated_content)})")
except json.JSONDecodeError:
pass
end_time = time.time()
print(f"\n{'='*60}")
print(f"Total chunks: {len(tool_call_chunks)}, Total content: {len(accumulated_content)} chars")
print(f"Time: {end_time - start_time:.3f}s")
if len(tool_call_chunks) > 1:
print("✓ PASS: Arguments streamed in multiple chunks")
elif len(tool_call_chunks) == 1:
print("✗ FAIL: Arguments arrived in single chunk (buffered)")
else:
print("? No tool call occurred")
print(f"{'='*60}\n")
def test_non_streaming_tool_call():
"""
Baseline test: non-streaming tool call for comparison.
"""
tools = [
{
"type": "function",
"function": {
"name": "write_file",
"description": "Write content to a file",
"parameters": {
"type": "object",
"properties": {
"filename": {"type": "string"},
"content": {"type": "string"}
},
"required": ["filename", "content"]
}
}
}
]
messages = [
{
"role": "user",
"content": "Write a simple Python hello world and save it using the write_file tool."
}
]
print(f"\n{'='*60}")
print(f"TEST: Non-streaming tool call (baseline)")
print(f"{'='*60}\n")
start_time = time.time()
with httpx.Client(timeout=120.0) as client:
response = client.post(
f"{API_BASE}/chat/completions",
headers={
"Authorization": f"Bearer {API_KEY}",
"Content-Type": "application/json"
},
json={
"model": MODEL,
"messages": messages,
"tools": tools,
"tool_choice": "auto",
"stream": False,
"max_tokens": 1024
}
)
result = response.json()
end_time = time.time()
print(f"Status: {response.status_code}")
print(f"Time: {end_time - start_time:.3f}s")
if result.get("choices"):
message = result["choices"][0].get("message", {})
if message.get("tool_calls"):
for tc in message["tool_calls"]:
print(f"Tool: {tc['function']['name']}")
args = json.loads(tc["function"]["arguments"])
print(f"Arguments parsed successfully")
print(f" - filename: {args.get('filename')}")
print(f" - content length: {len(args.get('content', ''))}")
else:
print("No tool call in response")
print(f"{'='*60}\n")
def main():
print("\n" + "="*60)
print("vLLM GLM-5.1 Streaming Tool Call Tests")
print("="*60)
# Check API connectivity
print(f"\nChecking API at {API_BASE}...")
try:
with httpx.Client(timeout=10.0) as client:
response = client.get(f"{API_BASE.replace('/v1', '')}/health")
print(f"Health check: {response.status_code}")
except Exception as e:
print(f"Warning: Could not reach API - {e}")
# Run tests
print("\nRunning tests...\n")
# Test 1: Non-streaming baseline
test_non_streaming_tool_call()
# Test 2: Streaming with nested JSON
test_streaming_tool_call_with_json()
# Test 3: Main test - streaming with long code
result = test_streaming_tool_call_with_code()
print("\nAll tests complete.")
if __name__ == "__main__":
main()