Add degeneration test 1: chat-template token-ID diff

This commit is contained in:
2026-06-03 08:17:09 +00:00
parent c5a131c358
commit c77b83fffc

View File

@@ -0,0 +1,221 @@
#!/usr/bin/env python3
"""DEGENERATION TEST 1 — Chat-template token-ID diff.
Hypothesis: the hand-rolled prompt is out-of-distribution for this reasoning model.
The hand-rolled construction is:
[bos, USER_TOKEN] + tokenizer.encode('\n\n' + PROMPT) + [ASSISTANT_TOKEN, THINK_START]
This test diffs it against what apply_chat_template produces.
If they differ, the hand-rolled path is wrong and likely causes degenerate output.
"""
import os, sys, json
CHECKPOINT_DIR = os.environ.get("CHECKPOINT_DIR", "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4")
PROMPT = os.environ.get("TEST_PROMPT", "The capital of France is")
# These must match single_shot_inference.py
THINK_START, THINK_END = 128821, 128822
USER_TOKEN, ASSISTANT_TOKEN = 128803, 128804
def main():
from transformers import AutoTokenizer
print("=" * 70)
print("DEGENERATION TEST 1 — Chat-template token-ID diff")
print("=" * 70)
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
# 1. Build hand-rolled prompt (exactly as single_shot_inference.py does)
bos = tokenizer.bos_token_id or 0
input_ids = [bos, USER_TOKEN]
input_ids += tokenizer.encode('\n\n' + PROMPT, add_special_tokens=False)
input_ids.append(ASSISTANT_TOKEN)
input_ids.append(THINK_START)
print(f"\n1. HAND-ROLLED prompt:")
print(f" Token IDs ({len(input_ids)}): {input_ids}")
print(f" Decoded: {repr(tokenizer.decode(input_ids))}")
print(f" Token-by-token:")
for i, tid in enumerate(input_ids):
tok_str = repr(tokenizer.decode([tid]))
print(f" [{i:3d}] id={tid:>7d} {tok_str}")
# 2. Build canonical template via apply_chat_template
print(f"\n2. CANONICAL apply_chat_template:")
# First, print the raw chat template source
print(f"\n --- Raw chat_template ---")
if hasattr(tokenizer, 'chat_template') and tokenizer.chat_template:
# Print first 2000 chars of template
tmpl = tokenizer.chat_template
print(f" Length: {len(tmpl)} chars")
print(f" {tmpl[:2000]}")
if len(tmpl) > 2000:
print(f" ... ({len(tmpl) - 2000} more chars)")
else:
print(" WARNING: No chat_template attribute found on tokenizer!")
# Check tokenizer_config.json for templates
import os
config_path = os.path.join(CHECKPOINT_DIR, "tokenizer_config.json")
if os.path.exists(config_path):
with open(config_path) as f:
tcfg = json.load(f)
# Check for chat_template field(s)
if "chat_template" in tcfg:
ct = tcfg["chat_template"]
if isinstance(ct, str):
print(f"\n --- chat_template in tokenizer_config.json (string, {len(ct)} chars) ---")
print(f" {ct[:2000]}")
elif isinstance(ct, list):
print(f"\n --- chat_template in tokenizer_config.json (list, {len(ct)} entries) ---")
for entry in ct:
name = entry.get("name", "default")
tmpl = entry.get("template", "")
print(f"\n Template '{name}' ({len(tmpl)} chars):")
print(f" {tmpl[:1500]}")
# Check for thinking-related config
for key in tcfg:
if 'think' in key.lower():
print(f"\n THINKING-RELATED CONFIG: {key} = {tcfg[key]}")
# Try apply_chat_template with various options
messages = [{"role": "user", "content": PROMPT}]
# Basic call
try:
ref_ids = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True)
ref_str = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False)
print(f"\n apply_chat_template (default):")
print(f" Token IDs ({len(ref_ids)}): {ref_ids}")
print(f" Decoded: {repr(ref_str)}")
print(f" Token-by-token:")
for i, tid in enumerate(ref_ids):
tok_str = repr(tokenizer.decode([tid]))
print(f" [{i:3d}] id={tid:>7d} {tok_str}")
except Exception as e:
print(f" apply_chat_template (default) FAILED: {e}")
ref_ids = None
# Try with enable_thinking=True if supported
try:
ref_ids_think = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True,
enable_thinking=True)
ref_str_think = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=False,
enable_thinking=True)
print(f"\n apply_chat_template (enable_thinking=True):")
print(f" Token IDs ({len(ref_ids_think)}): {ref_ids_think}")
print(f" Decoded: {repr(ref_str_think)}")
print(f" Token-by-token:")
for i, tid in enumerate(ref_ids_think):
tok_str = repr(tokenizer.decode([tid]))
print(f" [{i:3d}] id={tid:>7d} {tok_str}")
except Exception as e:
print(f" apply_chat_template (enable_thinking=True) FAILED: {e}")
ref_ids_think = None
# Try with thinking=True (alternate kwarg name)
try:
ref_ids_think2 = tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True,
thinking=True)
print(f"\n apply_chat_template (thinking=True):")
print(f" Token IDs ({len(ref_ids_think2)}): {ref_ids_think2}")
except Exception as e:
print(f" apply_chat_template (thinking=True) FAILED: {e}")
ref_ids_think2 = None
# 3. Diff
print(f"\n{'='*70}")
print(f"3. DIFF ANALYSIS")
print(f"{'='*70}")
best_ref = ref_ids_think if ref_ids_think is not None else ref_ids
if best_ref is None:
print(" CRITICAL: apply_chat_template produced no output — cannot diff!")
print(" The tokenizer may not support chat templates.")
return
# Compare hand-rolled vs canonical
if input_ids == best_ref:
print(" IDENTICAL — hand-rolled matches apply_chat_template")
print(" This means the prompt is correct (per the tokenizer).")
print(" If the model still degenerates, the issue is elsewhere.")
else:
print(" DIFFERENT — hand-rolled does NOT match apply_chat_template!")
print(f" Hand-rolled: {len(input_ids)} tokens")
print(f" Canonical: {len(best_ref)} tokens")
# Find first difference
min_len = min(len(input_ids), len(best_ref))
first_diff = None
for i in range(min_len):
if input_ids[i] != best_ref[i]:
first_diff = i
break
if first_diff is None and len(input_ids) != len(best_ref):
first_diff = min_len
if first_diff is not None:
print(f" First difference at position {first_diff}:")
# Show context around the diff
start = max(0, first_diff - 2)
end = min(len(input_ids), len(best_ref), first_diff + 5)
print(f" {'pos':>5s} {'hand_rolled':>12s} {'canonical':>12s} {'hr_decoded':>20s} {'can_decoded':>20s}")
for i in range(start, end):
hr_id = input_ids[i] if i < len(input_ids) else ""
can_id = best_ref[i] if i < len(best_ref) else ""
hr_dec = repr(tokenizer.decode([input_ids[i]])) if i < len(input_ids) else ""
can_dec = repr(tokenizer.decode([best_ref[i]])) if i < len(best_ref) else ""
marker = " <<<" if i == first_diff else ""
print(f" {i:5d} {str(hr_id):>12s} {str(can_id):>12s} {hr_dec:>20s} {can_dec:>20s}{marker}")
# Also check if the canonical template ends with THINK_START
if best_ref[-1] == THINK_START:
print(f"\n Canonical template ENDS with THINK_START ({THINK_START}) ✓")
elif best_ref[-1] == ASSISTANT_TOKEN:
print(f"\n WARNING: Canonical template ends with ASSISTANT_TOKEN ({ASSISTANT_TOKEN}), NOT THINK_START!")
print(f" The model may expect THINK_START priming after ASSISTANT_TOKEN for reasoning.")
else:
last_tok = tokenizer.decode([best_ref[-1]])
print(f"\n Canonical template ends with token {best_ref[-1]} ({repr(last_tok)})")
# Check if THINK_START is anywhere in the canonical output
if THINK_START in best_ref:
idx = best_ref.index(THINK_START)
print(f" THINK_START found at position {idx} in canonical output")
else:
print(f" THINK_START ({THINK_START}) NOT FOUND in canonical output at all!")
# Check if canonical has different newline placement
# Count newline tokens
nl_token_ids = set()
for tok in ['\n', '\n\n', '\n\n\n']:
tid = tokenizer.encode(tok, add_special_tokens=False)
nl_token_ids.update(tid)
hr_nls = sum(1 for t in input_ids if t in nl_token_ids)
can_nls = sum(1 for t in best_ref if t in nl_token_ids)
print(f"\n Newline tokens: hand_rolled={hr_nls}, canonical={can_nls}")
# 4. Also dump ALL special tokens for reference
print(f"\n4. SPECIAL TOKENS:")
if hasattr(tokenizer, 'special_tokens_map'):
for k, v in tokenizer.special_tokens_map.items():
tid = tokenizer.convert_tokens_to_ids(v) if isinstance(v, str) else [tokenizer.convert_tokens_to_ids(t) for t in v]
print(f" {k}: {v} -> id={tid}")
if hasattr(tokenizer, 'added_tokens_decoder'):
print(f"\n Added tokens containing 'think' or 'user' or 'assistant':")
for tid, tok in tokenizer.added_tokens_decoder.items():
s = str(tok)
if any(x in s.lower() for x in ['think', 'user', 'assistant']):
print(f" id={tid}: {s}")
print(f"\n{'='*70}")
print("TEST 1 COMPLETE — report the diff results above")
print(f"{'='*70}")
if __name__ == "__main__":
main()