Update PyTorch reference with official DSV4 encoding + batched prefill

Only template/tokenizer/parser changes — no SwiGLU or compressor fixes.
Official encoding module, chat mode, batched prefill, stop set,
official parser for structured output.
This commit is contained in:
2026-06-03 12:19:38 +00:00
parent bdd9ab9669
commit cfea22cd6f

View File

@@ -68,8 +68,19 @@ SEED = _args.seed
VERBOSE = _args.verbose
GROWTH_DIAG = VERBOSE >= 1
THINK_START, THINK_END = 128821, 128822
USER_TOKEN, ASSISTANT_TOKEN = 128803, 128804
# Derive special tokens from official encoding module (set after tokenizer init)
from encoding.deepseek_v4_encoding import (
thinking_start_token as _THINK_START_STR,
thinking_end_token as _THINK_END_STR,
USER_SP_TOKEN as _USER_STR,
ASSISTANT_SP_TOKEN as _ASSISTANT_STR,
eos_token as _EOS_STR,
bos_token as _BOS_STR,
)
THINK_START = None
THINK_END = None
USER_TOKEN = None
ASSISTANT_TOKEN = None
# =====================================================================
# NVFP4 dequantization — two-level scale
@@ -748,20 +759,40 @@ def main():
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
# Set special token IDs from official encoding
global THINK_START, THINK_END, USER_TOKEN, ASSISTANT_TOKEN
THINK_START = tokenizer.convert_tokens_to_ids(_THINK_START_STR)
THINK_END = tokenizer.convert_tokens_to_ids(_THINK_END_STR)
USER_TOKEN = tokenizer.convert_tokens_to_ids(_USER_STR)
ASSISTANT_TOKEN = tokenizer.convert_tokens_to_ids(_ASSISTANT_STR)
bos = tokenizer.bos_token_id or 0
input_ids = [bos, USER_TOKEN]
input_ids += tokenizer.encode('\n\n' + PROMPT, add_special_tokens=False)
input_ids.append(ASSISTANT_TOKEN)
generated = input_ids.copy()
print(f"Input: {len(generated)} tokens")
eos_id = tokenizer.eos_token_id
print(f" Special tokens: THINK_START={THINK_START} THINK_END={THINK_END} USER={USER_TOKEN} ASST={ASSISTANT_TOKEN} BOS={bos} EOS={eos_id}")
# Prefill
print(f"Prefilling {len(generated)} tokens...")
for pi, tid_val in enumerate(generated):
# Official DeepSeek V4 encoding
from encoding.deepseek_v4_encoding import encode_messages
messages = [{"role": "user", "content": PROMPT}]
encoded_str = encode_messages(messages, thinking_mode='chat')
generated = tokenizer.encode(encoded_str, add_special_tokens=False)
if generated[0] != bos:
generated = [bos] + generated
print(f"Input: {len(generated)} tokens (chat mode, official encoding)")
# Batched prefill — process tokens in chunks of up to 128
PREFILL_CHUNK = 128
n_prefill = len(generated)
print(f"Prefilling {n_prefill} tokens, chunk_size={PREFILL_CHUNK}...")
prefill_ids = torch.tensor(generated, dtype=torch.long, device='cuda:0')
all_positions = torch.arange(n_prefill, dtype=torch.long, device='cuda:0')
for cs in range(0, n_prefill, PREFILL_CHUNK):
ce = min(cs + PREFILL_CHUNK, n_prefill)
chunk_len = ce - cs
t1 = time.time()
tid = torch.tensor([tid_val], dtype=torch.long, device='cuda:0')
pos = torch.tensor([pi], dtype=torch.long, device='cuda:0')
X = mHCBlock.init_state(embed(tid))
chunk_ids = prefill_ids[cs:ce]
chunk_pos = all_positions[cs:ce]
chunk_embed = embed(chunk_ids)
X = mHCBlock.init_state(chunk_embed)
for li in range(n_layers):
gpu = li % NUM_GPUS
if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}")
@@ -769,10 +800,10 @@ def main():
X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu],
attn_mhcs.get(li), ffn_mhcs.get(li),
attn_norms.get(li), ffn_norms.get(li),
kv_caches[li], pos, tid,
kv_caches[li], chunk_pos, chunk_ids,
compressors.get(li), indexers.get(li))
X = X.to('cuda:0'); torch.cuda.set_device(0)
if pi % 10 == 0: print(f" Token {pi}/{len(generated)}: {time.time()-t1:.2f}s", flush=True)
print(f" Chunk tokens {cs}-{ce-1} ({chunk_len} tok): {time.time()-t1:.2f}s", flush=True)
print(f" Prefill done ({time.time()-t0:.1f}s)")
# Decode
@@ -808,12 +839,30 @@ def main():
f"logits=[{logits.float().min().item():.1f},{logits.float().max().item():.1f}] "
f"nan={has_nan} |X|={X.abs().max().item():.1f} top5: {top5}", flush=True)
if has_nan: break
if next_id == tokenizer.eos_token_id: break
STOP_IDS = set()
if eos_id is not None: STOP_IDS.add(eos_id)
STOP_IDS.add(USER_TOKEN)
if next_id in STOP_IDS:
print(f" STOP ({next_id}) at step {step}", flush=True)
break
out = tokenizer.decode(all_tokens, skip_special_tokens=True)
print(f"\n{'='*70}")
print(f"Input: '{PROMPT}'")
print(f"Output: '{out}'")
# Parse with official DSV4 parser
out_raw = tokenizer.decode(all_tokens, skip_special_tokens=False)
try:
from encoding.deepseek_v4_encoding import parse_message_from_completion_text
assistant_start = out_raw.find(_ASSISTANT_STR)
assistant_text = out_raw[assistant_start + len(_ASSISTANT_STR):] if assistant_start >= 0 else out_raw
parsed = parse_message_from_completion_text(assistant_text, thinking_mode='chat')
content = parsed.get('content', '')
print(f"\n{'='*70}")
print(f"Input: '{PROMPT}'")
print(f"Content: {content}")
except Exception as e:
out = tokenizer.decode(all_tokens, skip_special_tokens=True)
print(f"\n{'='*70}")
print(f"Input: '{PROMPT}'")
print(f"Output (raw): '{out}'")
print(f"Parse error: {e}")
print(f"Total: {time.time()-t0:.1f}s")
print(f"{'='*70}")