From cfea22cd6fd3ac7e1143706a8aae884e0ea135df Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 3 Jun 2026 12:19:38 +0000 Subject: [PATCH] Update PyTorch reference with official DSV4 encoding + batched prefill MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Only template/tokenizer/parser changes — no SwiGLU or compressor fixes. Official encoding module, chat mode, batched prefill, stop set, official parser for structured output. --- .../single_shot_PYTORCH_REFERENCE.py | 89 ++++++++++++++----- 1 file changed, 69 insertions(+), 20 deletions(-) diff --git a/dsv4/reference/single_shot_PYTORCH_REFERENCE.py b/dsv4/reference/single_shot_PYTORCH_REFERENCE.py index 64ccad80..c4e5b6bb 100644 --- a/dsv4/reference/single_shot_PYTORCH_REFERENCE.py +++ b/dsv4/reference/single_shot_PYTORCH_REFERENCE.py @@ -68,8 +68,19 @@ SEED = _args.seed VERBOSE = _args.verbose GROWTH_DIAG = VERBOSE >= 1 -THINK_START, THINK_END = 128821, 128822 -USER_TOKEN, ASSISTANT_TOKEN = 128803, 128804 +# Derive special tokens from official encoding module (set after tokenizer init) +from encoding.deepseek_v4_encoding import ( + thinking_start_token as _THINK_START_STR, + thinking_end_token as _THINK_END_STR, + USER_SP_TOKEN as _USER_STR, + ASSISTANT_SP_TOKEN as _ASSISTANT_STR, + eos_token as _EOS_STR, + bos_token as _BOS_STR, +) +THINK_START = None +THINK_END = None +USER_TOKEN = None +ASSISTANT_TOKEN = None # ===================================================================== # NVFP4 dequantization — two-level scale @@ -748,20 +759,40 @@ def main(): from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR) + # Set special token IDs from official encoding + global THINK_START, THINK_END, USER_TOKEN, ASSISTANT_TOKEN + THINK_START = tokenizer.convert_tokens_to_ids(_THINK_START_STR) + THINK_END = tokenizer.convert_tokens_to_ids(_THINK_END_STR) + USER_TOKEN = tokenizer.convert_tokens_to_ids(_USER_STR) + ASSISTANT_TOKEN = tokenizer.convert_tokens_to_ids(_ASSISTANT_STR) bos = tokenizer.bos_token_id or 0 - input_ids = [bos, USER_TOKEN] - input_ids += tokenizer.encode('\n\n' + PROMPT, add_special_tokens=False) - input_ids.append(ASSISTANT_TOKEN) - generated = input_ids.copy() - print(f"Input: {len(generated)} tokens") + eos_id = tokenizer.eos_token_id + print(f" Special tokens: THINK_START={THINK_START} THINK_END={THINK_END} USER={USER_TOKEN} ASST={ASSISTANT_TOKEN} BOS={bos} EOS={eos_id}") - # Prefill - print(f"Prefilling {len(generated)} tokens...") - for pi, tid_val in enumerate(generated): + # Official DeepSeek V4 encoding + from encoding.deepseek_v4_encoding import encode_messages + messages = [{"role": "user", "content": PROMPT}] + encoded_str = encode_messages(messages, thinking_mode='chat') + generated = tokenizer.encode(encoded_str, add_special_tokens=False) + if generated[0] != bos: + generated = [bos] + generated + print(f"Input: {len(generated)} tokens (chat mode, official encoding)") + + # Batched prefill — process tokens in chunks of up to 128 + PREFILL_CHUNK = 128 + n_prefill = len(generated) + print(f"Prefilling {n_prefill} tokens, chunk_size={PREFILL_CHUNK}...") + prefill_ids = torch.tensor(generated, dtype=torch.long, device='cuda:0') + all_positions = torch.arange(n_prefill, dtype=torch.long, device='cuda:0') + + for cs in range(0, n_prefill, PREFILL_CHUNK): + ce = min(cs + PREFILL_CHUNK, n_prefill) + chunk_len = ce - cs t1 = time.time() - tid = torch.tensor([tid_val], dtype=torch.long, device='cuda:0') - pos = torch.tensor([pi], dtype=torch.long, device='cuda:0') - X = mHCBlock.init_state(embed(tid)) + chunk_ids = prefill_ids[cs:ce] + chunk_pos = all_positions[cs:ce] + chunk_embed = embed(chunk_ids) + X = mHCBlock.init_state(chunk_embed) for li in range(n_layers): gpu = li % NUM_GPUS if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}") @@ -769,10 +800,10 @@ def main(): X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu], attn_mhcs.get(li), ffn_mhcs.get(li), attn_norms.get(li), ffn_norms.get(li), - kv_caches[li], pos, tid, + kv_caches[li], chunk_pos, chunk_ids, compressors.get(li), indexers.get(li)) X = X.to('cuda:0'); torch.cuda.set_device(0) - if pi % 10 == 0: print(f" Token {pi}/{len(generated)}: {time.time()-t1:.2f}s", flush=True) + print(f" Chunk tokens {cs}-{ce-1} ({chunk_len} tok): {time.time()-t1:.2f}s", flush=True) print(f" Prefill done ({time.time()-t0:.1f}s)") # Decode @@ -808,12 +839,30 @@ def main(): f"logits=[{logits.float().min().item():.1f},{logits.float().max().item():.1f}] " f"nan={has_nan} |X|={X.abs().max().item():.1f} top5: {top5}", flush=True) if has_nan: break - if next_id == tokenizer.eos_token_id: break + STOP_IDS = set() + if eos_id is not None: STOP_IDS.add(eos_id) + STOP_IDS.add(USER_TOKEN) + if next_id in STOP_IDS: + print(f" STOP ({next_id}) at step {step}", flush=True) + break - out = tokenizer.decode(all_tokens, skip_special_tokens=True) - print(f"\n{'='*70}") - print(f"Input: '{PROMPT}'") - print(f"Output: '{out}'") + # Parse with official DSV4 parser + out_raw = tokenizer.decode(all_tokens, skip_special_tokens=False) + try: + from encoding.deepseek_v4_encoding import parse_message_from_completion_text + assistant_start = out_raw.find(_ASSISTANT_STR) + assistant_text = out_raw[assistant_start + len(_ASSISTANT_STR):] if assistant_start >= 0 else out_raw + parsed = parse_message_from_completion_text(assistant_text, thinking_mode='chat') + content = parsed.get('content', '') + print(f"\n{'='*70}") + print(f"Input: '{PROMPT}'") + print(f"Content: {content}") + except Exception as e: + out = tokenizer.decode(all_tokens, skip_special_tokens=True) + print(f"\n{'='*70}") + print(f"Input: '{PROMPT}'") + print(f"Output (raw): '{out}'") + print(f"Parse error: {e}") print(f"Total: {time.time()-t0:.1f}s") print(f"{'='*70}")