diff --git a/scripts/stage1b_api_debug.py b/scripts/stage1b_api_debug.py new file mode 100644 index 0000000..868ad89 --- /dev/null +++ b/scripts/stage1b_api_debug.py @@ -0,0 +1,191 @@ +""" +Stage 1b: Chat Template Debugger via OpenAI API + +Runs the same write_file prompt through vLLM's /v1/chat/completions endpoint +(so the chat template IS applied) and captures the raw SSE stream. + +Compare with stage1_debug.py (raw generate, no template) to see if the +chat template itself triggers tool-call tokens or if the model still +code-dumps. +""" + +import os +import json +import httpx +import time + +API_BASE = os.environ.get("API_BASE", "http://localhost:8000/v1") +MODEL = os.environ.get("MODEL", "/workspace/models/SmolLM3-3B") +MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "512")) +TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.0")) + +tools = [ + { + "type": "function", + "function": { + "name": "write_file", + "description": "Write content to a file.", + "parameters": { + "type": "object", + "properties": { + "filename": { + "type": "string", + "description": "Name of the file to write" + }, + "content": { + "type": "string", + "description": "The content to write to the file" + } + }, + "required": ["filename", "content"] + } + } + } +] + +messages = [ + { + "role": "user", + "content": 'Write "hello world" to /tmp/test.txt' + } +] + +# ── First, let's see what the chat template produces ───────────────────────── +print("[stage1b] Loading tokenizer to inspect chat template rendering...") +from transformers import AutoTokenizer +tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True) + +rendered = tokenizer.apply_chat_template( + messages, + tools=tools, + tokenize=False, + add_generation_prompt=True, +) +print(f"[stage1b] Chat template rendered prompt ({len(rendered)} chars):") +print(f"{'═' * 60}") +print(rendered) +print(f"{'═' * 60}\n") + +# ── Streaming request ──────────────────────────────────────────────────────── + +print(f"[stage1b] Sending streaming request to {API_BASE}/chat/completions ...") +print(f"[stage1b] Model: {MODEL}") +print(f"[stage1b] Tools: write_file") +print(f"[stage1b] Temperature: {TEMPERATURE}, Max tokens: {MAX_TOKENS}\n") + +chunk_count = 0 +tool_call_chunks = [] +tool_name = None +content_chunks = 0 +reasoning_chunks = 0 +accumulated_args = "" +all_deltas = [] + +start_time = time.time() + +with httpx.Client(timeout=120.0) as client: + with client.stream( + "POST", + f"{API_BASE}/chat/completions", + headers={ + "Content-Type": "application/json" + }, + json={ + "model": MODEL, + "messages": messages, + "tools": tools, + "tool_choice": "auto", + "stream": True, + "max_tokens": MAX_TOKENS, + "temperature": TEMPERATURE, + } + ) as response: + print(f"[stage1b] Response status: {response.status_code}\n") + + for line in response.iter_lines(): + if not line or line == "data: [DONE]": + continue + if not line.startswith("data: "): + continue + + try: + chunk = json.loads(line[6:]) + chunk_count += 1 + + if not chunk.get("choices"): + continue + + delta = chunk["choices"][0].get("delta", {}) + all_deltas.append(delta) + + if delta.get("reasoning"): + reasoning_chunks += 1 + + if delta.get("tool_calls"): + for tc in delta["tool_calls"]: + fn = tc.get("function", {}) + if fn.get("name"): + tool_name = fn["name"] + print(f" [chunk {chunk_count}] TOOL CALL NAME: {tool_name}") + if fn.get("arguments"): + tool_call_chunks.append(fn["arguments"]) + accumulated_args += fn["arguments"] + + if delta.get("content") and delta["content"]: + content_chunks += 1 + # Print first few content chunks for visibility + if content_chunks <= 5: + print(f" [chunk {chunk_count}] CONTENT: {json.dumps(delta['content'][:100])}") + + except json.JSONDecodeError as e: + print(f" [chunk {chunk_count}] JSON error: {e}") + +end_time = time.time() + +# ── Summary ────────────────────────────────────────────────────────────────── + +print(f"\n{'═' * 60}") +print("SUMMARY") +print(f"{'═' * 60}") +print(f"Total SSE chunks: {chunk_count}") +print(f" Reasoning chunks: {reasoning_chunks}") +print(f" Tool call arg chunks: {len(tool_call_chunks)}") +print(f" Content chunks: {content_chunks}") +print(f"Total time: {end_time - start_time:.3f}s") + +if tool_name: + print(f"\nTool called: {tool_name}") + print(f"Total tool call args: {len(accumulated_args)} chars") + try: + args = json.loads(accumulated_args) + print(f"✓ Parsed successfully:") + for k, v in args.items(): + print(f" - {k}: {v if len(str(v)) < 100 else f'{str(v)[:100]}...'}") + except json.JSONDecodeError as e: + print(f"✗ Failed to parse: {e}") + print(f"Raw (first 500): {accumulated_args[:500]}") +else: + print("\n⚠ NO TOOL CALL emitted by the model") + +if content_chunks > 0: + full_content = "" + for d in all_deltas: + if d.get("content"): + full_content += d["content"] + print(f"\nFull content output ({len(full_content)} chars):") + print(f"{'─' * 60}") + print(full_content[:1000]) + if len(full_content) > 1000: + print(f"... ({len(full_content) - 1000} more chars)") + print(f"{'─' * 60}") + +# ── Verdict ────────────────────────────────────────────────────────────────── +print(f"\n{'═' * 60}") +if tool_name and len(tool_call_chunks) > 0: + print(f"✓ MODEL EMITTED TOOL CALL via API (chat template applied)") + print(f" → The chat template triggers the tool-call path") + print(f" → Compare with stage1 raw output to see the difference") +else: + print(f"✗ MODEL DID NOT EMIT TOOL CALL even with chat template") + print(f" → This is a model capability issue") +print(f"{'═' * 60}")