Files
chat-template-debugger/scripts/stage1_debug.py

71 lines
2.6 KiB
Python

"""
Stage 1: Chat Template Debugger
Runs a raw prompt through vLLM's generate() API — no chat template, no serving
layer, no middleware. Captures the exact tokens the model emits so you can
determine whether tool-call failures are a model problem or a parser problem.
"""
import os
import json
from vllm import LLM, SamplingParams
MODEL_PATH = os.environ.get("MODEL_PATH", "/workspace/models/SmolLM3-3B")
PROMPT_FILE = os.environ.get("PROMPT_FILE", "/workspace/prompts/smol_tool_call.txt")
MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "512"))
TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.0"))
# ── Load prompt ──────────────────────────────────────────────────────────────
if os.path.exists(PROMPT_FILE):
with open(PROMPT_FILE, "r") as f:
prompt = f.read().strip()
else:
print(f"[stage1] Prompt file not found: {PROMPT_FILE}")
print("[stage1] Using inline fallback prompt.")
prompt = """You are a helpful assistant with access to tools.
Available tools:
- write_file: Write content to a file. Args: {"path": "string", "content": "string"}
User: Write "hello world" to /tmp/test.txt
Assistant:"""
print(f"[stage1] Prompt ({len(prompt)} chars):\n{'' * 60}")
print(prompt)
print(f"{'' * 60}\n")
# ── Run model ────────────────────────────────────────────────────────────────
print(f"[stage1] Loading model from {MODEL_PATH} ...")
llm = LLM(model=MODEL_PATH, trust_remote_code=True)
params = SamplingParams(
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
)
print(f"[stage1] Generating (temp={TEMPERATURE}, max_tokens={MAX_TOKENS}) ...\n")
outputs = llm.generate([prompt], params)
# ── Dump results ─────────────────────────────────────────────────────────────
for output in outputs:
generated = output.outputs[0]
text = generated.text
token_ids = list(generated.token_ids)
print(f"{'' * 60}")
print(f"RAW TEXT:\n{text}")
print(f"{'' * 60}")
print(f"TOKEN IDS ({len(token_ids)} tokens):")
print(json.dumps(token_ids))
print(f"{'' * 60}")
print("PER-TOKEN DECODE:")
tokenizer = llm.get_tokenizer()
for i, tid in enumerate(token_ids):
decoded = tokenizer.decode([tid])
print(f" [{i:4d}] id={tid:>8d}{json.dumps(decoded)}")
print(f"{'' * 60}")