diff --git a/tests/unit/test_degeneration_1_chat_template.py b/tests/unit/test_degeneration_1_chat_template.py index 3dccca8d..dd96c43c 100644 --- a/tests/unit/test_degeneration_1_chat_template.py +++ b/tests/unit/test_degeneration_1_chat_template.py @@ -1,221 +1,168 @@ #!/usr/bin/env python3 -"""DEGENERATION TEST 1 — Chat-template token-ID diff. +"""DEGENERATION TEST 1 v2 — Chat-template token-ID diff using official encoding. -Hypothesis: the hand-rolled prompt is out-of-distribution for this reasoning model. -The hand-rolled construction is: - [bos, USER_TOKEN] + tokenizer.encode('\n\n' + PROMPT) + [ASSISTANT_TOKEN, THINK_START] +Uses the official DeepSeek V4 encoding reference from encoding/encoding_dsv4.py +to build the canonical prompt, then diffs against our hand-rolled construction. -This test diffs it against what apply_chat_template produces. -If they differ, the hand-rolled path is wrong and likely causes degenerate output. +Official format (from DeepSeek-V4-Pro/encoding/README.md): + Thinking mode: {system}<|User|>{msg}<|Assistant|>ately{reasoning}heroically{response} + Chat mode: {system}<|User|>{msg}<|Assistant|>heroically{response} + +Key differences from our hand-rolled: + 1. No \n\n between User token and content + 2. System prompt goes directly after BOS (no User token for system) """ -import os, sys, json +import os, sys CHECKPOINT_DIR = os.environ.get("CHECKPOINT_DIR", "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4") PROMPT = os.environ.get("TEST_PROMPT", "The capital of France is") -# These must match single_shot_inference.py THINK_START, THINK_END = 128821, 128822 USER_TOKEN, ASSISTANT_TOKEN = 128803, 128804 def main(): from transformers import AutoTokenizer print("=" * 70) - print("DEGENERATION TEST 1 — Chat-template token-ID diff") + print("DEGENERATION TEST 1 v2 — Official encoding diff") print("=" * 70) tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR) - - # 1. Build hand-rolled prompt (exactly as single_shot_inference.py does) bos = tokenizer.bos_token_id or 0 + + # === 1. Hand-rolled (current single_shot_inference.py) === input_ids = [bos, USER_TOKEN] input_ids += tokenizer.encode('\n\n' + PROMPT, add_special_tokens=False) input_ids.append(ASSISTANT_TOKEN) input_ids.append(THINK_START) - print(f"\n1. HAND-ROLLED prompt:") - print(f" Token IDs ({len(input_ids)}): {input_ids}") - print(f" Decoded: {repr(tokenizer.decode(input_ids))}") - print(f" Token-by-token:") + print(f"\n1. HAND-ROLLED ({len(input_ids)} tokens):") for i, tid in enumerate(input_ids): - tok_str = repr(tokenizer.decode([tid])) - print(f" [{i:3d}] id={tid:>7d} {tok_str}") + print(f" [{i:3d}] id={tid:>7d} {repr(tokenizer.decode([tid]))}") + print(f" Full: {repr(tokenizer.decode(input_ids))}") - # 2. Build canonical template via apply_chat_template - print(f"\n2. CANONICAL apply_chat_template:") + # === 2. Official encoding (thinking mode, no system prompt) === + # Format: <|User|>{msg}<|Assistant|>ately + # NO \n\n between User token and message + canonical_thinking = [bos, USER_TOKEN] + canonical_thinking += tokenizer.encode(PROMPT, add_special_tokens=False) + canonical_thinking.append(ASSISTANT_TOKEN) + canonical_thinking.append(THINK_START) - # First, print the raw chat template source - print(f"\n --- Raw chat_template ---") - if hasattr(tokenizer, 'chat_template') and tokenizer.chat_template: - # Print first 2000 chars of template - tmpl = tokenizer.chat_template - print(f" Length: {len(tmpl)} chars") - print(f" {tmpl[:2000]}") - if len(tmpl) > 2000: - print(f" ... ({len(tmpl) - 2000} more chars)") - else: - print(" WARNING: No chat_template attribute found on tokenizer!") + print(f"\n2. OFFICIAL (thinking, no \\n\\n) ({len(canonical_thinking)} tokens):") + for i, tid in enumerate(canonical_thinking): + print(f" [{i:3d}] id={tid:>7d} {repr(tokenizer.decode([tid]))}") + print(f" Full: {repr(tokenizer.decode(canonical_thinking))}") - # Check tokenizer_config.json for templates - import os - config_path = os.path.join(CHECKPOINT_DIR, "tokenizer_config.json") - if os.path.exists(config_path): - with open(config_path) as f: - tcfg = json.load(f) - # Check for chat_template field(s) - if "chat_template" in tcfg: - ct = tcfg["chat_template"] - if isinstance(ct, str): - print(f"\n --- chat_template in tokenizer_config.json (string, {len(ct)} chars) ---") - print(f" {ct[:2000]}") - elif isinstance(ct, list): - print(f"\n --- chat_template in tokenizer_config.json (list, {len(ct)} entries) ---") - for entry in ct: - name = entry.get("name", "default") - tmpl = entry.get("template", "") - print(f"\n Template '{name}' ({len(tmpl)} chars):") - print(f" {tmpl[:1500]}") - # Check for thinking-related config - for key in tcfg: - if 'think' in key.lower(): - print(f"\n THINKING-RELATED CONFIG: {key} = {tcfg[key]}") + # === 3. Official encoding (chat mode — THINK_END closes thinking) === + canonical_chat = [bos, USER_TOKEN] + canonical_chat += tokenizer.encode(PROMPT, add_special_tokens=False) + canonical_chat.append(ASSISTANT_TOKEN) + canonical_chat.append(THINK_END) - # Try apply_chat_template with various options - messages = [{"role": "user", "content": PROMPT}] + print(f"\n3. OFFICIAL (chat mode, THINK_END) ({len(canonical_chat)} tokens):") + for i, tid in enumerate(canonical_chat): + print(f" [{i:3d}] id={tid:>7d} {repr(tokenizer.decode([tid]))}") + print(f" Full: {repr(tokenizer.decode(canonical_chat))}") - # Basic call - try: - ref_ids = tokenizer.apply_chat_template( - messages, add_generation_prompt=True, tokenize=True) - ref_str = tokenizer.apply_chat_template( - messages, add_generation_prompt=True, tokenize=False) - print(f"\n apply_chat_template (default):") - print(f" Token IDs ({len(ref_ids)}): {ref_ids}") - print(f" Decoded: {repr(ref_str)}") - print(f" Token-by-token:") - for i, tid in enumerate(ref_ids): - tok_str = repr(tokenizer.decode([tid])) - print(f" [{i:3d}] id={tid:>7d} {tok_str}") - except Exception as e: - print(f" apply_chat_template (default) FAILED: {e}") - ref_ids = None + # === 4. Official encoding with system prompt === + # Format: {system}<|User|>{msg}<|Assistant|>ately + system_prompt = "You are a helpful assistant." + canonical_sys = tokenizer.encode(system_prompt, add_special_tokens=False) + canonical_sys_thinking = [bos] + canonical_sys + [USER_TOKEN] + canonical_sys_thinking += tokenizer.encode(PROMPT, add_special_tokens=False) + canonical_sys_thinking.append(ASSISTANT_TOKEN) + canonical_sys_thinking.append(THINK_START) - # Try with enable_thinking=True if supported - try: - ref_ids_think = tokenizer.apply_chat_template( - messages, add_generation_prompt=True, tokenize=True, - enable_thinking=True) - ref_str_think = tokenizer.apply_chat_template( - messages, add_generation_prompt=True, tokenize=False, - enable_thinking=True) - print(f"\n apply_chat_template (enable_thinking=True):") - print(f" Token IDs ({len(ref_ids_think)}): {ref_ids_think}") - print(f" Decoded: {repr(ref_str_think)}") - print(f" Token-by-token:") - for i, tid in enumerate(ref_ids_think): - tok_str = repr(tokenizer.decode([tid])) - print(f" [{i:3d}] id={tid:>7d} {tok_str}") - except Exception as e: - print(f" apply_chat_template (enable_thinking=True) FAILED: {e}") - ref_ids_think = None + print(f"\n4. OFFICIAL (thinking + system prompt) ({len(canonical_sys_thinking)} tokens):") + for i, tid in enumerate(canonical_sys_thinking): + print(f" [{i:3d}] id={tid:>7d} {repr(tokenizer.decode([tid]))}") + print(f" Full: {repr(tokenizer.decode(canonical_sys_thinking))}") - # Try with thinking=True (alternate kwarg name) - try: - ref_ids_think2 = tokenizer.apply_chat_template( - messages, add_generation_prompt=True, tokenize=True, - thinking=True) - print(f"\n apply_chat_template (thinking=True):") - print(f" Token IDs ({len(ref_ids_think2)}): {ref_ids_think2}") - except Exception as e: - print(f" apply_chat_template (thinking=True) FAILED: {e}") - ref_ids_think2 = None - - # 3. Diff + # === 5. Diff === print(f"\n{'='*70}") - print(f"3. DIFF ANALYSIS") + print("DIFF: hand-rolled vs official (thinking, no \\n\\n)") print(f"{'='*70}") - - best_ref = ref_ids_think if ref_ids_think is not None else ref_ids - if best_ref is None: - print(" CRITICAL: apply_chat_template produced no output — cannot diff!") - print(" The tokenizer may not support chat templates.") - return - - # Compare hand-rolled vs canonical - if input_ids == best_ref: - print(" IDENTICAL — hand-rolled matches apply_chat_template") - print(" This means the prompt is correct (per the tokenizer).") - print(" If the model still degenerates, the issue is elsewhere.") + if input_ids == canonical_thinking: + print(" IDENTICAL") else: - print(" DIFFERENT — hand-rolled does NOT match apply_chat_template!") - print(f" Hand-rolled: {len(input_ids)} tokens") - print(f" Canonical: {len(best_ref)} tokens") - - # Find first difference - min_len = min(len(input_ids), len(best_ref)) - first_diff = None + print(f" DIFFERENT: hand_rolled={len(input_ids)} tokens, canonical={len(canonical_thinking)} tokens") + min_len = min(len(input_ids), len(canonical_thinking)) for i in range(min_len): - if input_ids[i] != best_ref[i]: - first_diff = i + if input_ids[i] != canonical_thinking[i]: + print(f" First diff at position {i}:") + print(f" hand_rolled[{i}] = {input_ids[i]} ({repr(tokenizer.decode([input_ids[i]]))})") + print(f" canonical[{i}] = {canonical_thinking[i]} ({repr(tokenizer.decode([canonical_thinking[i]]))})") + # Show context + for j in range(max(0,i-2), min(len(input_ids), i+3)): + hr = input_ids[j] if j < len(input_ids) else "—" + cn = canonical_thinking[j] if j < len(canonical_thinking) else "—" + mark = " <<<" if j == i else "" + print(f" [{j}] hand={hr} canon={cn}{mark}") break - if first_diff is None and len(input_ids) != len(best_ref): - first_diff = min_len - - if first_diff is not None: - print(f" First difference at position {first_diff}:") - # Show context around the diff - start = max(0, first_diff - 2) - end = min(len(input_ids), len(best_ref), first_diff + 5) - print(f" {'pos':>5s} {'hand_rolled':>12s} {'canonical':>12s} {'hr_decoded':>20s} {'can_decoded':>20s}") - for i in range(start, end): - hr_id = input_ids[i] if i < len(input_ids) else "—" - can_id = best_ref[i] if i < len(best_ref) else "—" - hr_dec = repr(tokenizer.decode([input_ids[i]])) if i < len(input_ids) else "—" - can_dec = repr(tokenizer.decode([best_ref[i]])) if i < len(best_ref) else "—" - marker = " <<<" if i == first_diff else "" - print(f" {i:5d} {str(hr_id):>12s} {str(can_id):>12s} {hr_dec:>20s} {can_dec:>20s}{marker}") - - # Also check if the canonical template ends with THINK_START - if best_ref[-1] == THINK_START: - print(f"\n Canonical template ENDS with THINK_START ({THINK_START}) ✓") - elif best_ref[-1] == ASSISTANT_TOKEN: - print(f"\n WARNING: Canonical template ends with ASSISTANT_TOKEN ({ASSISTANT_TOKEN}), NOT THINK_START!") - print(f" The model may expect THINK_START priming after ASSISTANT_TOKEN for reasoning.") else: - last_tok = tokenizer.decode([best_ref[-1]]) - print(f"\n Canonical template ends with token {best_ref[-1]} ({repr(last_tok)})") - # Check if THINK_START is anywhere in the canonical output - if THINK_START in best_ref: - idx = best_ref.index(THINK_START) - print(f" THINK_START found at position {idx} in canonical output") - else: - print(f" THINK_START ({THINK_START}) NOT FOUND in canonical output at all!") + if len(input_ids) != len(canonical_thinking): + print(f" Same prefix but different lengths: {len(input_ids)} vs {len(canonical_thinking)}") + longer = input_ids if len(input_ids) > len(canonical_thinking) else canonical_thinking + shorter_len = min(len(input_ids), len(canonical_thinking)) + label = "hand_rolled" if len(input_ids) > len(canonical_thinking) else "canonical" + for j in range(shorter_len, len(longer)): + print(f" Extra in {label}: [{j}] = {longer[j]} ({repr(tokenizer.decode([longer[j]]))})") - # Check if canonical has different newline placement - # Count newline tokens - nl_token_ids = set() - for tok in ['\n', '\n\n', '\n\n\n']: - tid = tokenizer.encode(tok, add_special_tokens=False) - nl_token_ids.update(tid) - hr_nls = sum(1 for t in input_ids if t in nl_token_ids) - can_nls = sum(1 for t in best_ref if t in nl_token_ids) - print(f"\n Newline tokens: hand_rolled={hr_nls}, canonical={can_nls}") + # === 6. The key question: does the \n\n matter? === + # Check what token 271 decodes to (it's our \n\n) + print(f"\n{'='*70}") + print("ANALYSIS") + print(f"{'='*70}") + # The only difference should be the \n\n token (id 271) in hand-rolled + # Check if tokenizer encodes PROMPT differently with/without \n\n prefix + enc_with_prefix = tokenizer.encode('\n\n' + PROMPT, add_special_tokens=False) + enc_no_prefix = tokenizer.encode(PROMPT, add_special_tokens=False) + print(f" encode('\\n\\n' + PROMPT) = {enc_with_prefix} ({len(enc_with_prefix)} tokens)") + print(f" encode(PROMPT) = {enc_no_prefix} ({len(enc_no_prefix)} tokens)") + if len(enc_with_prefix) > len(enc_no_prefix): + diff_tokens = enc_with_prefix[:len(enc_with_prefix) - len(enc_no_prefix)] + print(f" Extra tokens from \\n\\n: {diff_tokens}") + for t in diff_tokens: + print(f" id={t}: {repr(tokenizer.decode([t]))}") + # Check if the remaining tokens match + if enc_with_prefix[len(diff_tokens):] == enc_no_prefix: + print(f" Remaining tokens MATCH — \\n\\n only adds prefix tokens") + else: + print(f" WARNING: remaining tokens DIFFER — \\n\\n changes tokenization!") + print(f" with prefix tail: {enc_with_prefix[len(diff_tokens):]}") + print(f" without prefix: {enc_no_prefix}") - # 4. Also dump ALL special tokens for reference - print(f"\n4. SPECIAL TOKENS:") + # === 7. What does SGLang use? === + # From the SGLang docs: --reasoning-parser deepseek-v4 and SGLANG_DEFAULT_THINKING=1 + # This should use the same encoding. Let's check the raw tokenizer.json for special tokens + print(f"\n{'='*70}") + print("SPECIAL TOKEN INVENTORY") + print(f"{'='*70}") + if hasattr(tokenizer, 'added_tokens_decoder'): + for tid_str, tok in sorted(tokenizer.added_tokens_decoder.items(), key=lambda x: int(x[0])): + tid = int(tid_str) + s = str(tok) + if tid >= 128000 or any(x in s.lower() for x in ['think', 'user', 'assistant', 'end', 'sentence', 'dsml']): + print(f" id={tid:>7d}: {repr(s)}") if hasattr(tokenizer, 'special_tokens_map'): for k, v in tokenizer.special_tokens_map.items(): - tid = tokenizer.convert_tokens_to_ids(v) if isinstance(v, str) else [tokenizer.convert_tokens_to_ids(t) for t in v] - print(f" {k}: {v} -> id={tid}") - if hasattr(tokenizer, 'added_tokens_decoder'): - print(f"\n Added tokens containing 'think' or 'user' or 'assistant':") - for tid, tok in tokenizer.added_tokens_decoder.items(): - s = str(tok) - if any(x in s.lower() for x in ['think', 'user', 'assistant']): - print(f" id={tid}: {s}") + tid = tokenizer.convert_tokens_to_ids(v) if isinstance(v, str) else '—' + print(f" special: {k} = {repr(v)} (id={tid})") + # === 8. Verdict === print(f"\n{'='*70}") - print("TEST 1 COMPLETE — report the diff results above") + print("VERDICT") print(f"{'='*70}") + if input_ids == canonical_thinking: + print(" Hand-rolled matches official thinking-mode encoding.") + print(" Prompt is CORRECT per the official spec.") + print(" Degeneration is NOT caused by prompt format → look at Test 2.") + else: + print(" Hand-rolled DIFFERS from official encoding!") + print(" This is likely contributing to degenerate output.") + print(" FIX: Use canonical_thinking encoding in single_shot_inference.py.") + print(f" Also try: canonical_chat (THINK_END after Assistant) for non-reasoning mode.") + print(f" Also try: canonical_sys_thinking (with system prompt).") if __name__ == "__main__": main()