Add degeneration test 1: chat-template token-ID diff
This commit is contained in:
221
tests/unit/test_degeneration_1_chat_template.py
Normal file
221
tests/unit/test_degeneration_1_chat_template.py
Normal file
@@ -0,0 +1,221 @@
|
||||
#!/usr/bin/env python3
|
||||
"""DEGENERATION TEST 1 — Chat-template token-ID diff.
|
||||
|
||||
Hypothesis: the hand-rolled prompt is out-of-distribution for this reasoning model.
|
||||
The hand-rolled construction is:
|
||||
[bos, USER_TOKEN] + tokenizer.encode('\n\n' + PROMPT) + [ASSISTANT_TOKEN, THINK_START]
|
||||
|
||||
This test diffs it against what apply_chat_template produces.
|
||||
If they differ, the hand-rolled path is wrong and likely causes degenerate output.
|
||||
"""
|
||||
import os, sys, json
|
||||
|
||||
CHECKPOINT_DIR = os.environ.get("CHECKPOINT_DIR", "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4")
|
||||
PROMPT = os.environ.get("TEST_PROMPT", "The capital of France is")
|
||||
|
||||
# These must match single_shot_inference.py
|
||||
THINK_START, THINK_END = 128821, 128822
|
||||
USER_TOKEN, ASSISTANT_TOKEN = 128803, 128804
|
||||
|
||||
def main():
|
||||
from transformers import AutoTokenizer
|
||||
print("=" * 70)
|
||||
print("DEGENERATION TEST 1 — Chat-template token-ID diff")
|
||||
print("=" * 70)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
|
||||
|
||||
# 1. Build hand-rolled prompt (exactly as single_shot_inference.py does)
|
||||
bos = tokenizer.bos_token_id or 0
|
||||
input_ids = [bos, USER_TOKEN]
|
||||
input_ids += tokenizer.encode('\n\n' + PROMPT, add_special_tokens=False)
|
||||
input_ids.append(ASSISTANT_TOKEN)
|
||||
input_ids.append(THINK_START)
|
||||
|
||||
print(f"\n1. HAND-ROLLED prompt:")
|
||||
print(f" Token IDs ({len(input_ids)}): {input_ids}")
|
||||
print(f" Decoded: {repr(tokenizer.decode(input_ids))}")
|
||||
print(f" Token-by-token:")
|
||||
for i, tid in enumerate(input_ids):
|
||||
tok_str = repr(tokenizer.decode([tid]))
|
||||
print(f" [{i:3d}] id={tid:>7d} {tok_str}")
|
||||
|
||||
# 2. Build canonical template via apply_chat_template
|
||||
print(f"\n2. CANONICAL apply_chat_template:")
|
||||
|
||||
# First, print the raw chat template source
|
||||
print(f"\n --- Raw chat_template ---")
|
||||
if hasattr(tokenizer, 'chat_template') and tokenizer.chat_template:
|
||||
# Print first 2000 chars of template
|
||||
tmpl = tokenizer.chat_template
|
||||
print(f" Length: {len(tmpl)} chars")
|
||||
print(f" {tmpl[:2000]}")
|
||||
if len(tmpl) > 2000:
|
||||
print(f" ... ({len(tmpl) - 2000} more chars)")
|
||||
else:
|
||||
print(" WARNING: No chat_template attribute found on tokenizer!")
|
||||
|
||||
# Check tokenizer_config.json for templates
|
||||
import os
|
||||
config_path = os.path.join(CHECKPOINT_DIR, "tokenizer_config.json")
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path) as f:
|
||||
tcfg = json.load(f)
|
||||
# Check for chat_template field(s)
|
||||
if "chat_template" in tcfg:
|
||||
ct = tcfg["chat_template"]
|
||||
if isinstance(ct, str):
|
||||
print(f"\n --- chat_template in tokenizer_config.json (string, {len(ct)} chars) ---")
|
||||
print(f" {ct[:2000]}")
|
||||
elif isinstance(ct, list):
|
||||
print(f"\n --- chat_template in tokenizer_config.json (list, {len(ct)} entries) ---")
|
||||
for entry in ct:
|
||||
name = entry.get("name", "default")
|
||||
tmpl = entry.get("template", "")
|
||||
print(f"\n Template '{name}' ({len(tmpl)} chars):")
|
||||
print(f" {tmpl[:1500]}")
|
||||
# Check for thinking-related config
|
||||
for key in tcfg:
|
||||
if 'think' in key.lower():
|
||||
print(f"\n THINKING-RELATED CONFIG: {key} = {tcfg[key]}")
|
||||
|
||||
# Try apply_chat_template with various options
|
||||
messages = [{"role": "user", "content": PROMPT}]
|
||||
|
||||
# Basic call
|
||||
try:
|
||||
ref_ids = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=True)
|
||||
ref_str = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=False)
|
||||
print(f"\n apply_chat_template (default):")
|
||||
print(f" Token IDs ({len(ref_ids)}): {ref_ids}")
|
||||
print(f" Decoded: {repr(ref_str)}")
|
||||
print(f" Token-by-token:")
|
||||
for i, tid in enumerate(ref_ids):
|
||||
tok_str = repr(tokenizer.decode([tid]))
|
||||
print(f" [{i:3d}] id={tid:>7d} {tok_str}")
|
||||
except Exception as e:
|
||||
print(f" apply_chat_template (default) FAILED: {e}")
|
||||
ref_ids = None
|
||||
|
||||
# Try with enable_thinking=True if supported
|
||||
try:
|
||||
ref_ids_think = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=True,
|
||||
enable_thinking=True)
|
||||
ref_str_think = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=False,
|
||||
enable_thinking=True)
|
||||
print(f"\n apply_chat_template (enable_thinking=True):")
|
||||
print(f" Token IDs ({len(ref_ids_think)}): {ref_ids_think}")
|
||||
print(f" Decoded: {repr(ref_str_think)}")
|
||||
print(f" Token-by-token:")
|
||||
for i, tid in enumerate(ref_ids_think):
|
||||
tok_str = repr(tokenizer.decode([tid]))
|
||||
print(f" [{i:3d}] id={tid:>7d} {tok_str}")
|
||||
except Exception as e:
|
||||
print(f" apply_chat_template (enable_thinking=True) FAILED: {e}")
|
||||
ref_ids_think = None
|
||||
|
||||
# Try with thinking=True (alternate kwarg name)
|
||||
try:
|
||||
ref_ids_think2 = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=True,
|
||||
thinking=True)
|
||||
print(f"\n apply_chat_template (thinking=True):")
|
||||
print(f" Token IDs ({len(ref_ids_think2)}): {ref_ids_think2}")
|
||||
except Exception as e:
|
||||
print(f" apply_chat_template (thinking=True) FAILED: {e}")
|
||||
ref_ids_think2 = None
|
||||
|
||||
# 3. Diff
|
||||
print(f"\n{'='*70}")
|
||||
print(f"3. DIFF ANALYSIS")
|
||||
print(f"{'='*70}")
|
||||
|
||||
best_ref = ref_ids_think if ref_ids_think is not None else ref_ids
|
||||
if best_ref is None:
|
||||
print(" CRITICAL: apply_chat_template produced no output — cannot diff!")
|
||||
print(" The tokenizer may not support chat templates.")
|
||||
return
|
||||
|
||||
# Compare hand-rolled vs canonical
|
||||
if input_ids == best_ref:
|
||||
print(" IDENTICAL — hand-rolled matches apply_chat_template")
|
||||
print(" This means the prompt is correct (per the tokenizer).")
|
||||
print(" If the model still degenerates, the issue is elsewhere.")
|
||||
else:
|
||||
print(" DIFFERENT — hand-rolled does NOT match apply_chat_template!")
|
||||
print(f" Hand-rolled: {len(input_ids)} tokens")
|
||||
print(f" Canonical: {len(best_ref)} tokens")
|
||||
|
||||
# Find first difference
|
||||
min_len = min(len(input_ids), len(best_ref))
|
||||
first_diff = None
|
||||
for i in range(min_len):
|
||||
if input_ids[i] != best_ref[i]:
|
||||
first_diff = i
|
||||
break
|
||||
if first_diff is None and len(input_ids) != len(best_ref):
|
||||
first_diff = min_len
|
||||
|
||||
if first_diff is not None:
|
||||
print(f" First difference at position {first_diff}:")
|
||||
# Show context around the diff
|
||||
start = max(0, first_diff - 2)
|
||||
end = min(len(input_ids), len(best_ref), first_diff + 5)
|
||||
print(f" {'pos':>5s} {'hand_rolled':>12s} {'canonical':>12s} {'hr_decoded':>20s} {'can_decoded':>20s}")
|
||||
for i in range(start, end):
|
||||
hr_id = input_ids[i] if i < len(input_ids) else "—"
|
||||
can_id = best_ref[i] if i < len(best_ref) else "—"
|
||||
hr_dec = repr(tokenizer.decode([input_ids[i]])) if i < len(input_ids) else "—"
|
||||
can_dec = repr(tokenizer.decode([best_ref[i]])) if i < len(best_ref) else "—"
|
||||
marker = " <<<" if i == first_diff else ""
|
||||
print(f" {i:5d} {str(hr_id):>12s} {str(can_id):>12s} {hr_dec:>20s} {can_dec:>20s}{marker}")
|
||||
|
||||
# Also check if the canonical template ends with THINK_START
|
||||
if best_ref[-1] == THINK_START:
|
||||
print(f"\n Canonical template ENDS with THINK_START ({THINK_START}) ✓")
|
||||
elif best_ref[-1] == ASSISTANT_TOKEN:
|
||||
print(f"\n WARNING: Canonical template ends with ASSISTANT_TOKEN ({ASSISTANT_TOKEN}), NOT THINK_START!")
|
||||
print(f" The model may expect THINK_START priming after ASSISTANT_TOKEN for reasoning.")
|
||||
else:
|
||||
last_tok = tokenizer.decode([best_ref[-1]])
|
||||
print(f"\n Canonical template ends with token {best_ref[-1]} ({repr(last_tok)})")
|
||||
# Check if THINK_START is anywhere in the canonical output
|
||||
if THINK_START in best_ref:
|
||||
idx = best_ref.index(THINK_START)
|
||||
print(f" THINK_START found at position {idx} in canonical output")
|
||||
else:
|
||||
print(f" THINK_START ({THINK_START}) NOT FOUND in canonical output at all!")
|
||||
|
||||
# Check if canonical has different newline placement
|
||||
# Count newline tokens
|
||||
nl_token_ids = set()
|
||||
for tok in ['\n', '\n\n', '\n\n\n']:
|
||||
tid = tokenizer.encode(tok, add_special_tokens=False)
|
||||
nl_token_ids.update(tid)
|
||||
hr_nls = sum(1 for t in input_ids if t in nl_token_ids)
|
||||
can_nls = sum(1 for t in best_ref if t in nl_token_ids)
|
||||
print(f"\n Newline tokens: hand_rolled={hr_nls}, canonical={can_nls}")
|
||||
|
||||
# 4. Also dump ALL special tokens for reference
|
||||
print(f"\n4. SPECIAL TOKENS:")
|
||||
if hasattr(tokenizer, 'special_tokens_map'):
|
||||
for k, v in tokenizer.special_tokens_map.items():
|
||||
tid = tokenizer.convert_tokens_to_ids(v) if isinstance(v, str) else [tokenizer.convert_tokens_to_ids(t) for t in v]
|
||||
print(f" {k}: {v} -> id={tid}")
|
||||
if hasattr(tokenizer, 'added_tokens_decoder'):
|
||||
print(f"\n Added tokens containing 'think' or 'user' or 'assistant':")
|
||||
for tid, tok in tokenizer.added_tokens_decoder.items():
|
||||
s = str(tok)
|
||||
if any(x in s.lower() for x in ['think', 'user', 'assistant']):
|
||||
print(f" id={tid}: {s}")
|
||||
|
||||
print(f"\n{'='*70}")
|
||||
print("TEST 1 COMPLETE — report the diff results above")
|
||||
print(f"{'='*70}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user