Add stage1b - API debugger with chat template
This commit is contained in:
191
scripts/stage1b_api_debug.py
Normal file
191
scripts/stage1b_api_debug.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
Stage 1b: Chat Template Debugger via OpenAI API
|
||||
|
||||
Runs the same write_file prompt through vLLM's /v1/chat/completions endpoint
|
||||
(so the chat template IS applied) and captures the raw SSE stream.
|
||||
|
||||
Compare with stage1_debug.py (raw generate, no template) to see if the
|
||||
chat template itself triggers tool-call tokens or if the model still
|
||||
code-dumps.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import httpx
|
||||
import time
|
||||
|
||||
API_BASE = os.environ.get("API_BASE", "http://localhost:8000/v1")
|
||||
MODEL = os.environ.get("MODEL", "/workspace/models/SmolLM3-3B")
|
||||
MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "512"))
|
||||
TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.0"))
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "write_file",
|
||||
"description": "Write content to a file.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "Name of the file to write"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The content to write to the file"
|
||||
}
|
||||
},
|
||||
"required": ["filename", "content"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": 'Write "hello world" to /tmp/test.txt'
|
||||
}
|
||||
]
|
||||
|
||||
# ── First, let's see what the chat template produces ─────────────────────────
|
||||
print("[stage1b] Loading tokenizer to inspect chat template rendering...")
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
|
||||
|
||||
rendered = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tools=tools,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
print(f"[stage1b] Chat template rendered prompt ({len(rendered)} chars):")
|
||||
print(f"{'═' * 60}")
|
||||
print(rendered)
|
||||
print(f"{'═' * 60}\n")
|
||||
|
||||
# ── Streaming request ────────────────────────────────────────────────────────
|
||||
|
||||
print(f"[stage1b] Sending streaming request to {API_BASE}/chat/completions ...")
|
||||
print(f"[stage1b] Model: {MODEL}")
|
||||
print(f"[stage1b] Tools: write_file")
|
||||
print(f"[stage1b] Temperature: {TEMPERATURE}, Max tokens: {MAX_TOKENS}\n")
|
||||
|
||||
chunk_count = 0
|
||||
tool_call_chunks = []
|
||||
tool_name = None
|
||||
content_chunks = 0
|
||||
reasoning_chunks = 0
|
||||
accumulated_args = ""
|
||||
all_deltas = []
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
with httpx.Client(timeout=120.0) as client:
|
||||
with client.stream(
|
||||
"POST",
|
||||
f"{API_BASE}/chat/completions",
|
||||
headers={
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": MODEL,
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"tool_choice": "auto",
|
||||
"stream": True,
|
||||
"max_tokens": MAX_TOKENS,
|
||||
"temperature": TEMPERATURE,
|
||||
}
|
||||
) as response:
|
||||
print(f"[stage1b] Response status: {response.status_code}\n")
|
||||
|
||||
for line in response.iter_lines():
|
||||
if not line or line == "data: [DONE]":
|
||||
continue
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk = json.loads(line[6:])
|
||||
chunk_count += 1
|
||||
|
||||
if not chunk.get("choices"):
|
||||
continue
|
||||
|
||||
delta = chunk["choices"][0].get("delta", {})
|
||||
all_deltas.append(delta)
|
||||
|
||||
if delta.get("reasoning"):
|
||||
reasoning_chunks += 1
|
||||
|
||||
if delta.get("tool_calls"):
|
||||
for tc in delta["tool_calls"]:
|
||||
fn = tc.get("function", {})
|
||||
if fn.get("name"):
|
||||
tool_name = fn["name"]
|
||||
print(f" [chunk {chunk_count}] TOOL CALL NAME: {tool_name}")
|
||||
if fn.get("arguments"):
|
||||
tool_call_chunks.append(fn["arguments"])
|
||||
accumulated_args += fn["arguments"]
|
||||
|
||||
if delta.get("content") and delta["content"]:
|
||||
content_chunks += 1
|
||||
# Print first few content chunks for visibility
|
||||
if content_chunks <= 5:
|
||||
print(f" [chunk {chunk_count}] CONTENT: {json.dumps(delta['content'][:100])}")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f" [chunk {chunk_count}] JSON error: {e}")
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# ── Summary ──────────────────────────────────────────────────────────────────
|
||||
|
||||
print(f"\n{'═' * 60}")
|
||||
print("SUMMARY")
|
||||
print(f"{'═' * 60}")
|
||||
print(f"Total SSE chunks: {chunk_count}")
|
||||
print(f" Reasoning chunks: {reasoning_chunks}")
|
||||
print(f" Tool call arg chunks: {len(tool_call_chunks)}")
|
||||
print(f" Content chunks: {content_chunks}")
|
||||
print(f"Total time: {end_time - start_time:.3f}s")
|
||||
|
||||
if tool_name:
|
||||
print(f"\nTool called: {tool_name}")
|
||||
print(f"Total tool call args: {len(accumulated_args)} chars")
|
||||
try:
|
||||
args = json.loads(accumulated_args)
|
||||
print(f"✓ Parsed successfully:")
|
||||
for k, v in args.items():
|
||||
print(f" - {k}: {v if len(str(v)) < 100 else f'{str(v)[:100]}...'}")
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"✗ Failed to parse: {e}")
|
||||
print(f"Raw (first 500): {accumulated_args[:500]}")
|
||||
else:
|
||||
print("\n⚠ NO TOOL CALL emitted by the model")
|
||||
|
||||
if content_chunks > 0:
|
||||
full_content = ""
|
||||
for d in all_deltas:
|
||||
if d.get("content"):
|
||||
full_content += d["content"]
|
||||
print(f"\nFull content output ({len(full_content)} chars):")
|
||||
print(f"{'─' * 60}")
|
||||
print(full_content[:1000])
|
||||
if len(full_content) > 1000:
|
||||
print(f"... ({len(full_content) - 1000} more chars)")
|
||||
print(f"{'─' * 60}")
|
||||
|
||||
# ── Verdict ──────────────────────────────────────────────────────────────────
|
||||
print(f"\n{'═' * 60}")
|
||||
if tool_name and len(tool_call_chunks) > 0:
|
||||
print(f"✓ MODEL EMITTED TOOL CALL via API (chat template applied)")
|
||||
print(f" → The chat template triggers the tool-call path")
|
||||
print(f" → Compare with stage1 raw output to see the difference")
|
||||
else:
|
||||
print(f"✗ MODEL DID NOT EMIT TOOL CALL even with chat template")
|
||||
print(f" → This is a model capability issue")
|
||||
print(f"{'═' * 60}")
|
||||
Reference in New Issue
Block a user