diff --git a/single_shot_inference.py b/single_shot_inference.py index 19b0d442..f416d99f 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -24,6 +24,7 @@ def parse_args(): p.add_argument('--top-k', type=int, default=50, help='Top-k filtering (0=disabled)') p.add_argument('--top-p', type=float, default=0.95, help='Top-p (nucleus) filtering (1.0=disabled)') p.add_argument('--prompt', type=str, default=None) + p.add_argument('--chat-mode', action='store_true', help='Chat mode: close thinking block after Assistant token (no reasoning)') p.add_argument('--seed', type=int, default=42) p.add_argument('--verbose', type=int, default=1) p.add_argument('--prefill-only', action='store_true') @@ -1460,16 +1461,20 @@ def main(): if _args.prefill_tokens: generated = [int(x) for x in _args.prefill_tokens.split(',')] else: + # Official DeepSeek V4 encoding (from encoding/encoding_dsv4.py): + # <|User|>{msg}<|Assistant|>ately (thinking mode) + # <|User|>{msg}<|Assistant|>heroically (chat mode — closes thinking immediately) + # NOTE: No \n\n between User token and content — the official spec has no separator. + # The \n\n was wrong and made the prompt out-of-distribution. input_ids = [bos, USER_TOKEN] - input_ids += tokenizer.encode('\n\n' + PROMPT, add_special_tokens=False) + input_ids += tokenizer.encode(PROMPT, add_special_tokens=False) input_ids.append(ASSISTANT_TOKEN) - # DSV4 reasoning model: must prime with ◇ (think_start) after Assistant token. - # Without this, the model is out-of-distribution — it expects to be inside a - # thinking block but never received the think-start sentinel. - # Symptom: degenerate output from step 0 (e.g. "France" instead of "Paris", - # looping on newlines/repeated tokens). With ◇, the model generates thinking - # content, emits ◇ (think_end), then produces the actual answer. - input_ids.append(THINK_START) + if _args.chat_mode: + # Chat mode: close thinking block immediately → model generates content directly + input_ids.append(THINK_END) + else: + # Thinking mode: open thinking block → model reasons first, then answers + input_ids.append(THINK_START) generated = input_ids all_tokens = generated.copy() print(f"Input: {len(generated)} tokens")