Update PyTorch reference with official DSV4 encoding + batched prefill
Only template/tokenizer/parser changes — no SwiGLU or compressor fixes. Official encoding module, chat mode, batched prefill, stop set, official parser for structured output.
This commit is contained in:
@@ -68,8 +68,19 @@ SEED = _args.seed
|
||||
VERBOSE = _args.verbose
|
||||
GROWTH_DIAG = VERBOSE >= 1
|
||||
|
||||
THINK_START, THINK_END = 128821, 128822
|
||||
USER_TOKEN, ASSISTANT_TOKEN = 128803, 128804
|
||||
# Derive special tokens from official encoding module (set after tokenizer init)
|
||||
from encoding.deepseek_v4_encoding import (
|
||||
thinking_start_token as _THINK_START_STR,
|
||||
thinking_end_token as _THINK_END_STR,
|
||||
USER_SP_TOKEN as _USER_STR,
|
||||
ASSISTANT_SP_TOKEN as _ASSISTANT_STR,
|
||||
eos_token as _EOS_STR,
|
||||
bos_token as _BOS_STR,
|
||||
)
|
||||
THINK_START = None
|
||||
THINK_END = None
|
||||
USER_TOKEN = None
|
||||
ASSISTANT_TOKEN = None
|
||||
|
||||
# =====================================================================
|
||||
# NVFP4 dequantization — two-level scale
|
||||
@@ -748,20 +759,40 @@ def main():
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
|
||||
|
||||
# Set special token IDs from official encoding
|
||||
global THINK_START, THINK_END, USER_TOKEN, ASSISTANT_TOKEN
|
||||
THINK_START = tokenizer.convert_tokens_to_ids(_THINK_START_STR)
|
||||
THINK_END = tokenizer.convert_tokens_to_ids(_THINK_END_STR)
|
||||
USER_TOKEN = tokenizer.convert_tokens_to_ids(_USER_STR)
|
||||
ASSISTANT_TOKEN = tokenizer.convert_tokens_to_ids(_ASSISTANT_STR)
|
||||
bos = tokenizer.bos_token_id or 0
|
||||
input_ids = [bos, USER_TOKEN]
|
||||
input_ids += tokenizer.encode('\n\n' + PROMPT, add_special_tokens=False)
|
||||
input_ids.append(ASSISTANT_TOKEN)
|
||||
generated = input_ids.copy()
|
||||
print(f"Input: {len(generated)} tokens")
|
||||
eos_id = tokenizer.eos_token_id
|
||||
print(f" Special tokens: THINK_START={THINK_START} THINK_END={THINK_END} USER={USER_TOKEN} ASST={ASSISTANT_TOKEN} BOS={bos} EOS={eos_id}")
|
||||
|
||||
# Prefill
|
||||
print(f"Prefilling {len(generated)} tokens...")
|
||||
for pi, tid_val in enumerate(generated):
|
||||
# Official DeepSeek V4 encoding
|
||||
from encoding.deepseek_v4_encoding import encode_messages
|
||||
messages = [{"role": "user", "content": PROMPT}]
|
||||
encoded_str = encode_messages(messages, thinking_mode='chat')
|
||||
generated = tokenizer.encode(encoded_str, add_special_tokens=False)
|
||||
if generated[0] != bos:
|
||||
generated = [bos] + generated
|
||||
print(f"Input: {len(generated)} tokens (chat mode, official encoding)")
|
||||
|
||||
# Batched prefill — process tokens in chunks of up to 128
|
||||
PREFILL_CHUNK = 128
|
||||
n_prefill = len(generated)
|
||||
print(f"Prefilling {n_prefill} tokens, chunk_size={PREFILL_CHUNK}...")
|
||||
prefill_ids = torch.tensor(generated, dtype=torch.long, device='cuda:0')
|
||||
all_positions = torch.arange(n_prefill, dtype=torch.long, device='cuda:0')
|
||||
|
||||
for cs in range(0, n_prefill, PREFILL_CHUNK):
|
||||
ce = min(cs + PREFILL_CHUNK, n_prefill)
|
||||
chunk_len = ce - cs
|
||||
t1 = time.time()
|
||||
tid = torch.tensor([tid_val], dtype=torch.long, device='cuda:0')
|
||||
pos = torch.tensor([pi], dtype=torch.long, device='cuda:0')
|
||||
X = mHCBlock.init_state(embed(tid))
|
||||
chunk_ids = prefill_ids[cs:ce]
|
||||
chunk_pos = all_positions[cs:ce]
|
||||
chunk_embed = embed(chunk_ids)
|
||||
X = mHCBlock.init_state(chunk_embed)
|
||||
for li in range(n_layers):
|
||||
gpu = li % NUM_GPUS
|
||||
if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}")
|
||||
@@ -769,10 +800,10 @@ def main():
|
||||
X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu],
|
||||
attn_mhcs.get(li), ffn_mhcs.get(li),
|
||||
attn_norms.get(li), ffn_norms.get(li),
|
||||
kv_caches[li], pos, tid,
|
||||
kv_caches[li], chunk_pos, chunk_ids,
|
||||
compressors.get(li), indexers.get(li))
|
||||
X = X.to('cuda:0'); torch.cuda.set_device(0)
|
||||
if pi % 10 == 0: print(f" Token {pi}/{len(generated)}: {time.time()-t1:.2f}s", flush=True)
|
||||
print(f" Chunk tokens {cs}-{ce-1} ({chunk_len} tok): {time.time()-t1:.2f}s", flush=True)
|
||||
print(f" Prefill done ({time.time()-t0:.1f}s)")
|
||||
|
||||
# Decode
|
||||
@@ -808,12 +839,30 @@ def main():
|
||||
f"logits=[{logits.float().min().item():.1f},{logits.float().max().item():.1f}] "
|
||||
f"nan={has_nan} |X|={X.abs().max().item():.1f} top5: {top5}", flush=True)
|
||||
if has_nan: break
|
||||
if next_id == tokenizer.eos_token_id: break
|
||||
STOP_IDS = set()
|
||||
if eos_id is not None: STOP_IDS.add(eos_id)
|
||||
STOP_IDS.add(USER_TOKEN)
|
||||
if next_id in STOP_IDS:
|
||||
print(f" STOP ({next_id}) at step {step}", flush=True)
|
||||
break
|
||||
|
||||
out = tokenizer.decode(all_tokens, skip_special_tokens=True)
|
||||
print(f"\n{'='*70}")
|
||||
print(f"Input: '{PROMPT}'")
|
||||
print(f"Output: '{out}'")
|
||||
# Parse with official DSV4 parser
|
||||
out_raw = tokenizer.decode(all_tokens, skip_special_tokens=False)
|
||||
try:
|
||||
from encoding.deepseek_v4_encoding import parse_message_from_completion_text
|
||||
assistant_start = out_raw.find(_ASSISTANT_STR)
|
||||
assistant_text = out_raw[assistant_start + len(_ASSISTANT_STR):] if assistant_start >= 0 else out_raw
|
||||
parsed = parse_message_from_completion_text(assistant_text, thinking_mode='chat')
|
||||
content = parsed.get('content', '')
|
||||
print(f"\n{'='*70}")
|
||||
print(f"Input: '{PROMPT}'")
|
||||
print(f"Content: {content}")
|
||||
except Exception as e:
|
||||
out = tokenizer.decode(all_tokens, skip_special_tokens=True)
|
||||
print(f"\n{'='*70}")
|
||||
print(f"Input: '{PROMPT}'")
|
||||
print(f"Output (raw): '{out}'")
|
||||
print(f"Parse error: {e}")
|
||||
print(f"Total: {time.time()-t0:.1f}s")
|
||||
print(f"{'='*70}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user