Add stage1b - API debugger with chat template

This commit is contained in:
2026-04-10 16:11:56 +00:00
parent f20d3bebc3
commit becee624c6

View File

@@ -0,0 +1,191 @@
"""
Stage 1b: Chat Template Debugger via OpenAI API
Runs the same write_file prompt through vLLM's /v1/chat/completions endpoint
(so the chat template IS applied) and captures the raw SSE stream.
Compare with stage1_debug.py (raw generate, no template) to see if the
chat template itself triggers tool-call tokens or if the model still
code-dumps.
"""
import os
import json
import httpx
import time
API_BASE = os.environ.get("API_BASE", "http://localhost:8000/v1")
MODEL = os.environ.get("MODEL", "/workspace/models/SmolLM3-3B")
MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "512"))
TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.0"))
tools = [
{
"type": "function",
"function": {
"name": "write_file",
"description": "Write content to a file.",
"parameters": {
"type": "object",
"properties": {
"filename": {
"type": "string",
"description": "Name of the file to write"
},
"content": {
"type": "string",
"description": "The content to write to the file"
}
},
"required": ["filename", "content"]
}
}
}
]
messages = [
{
"role": "user",
"content": 'Write "hello world" to /tmp/test.txt'
}
]
# ── First, let's see what the chat template produces ─────────────────────────
print("[stage1b] Loading tokenizer to inspect chat template rendering...")
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
rendered = tokenizer.apply_chat_template(
messages,
tools=tools,
tokenize=False,
add_generation_prompt=True,
)
print(f"[stage1b] Chat template rendered prompt ({len(rendered)} chars):")
print(f"{'' * 60}")
print(rendered)
print(f"{'' * 60}\n")
# ── Streaming request ────────────────────────────────────────────────────────
print(f"[stage1b] Sending streaming request to {API_BASE}/chat/completions ...")
print(f"[stage1b] Model: {MODEL}")
print(f"[stage1b] Tools: write_file")
print(f"[stage1b] Temperature: {TEMPERATURE}, Max tokens: {MAX_TOKENS}\n")
chunk_count = 0
tool_call_chunks = []
tool_name = None
content_chunks = 0
reasoning_chunks = 0
accumulated_args = ""
all_deltas = []
start_time = time.time()
with httpx.Client(timeout=120.0) as client:
with client.stream(
"POST",
f"{API_BASE}/chat/completions",
headers={
"Content-Type": "application/json"
},
json={
"model": MODEL,
"messages": messages,
"tools": tools,
"tool_choice": "auto",
"stream": True,
"max_tokens": MAX_TOKENS,
"temperature": TEMPERATURE,
}
) as response:
print(f"[stage1b] Response status: {response.status_code}\n")
for line in response.iter_lines():
if not line or line == "data: [DONE]":
continue
if not line.startswith("data: "):
continue
try:
chunk = json.loads(line[6:])
chunk_count += 1
if not chunk.get("choices"):
continue
delta = chunk["choices"][0].get("delta", {})
all_deltas.append(delta)
if delta.get("reasoning"):
reasoning_chunks += 1
if delta.get("tool_calls"):
for tc in delta["tool_calls"]:
fn = tc.get("function", {})
if fn.get("name"):
tool_name = fn["name"]
print(f" [chunk {chunk_count}] TOOL CALL NAME: {tool_name}")
if fn.get("arguments"):
tool_call_chunks.append(fn["arguments"])
accumulated_args += fn["arguments"]
if delta.get("content") and delta["content"]:
content_chunks += 1
# Print first few content chunks for visibility
if content_chunks <= 5:
print(f" [chunk {chunk_count}] CONTENT: {json.dumps(delta['content'][:100])}")
except json.JSONDecodeError as e:
print(f" [chunk {chunk_count}] JSON error: {e}")
end_time = time.time()
# ── Summary ──────────────────────────────────────────────────────────────────
print(f"\n{'' * 60}")
print("SUMMARY")
print(f"{'' * 60}")
print(f"Total SSE chunks: {chunk_count}")
print(f" Reasoning chunks: {reasoning_chunks}")
print(f" Tool call arg chunks: {len(tool_call_chunks)}")
print(f" Content chunks: {content_chunks}")
print(f"Total time: {end_time - start_time:.3f}s")
if tool_name:
print(f"\nTool called: {tool_name}")
print(f"Total tool call args: {len(accumulated_args)} chars")
try:
args = json.loads(accumulated_args)
print(f"✓ Parsed successfully:")
for k, v in args.items():
print(f" - {k}: {v if len(str(v)) < 100 else f'{str(v)[:100]}...'}")
except json.JSONDecodeError as e:
print(f"✗ Failed to parse: {e}")
print(f"Raw (first 500): {accumulated_args[:500]}")
else:
print("\n⚠ NO TOOL CALL emitted by the model")
if content_chunks > 0:
full_content = ""
for d in all_deltas:
if d.get("content"):
full_content += d["content"]
print(f"\nFull content output ({len(full_content)} chars):")
print(f"{'' * 60}")
print(full_content[:1000])
if len(full_content) > 1000:
print(f"... ({len(full_content) - 1000} more chars)")
print(f"{'' * 60}")
# ── Verdict ──────────────────────────────────────────────────────────────────
print(f"\n{'' * 60}")
if tool_name and len(tool_call_chunks) > 0:
print(f"✓ MODEL EMITTED TOOL CALL via API (chat template applied)")
print(f" → The chat template triggers the tool-call path")
print(f" → Compare with stage1 raw output to see the difference")
else:
print(f"✗ MODEL DID NOT EMIT TOOL CALL even with chat template")
print(f" → This is a model capability issue")
print(f"{'' * 60}")