""" Stage 1: Chat Template Debugger Runs a raw prompt through vLLM's generate() API — no chat template, no serving layer, no middleware. Captures the exact tokens the model emits so you can determine whether tool-call failures are a model problem or a parser problem. """ import os import json from vllm import LLM, SamplingParams MODEL_PATH = os.environ.get("MODEL_PATH", "/workspace/models/SmolLM3-3B") PROMPT_FILE = os.environ.get("PROMPT_FILE", "/workspace/prompts/smol_tool_call.txt") MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "512")) TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.0")) # ── Load prompt ────────────────────────────────────────────────────────────── if os.path.exists(PROMPT_FILE): with open(PROMPT_FILE, "r") as f: prompt = f.read().strip() else: print(f"[stage1] Prompt file not found: {PROMPT_FILE}") print("[stage1] Using inline fallback prompt.") prompt = """You are a helpful assistant with access to tools. Available tools: - write_file: Write content to a file. Args: {"path": "string", "content": "string"} User: Write "hello world" to /tmp/test.txt Assistant:""" print(f"[stage1] Prompt ({len(prompt)} chars):\n{'─' * 60}") print(prompt) print(f"{'─' * 60}\n") # ── Run model ──────────────────────────────────────────────────────────────── print(f"[stage1] Loading model from {MODEL_PATH} ...") llm = LLM(model=MODEL_PATH, trust_remote_code=True) params = SamplingParams( temperature=TEMPERATURE, max_tokens=MAX_TOKENS, ) print(f"[stage1] Generating (temp={TEMPERATURE}, max_tokens={MAX_TOKENS}) ...\n") outputs = llm.generate([prompt], params) # ── Dump results ───────────────────────────────────────────────────────────── for output in outputs: generated = output.outputs[0] text = generated.text token_ids = list(generated.token_ids) print(f"{'═' * 60}") print(f"RAW TEXT:\n{text}") print(f"{'─' * 60}") print(f"TOKEN IDS ({len(token_ids)} tokens):") print(json.dumps(token_ids)) print(f"{'─' * 60}") print("PER-TOKEN DECODE:") tokenizer = llm.get_tokenizer() for i, tid in enumerate(token_ids): decoded = tokenizer.decode([tid]) print(f" [{i:4d}] id={tid:>8d} → {json.dumps(decoded)}") print(f"{'═' * 60}")