From c77b83fffce8a516c86383ebd527ca965cd3bc94 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 3 Jun 2026 08:17:09 +0000 Subject: [PATCH] Add degeneration test 1: chat-template token-ID diff --- .../unit/test_degeneration_1_chat_template.py | 221 ++++++++++++++++++ 1 file changed, 221 insertions(+) create mode 100644 tests/unit/test_degeneration_1_chat_template.py diff --git a/tests/unit/test_degeneration_1_chat_template.py b/tests/unit/test_degeneration_1_chat_template.py new file mode 100644 index 00000000..3dccca8d --- /dev/null +++ b/tests/unit/test_degeneration_1_chat_template.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python3 +"""DEGENERATION TEST 1 — Chat-template token-ID diff. + +Hypothesis: the hand-rolled prompt is out-of-distribution for this reasoning model. +The hand-rolled construction is: + [bos, USER_TOKEN] + tokenizer.encode('\n\n' + PROMPT) + [ASSISTANT_TOKEN, THINK_START] + +This test diffs it against what apply_chat_template produces. +If they differ, the hand-rolled path is wrong and likely causes degenerate output. +""" +import os, sys, json + +CHECKPOINT_DIR = os.environ.get("CHECKPOINT_DIR", "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4") +PROMPT = os.environ.get("TEST_PROMPT", "The capital of France is") + +# These must match single_shot_inference.py +THINK_START, THINK_END = 128821, 128822 +USER_TOKEN, ASSISTANT_TOKEN = 128803, 128804 + +def main(): + from transformers import AutoTokenizer + print("=" * 70) + print("DEGENERATION TEST 1 — Chat-template token-ID diff") + print("=" * 70) + + tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR) + + # 1. Build hand-rolled prompt (exactly as single_shot_inference.py does) + bos = tokenizer.bos_token_id or 0 + input_ids = [bos, USER_TOKEN] + input_ids += tokenizer.encode('\n\n' + PROMPT, add_special_tokens=False) + input_ids.append(ASSISTANT_TOKEN) + input_ids.append(THINK_START) + + print(f"\n1. HAND-ROLLED prompt:") + print(f" Token IDs ({len(input_ids)}): {input_ids}") + print(f" Decoded: {repr(tokenizer.decode(input_ids))}") + print(f" Token-by-token:") + for i, tid in enumerate(input_ids): + tok_str = repr(tokenizer.decode([tid])) + print(f" [{i:3d}] id={tid:>7d} {tok_str}") + + # 2. Build canonical template via apply_chat_template + print(f"\n2. CANONICAL apply_chat_template:") + + # First, print the raw chat template source + print(f"\n --- Raw chat_template ---") + if hasattr(tokenizer, 'chat_template') and tokenizer.chat_template: + # Print first 2000 chars of template + tmpl = tokenizer.chat_template + print(f" Length: {len(tmpl)} chars") + print(f" {tmpl[:2000]}") + if len(tmpl) > 2000: + print(f" ... ({len(tmpl) - 2000} more chars)") + else: + print(" WARNING: No chat_template attribute found on tokenizer!") + + # Check tokenizer_config.json for templates + import os + config_path = os.path.join(CHECKPOINT_DIR, "tokenizer_config.json") + if os.path.exists(config_path): + with open(config_path) as f: + tcfg = json.load(f) + # Check for chat_template field(s) + if "chat_template" in tcfg: + ct = tcfg["chat_template"] + if isinstance(ct, str): + print(f"\n --- chat_template in tokenizer_config.json (string, {len(ct)} chars) ---") + print(f" {ct[:2000]}") + elif isinstance(ct, list): + print(f"\n --- chat_template in tokenizer_config.json (list, {len(ct)} entries) ---") + for entry in ct: + name = entry.get("name", "default") + tmpl = entry.get("template", "") + print(f"\n Template '{name}' ({len(tmpl)} chars):") + print(f" {tmpl[:1500]}") + # Check for thinking-related config + for key in tcfg: + if 'think' in key.lower(): + print(f"\n THINKING-RELATED CONFIG: {key} = {tcfg[key]}") + + # Try apply_chat_template with various options + messages = [{"role": "user", "content": PROMPT}] + + # Basic call + try: + ref_ids = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True) + ref_str = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False) + print(f"\n apply_chat_template (default):") + print(f" Token IDs ({len(ref_ids)}): {ref_ids}") + print(f" Decoded: {repr(ref_str)}") + print(f" Token-by-token:") + for i, tid in enumerate(ref_ids): + tok_str = repr(tokenizer.decode([tid])) + print(f" [{i:3d}] id={tid:>7d} {tok_str}") + except Exception as e: + print(f" apply_chat_template (default) FAILED: {e}") + ref_ids = None + + # Try with enable_thinking=True if supported + try: + ref_ids_think = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, + enable_thinking=True) + ref_str_think = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False, + enable_thinking=True) + print(f"\n apply_chat_template (enable_thinking=True):") + print(f" Token IDs ({len(ref_ids_think)}): {ref_ids_think}") + print(f" Decoded: {repr(ref_str_think)}") + print(f" Token-by-token:") + for i, tid in enumerate(ref_ids_think): + tok_str = repr(tokenizer.decode([tid])) + print(f" [{i:3d}] id={tid:>7d} {tok_str}") + except Exception as e: + print(f" apply_chat_template (enable_thinking=True) FAILED: {e}") + ref_ids_think = None + + # Try with thinking=True (alternate kwarg name) + try: + ref_ids_think2 = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, + thinking=True) + print(f"\n apply_chat_template (thinking=True):") + print(f" Token IDs ({len(ref_ids_think2)}): {ref_ids_think2}") + except Exception as e: + print(f" apply_chat_template (thinking=True) FAILED: {e}") + ref_ids_think2 = None + + # 3. Diff + print(f"\n{'='*70}") + print(f"3. DIFF ANALYSIS") + print(f"{'='*70}") + + best_ref = ref_ids_think if ref_ids_think is not None else ref_ids + if best_ref is None: + print(" CRITICAL: apply_chat_template produced no output — cannot diff!") + print(" The tokenizer may not support chat templates.") + return + + # Compare hand-rolled vs canonical + if input_ids == best_ref: + print(" IDENTICAL — hand-rolled matches apply_chat_template") + print(" This means the prompt is correct (per the tokenizer).") + print(" If the model still degenerates, the issue is elsewhere.") + else: + print(" DIFFERENT — hand-rolled does NOT match apply_chat_template!") + print(f" Hand-rolled: {len(input_ids)} tokens") + print(f" Canonical: {len(best_ref)} tokens") + + # Find first difference + min_len = min(len(input_ids), len(best_ref)) + first_diff = None + for i in range(min_len): + if input_ids[i] != best_ref[i]: + first_diff = i + break + if first_diff is None and len(input_ids) != len(best_ref): + first_diff = min_len + + if first_diff is not None: + print(f" First difference at position {first_diff}:") + # Show context around the diff + start = max(0, first_diff - 2) + end = min(len(input_ids), len(best_ref), first_diff + 5) + print(f" {'pos':>5s} {'hand_rolled':>12s} {'canonical':>12s} {'hr_decoded':>20s} {'can_decoded':>20s}") + for i in range(start, end): + hr_id = input_ids[i] if i < len(input_ids) else "—" + can_id = best_ref[i] if i < len(best_ref) else "—" + hr_dec = repr(tokenizer.decode([input_ids[i]])) if i < len(input_ids) else "—" + can_dec = repr(tokenizer.decode([best_ref[i]])) if i < len(best_ref) else "—" + marker = " <<<" if i == first_diff else "" + print(f" {i:5d} {str(hr_id):>12s} {str(can_id):>12s} {hr_dec:>20s} {can_dec:>20s}{marker}") + + # Also check if the canonical template ends with THINK_START + if best_ref[-1] == THINK_START: + print(f"\n Canonical template ENDS with THINK_START ({THINK_START}) ✓") + elif best_ref[-1] == ASSISTANT_TOKEN: + print(f"\n WARNING: Canonical template ends with ASSISTANT_TOKEN ({ASSISTANT_TOKEN}), NOT THINK_START!") + print(f" The model may expect THINK_START priming after ASSISTANT_TOKEN for reasoning.") + else: + last_tok = tokenizer.decode([best_ref[-1]]) + print(f"\n Canonical template ends with token {best_ref[-1]} ({repr(last_tok)})") + # Check if THINK_START is anywhere in the canonical output + if THINK_START in best_ref: + idx = best_ref.index(THINK_START) + print(f" THINK_START found at position {idx} in canonical output") + else: + print(f" THINK_START ({THINK_START}) NOT FOUND in canonical output at all!") + + # Check if canonical has different newline placement + # Count newline tokens + nl_token_ids = set() + for tok in ['\n', '\n\n', '\n\n\n']: + tid = tokenizer.encode(tok, add_special_tokens=False) + nl_token_ids.update(tid) + hr_nls = sum(1 for t in input_ids if t in nl_token_ids) + can_nls = sum(1 for t in best_ref if t in nl_token_ids) + print(f"\n Newline tokens: hand_rolled={hr_nls}, canonical={can_nls}") + + # 4. Also dump ALL special tokens for reference + print(f"\n4. SPECIAL TOKENS:") + if hasattr(tokenizer, 'special_tokens_map'): + for k, v in tokenizer.special_tokens_map.items(): + tid = tokenizer.convert_tokens_to_ids(v) if isinstance(v, str) else [tokenizer.convert_tokens_to_ids(t) for t in v] + print(f" {k}: {v} -> id={tid}") + if hasattr(tokenizer, 'added_tokens_decoder'): + print(f"\n Added tokens containing 'think' or 'user' or 'assistant':") + for tid, tok in tokenizer.added_tokens_decoder.items(): + s = str(tok) + if any(x in s.lower() for x in ['think', 'user', 'assistant']): + print(f" id={tid}: {s}") + + print(f"\n{'='*70}") + print("TEST 1 COMPLETE — report the diff results above") + print(f"{'='*70}") + +if __name__ == "__main__": + main()