From 98fa4101670196da65a2de38f434ce6f491e99bf Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 31 May 2026 20:11:37 +0000 Subject: [PATCH] Add HF reference test script --- hf_reference_test.py | 28 +++++ tests/compare_layer0.py | 250 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 278 insertions(+) create mode 100644 hf_reference_test.py create mode 100644 tests/compare_layer0.py diff --git a/hf_reference_test.py b/hf_reference_test.py new file mode 100644 index 00000000..5b1a6bfd --- /dev/null +++ b/hf_reference_test.py @@ -0,0 +1,28 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +model_name = '/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4' +print('Loading tokenizer...', flush=True) +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) +print('Loading model...', flush=True) +model = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype=torch.bfloat16, + device_map='auto', trust_remote_code=True, low_cpu_mem_usage=True +) +model.eval() +print('Model loaded!', flush=True) + +msg = [{'role':'user','content':'The capital of France is'}] +ids = tokenizer.apply_chat_template(msg, add_generation_prompt=True, return_tensors='pt').cuda() +print(f'Input: {ids.shape} tokens: {repr(tokenizer.decode(ids[0]))}', flush=True) + +with torch.no_grad(): + logits = model(ids).logits[0, -1] + top10 = torch.topk(logits, 10) + print('HF Top-10:', flush=True) + for i, (tid, val) in enumerate(zip(top10.indices, top10.values)): + print(f' {i+1}. {repr(tokenizer.decode([tid.item()]))} (id={tid.item()}, logit={val.item():.3f})', flush=True) + + # Generate 10 tokens + out = model.generate(ids, max_new_tokens=10, do_sample=False) + print(f'Generated: {repr(tokenizer.decode(out[0]))}', flush=True) diff --git a/tests/compare_layer0.py b/tests/compare_layer0.py new file mode 100644 index 00000000..2f5ea16d --- /dev/null +++ b/tests/compare_layer0.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python3 +"""Compare our single-shot inference with the HuggingFace reference for layer 0. + +This script processes a single token through just layer 0 and compares the output +with a pure PyTorch reference implementation that matches the HF model exactly. + +Usage (on B200): + python3 tests/compare_layer0.py +""" +import os, sys, json, math, torch +from pathlib import Path + +# Add kernel to path +sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel') + +CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" +DEVICE = "cuda:0" + +def load_weights(): + from safetensors.torch import load_file + cdir = Path(CHECKPOINT_DIR) + index_path = cdir / "model.safetensors.index.json" + weight_map = {} + if index_path.exists(): + with open(index_path) as f: + weight_map = json.load(f).get("weight_map", {}) + shard_names = set(weight_map.values()) if weight_map else { + f"model-{i:05d}-of-00095.safetensors" for i in range(1, 96) + } + all_w = {} + for shard_name in sorted(shard_names): + if not (cdir / shard_name).exists(): + continue + data = load_file(str(cdir / shard_name)) + for k, v in data.items(): + if k.startswith("model.layers.0.") or k in ["model.embed_tokens.weight", "model.norm.weight", "lm_head.weight"]: + all_w[k] = v + return all_w + +# ===================================================================== +# FP4 dequant +# ===================================================================== +FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]) + +def dequant_nvfp4(weight, weight_scale, weight_scale_2): + out_dim = weight.shape[0] + in_packed = weight.shape[1] + in_features = in_packed * 2 + low = (weight & 0x0F).to(torch.int8) + high = (weight >> 4).to(torch.int8) + low_sign, low_idx = (low >> 3).bool(), (low & 0x07).long() + high_sign, high_idx = (high >> 3).bool(), (high & 0x07).long() + lut = FP4_LUT.to(device=weight.device, dtype=torch.float32) + low_f = lut[low_idx] * torch.where(low_sign, -1.0, 1.0) + high_f = lut[high_idx] * torch.where(high_sign, -1.0, 1.0) + w_f = torch.stack([low_f, high_f], dim=-1).reshape(out_dim, in_features) + scale_f = weight_scale.float() * weight_scale_2.float() + scale_expanded = scale_f.repeat_interleave(16, dim=1) + return (w_f * scale_expanded).bfloat16() + +def nvfp4_linear(x, weight, weight_scale, weight_scale_2): + w = dequant_nvfp4(weight, weight_scale, weight_scale_2) + return torch.nn.functional.linear(x, w) + +# ===================================================================== +# Reference: pure PyTorch layer 0 +# ===================================================================== +def reference_layer0(embedding, w, cfg): + """Process one token through layer 0 using pure PyTorch (matching HF).""" + li = 0 + pre = f"model.layers.{li}.self_attn" + n_h = cfg["num_attention_heads"] # 128 + hd = cfg["head_dim"] # 512 + rd = cfg.get("qk_rope_head_dim", 64) # 64 + H = cfg["hidden_size"] # 7168 + o_groups = cfg.get("o_groups", 16) + o_rank = cfg.get("o_group_dim", 1024) + n_hc = 4 + heads_per_group = n_h // o_groups + + # Init mHC state + X = embedding.unsqueeze(1).expand(-1, n_hc, -1).clone() # (1, 4, H) + + # ============ mHC (attention) ============ + # Match HF DeepseekV4HyperConnection.forward + fn = w[f"model.layers.{li}.attn_hc.fn"] # (24, 28672) + base = w[f"model.layers.{li}.attn_hc.base"] # (24,) + scale = w[f"model.layers.{li}.attn_hc.scale"] # (3,) + + # Unweighted RMSNorm on flattened residual + X_flat = X.reshape(1, n_hc * H).float() + rms_inv = X_flat.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() + flat = X_flat * rms_inv + + # F.linear(flat, fn) and split [pre(4), post(4), comb(16)] + proj = torch.nn.functional.linear(flat.to(torch.bfloat16), fn.float().to(DEVICE)).float() + pre_w, post_w, comb_w = proj.split([n_hc, n_hc, n_hc * n_hc], dim=-1) + + # Apply scale and bias + pre_b, post_b, comb_b = base.split([n_hc, n_hc, n_hc * n_hc]) + pre_scale, post_scale, comb_scale = scale.unbind(0) + + pre_vals = torch.sigmoid(pre_w * pre_scale + pre_b) + 1e-6 # A_l + post_vals = 2.0 * torch.sigmoid(post_w * post_scale + post_b) # C_l + + # Sinkhorn on comb + comb_logits = (comb_w * comb_scale + comb_b).reshape(1, n_hc, n_hc) + comb = torch.softmax(comb_logits, dim=-1) + 1e-6 + comb = comb / (comb.sum(dim=-2, keepdim=True) + 1e-6) + for _ in range(19): # 20 total + comb = comb / (comb.sum(dim=-1, keepdim=True) + 1e-6) + comb = comb / (comb.sum(dim=-2, keepdim=True) + 1e-6) + + # collapsed = (pre * streams).sum(dim=streams) + x_in = (pre_vals.unsqueeze(-1) * X.float()).sum(dim=1).to(torch.bfloat16) # (1, H) + B_l = comb # (1, 4, 4) + C_l = post_vals # (1, 4) + + print(f" A_l: {pre_vals[0].tolist()}") + print(f" C_l: {C_l[0].tolist()}") + print(f" B row sums: {B_l[0].sum(dim=-1).tolist()}") + print(f" B col sums: {B_l[0].sum(dim=-2).tolist()}") + + # ============ RMSNorm ============ + x_normed = x_in.float() + rms_inv = x_normed.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() + norm_w = w[f"model.layers.{li}.input_layernorm.weight"].to(DEVICE).float() + x_normed = (x_normed * rms_inv * norm_w).to(torch.bfloat16) + + # ============ Q projection ============ + c_Q = nvfp4_linear(x_normed, w[f"{pre}.q_a_proj.weight"], w[f"{pre}.q_a_proj.weight_scale"], w[f"{pre}.q_a_proj.weight_scale_2"]) + # q_a_norm (weighted RMSNorm) + q_norm_w = w[f"{pre}.q_a_norm.weight"].to(DEVICE).float() + c_Q_f = c_Q.float() + c_Q_rms = c_Q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() + c_Q = (c_Q_f * c_Q_rms * q_norm_w).bfloat16() + + q = nvfp4_linear(c_Q, w[f"{pre}.q_b_proj.weight"], w[f"{pre}.q_b_proj.weight_scale"], w[f"{pre}.q_b_proj.weight_scale_2"]) + # q_b_norm (unweighted RMSNorm) + q_f = q.float() + q_rms = q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() + q = (q_f * q_rms).bfloat16() + + # ============ KV projection ============ + kv = nvfp4_linear(x_normed, w[f"{pre}.kv_proj.weight"], w[f"{pre}.kv_proj.weight_scale"], w[f"{pre}.kv_proj.weight_scale_2"]) + kv_norm_w = w[f"{pre}.kv_norm.weight"].to(DEVICE).float() + kv_f = kv.float() + kv_rms = kv_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() + kv = (kv_f * kv_rms * kv_norm_w).bfloat16() + + print(f" |c_Q|={c_Q.abs().max().item():.4f} |q|={q.abs().max().item():.4f} |kv|={kv.abs().max().item():.4f}") + + # ============ Attention (self-attention for single token) ============ + q_heads = q.reshape(1, n_h, hd) # (1, n_h, hd) + kv_heads = kv.reshape(1, 1, hd) # (1, 1, hd) — 1 KV head + + # For single token, self-attention is trivially identity (weight=1 on self) + # V = K (DSV4 MQA), so attn_out = V = K for single token + attn_out = kv_heads.expand(1, n_h, hd) # (1, n_h, hd) — just V + + # Inverse RoPE would be applied here, but for single token with no RoPE (position 0, cos=1, sin=0), + # RoPE is identity and inverse RoPE is also identity. + + # ============ Output projection ============ + attn_flat = attn_out.reshape(1, n_h * hd) + attn_grouped = attn_flat.reshape(1, o_groups, heads_per_group * hd) + oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16().to(DEVICE) + oa_3d = oa_w.reshape(o_groups, o_rank, heads_per_group * hd) + attn_for_bmm = attn_grouped.permute(1, 0, 2) + grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2)) + grouped_flat = grouped_out.permute(1, 0, 2).reshape(1, o_groups * o_rank) + F_attn = nvfp4_linear(grouped_flat, w[f"{pre}.o_b_proj.weight"], w[f"{pre}.o_b_proj.weight_scale"], w[f"{pre}.o_b_proj.weight_scale_2"]) + + print(f" |F_attn|={F_attn.abs().max().item():.4f} mean={F_attn.float().abs().mean().item():.6f}") + + # ============ mHC post_block ============ + # X_next = C_l * F_attn + B_l.T @ X + BX = torch.bmm(B_l.transpose(-1, -2), X.float()) + CF = C_l.unsqueeze(-1) * F_attn.unsqueeze(1) + X_mid = (CF.float() + BX).to(torch.bfloat16) + + print(f" |X_mid|={X_mid.abs().max().item():.4f} stream0_mean={X_mid[:,0,:].float().abs().mean().item():.6f}") + + # ============ FFN (shared expert only for simplicity) ============ + # FFN mHC + fn_ffn = w[f"model.layers.{li}.ffn_hc.fn"] + base_ffn = w[f"model.layers.{li}.ffn_hc.base"] + scale_ffn = w[f"model.layers.{li}.ffn_hc.scale"] + + X_flat_ffn = X_mid.reshape(1, n_hc * H).float() + rms_inv_ffn = X_flat_ffn.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() + flat_ffn = X_flat_ffn * rms_inv_ffn + + proj_ffn = torch.nn.functional.linear(flat_ffn.to(torch.bfloat16), fn_ffn.float().to(DEVICE)).float() + pre_w_f, post_w_f, comb_w_f = proj_ffn.split([n_hc, n_hc, n_hc * n_hc], dim=-1) + + pre_b_f, post_b_f, comb_b_f = base_ffn.split([n_hc, n_hc, n_hc * n_hc]) + pre_s_f, post_s_f, comb_s_f = scale_ffn.unbind(0) + + pre_vals_f = torch.sigmoid(pre_w_f * pre_s_f + pre_b_f) + 1e-6 + post_vals_f = 2.0 * torch.sigmoid(post_w_f * post_s_f + post_b_f) + + comb_logits_f = (comb_w_f * comb_s_f + comb_b_f).reshape(1, n_hc, n_hc) + comb_f = torch.softmax(comb_logits_f, dim=-1) + 1e-6 + comb_f = comb_f / (comb_f.sum(dim=-2, keepdim=True) + 1e-6) + for _ in range(19): + comb_f = comb_f / (comb_f.sum(dim=-1, keepdim=True) + 1e-6) + comb_f = comb_f / (comb_f.sum(dim=-2, keepdim=True) + 1e-6) + + x_ffn = (pre_vals_f.unsqueeze(-1) * X_mid.float()).sum(dim=1).to(torch.bfloat16) + + # FFN RMSNorm + norm_w_ffn = w[f"model.layers.{li}.post_attention_layernorm.weight"].to(DEVICE).float() + x_ffn_n = x_ffn.float() + rms_ffn = x_ffn_n.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() + x_ffn_n = (x_ffin_n * rms_ffn * norm_w_ffn).to(torch.bfloat16) + + # Shared expert + se_pre = f"model.layers.{li}.mlp.shared_experts" + gate = nvfp4_linear(x_ffn_n, w[f"{se_pre}.gate_proj.weight"], w[f"{se_pre}.gate_proj.weight_scale"], w[f"{se_pre}.gate_proj.weight_scale_2"]) + up = nvfp4_linear(x_ffn_n, w[f"{se_pre}.up_proj.weight"], w[f"{se_pre}.up_proj.weight_scale"], w[f"{se_pre}.up_proj.weight_scale_2"]) + hidden = (torch.nn.functional.silu(gate.float()) * up.float()).bfloat16() + shared_out = nvfp4_linear(hidden, w[f"{se_pre}.down_proj.weight"], w[f"{se_pre}.down_proj.weight_scale"], w[f"{se_pre}.down_proj.weight_scale_2"]) + + # mHC post (FFN) + BX_f = torch.bmm(comb_f.transpose(-1, -2), X_mid.float()) + CF_f = post_vals_f.unsqueeze(-1) * shared_out.unsqueeze(1) + X_next = (CF_f.float() + BX_f).to(torch.bfloat16) + + print(f" |X_next|={X_next.abs().max().item():.4f}") + return X_next + + +if __name__ == "__main__": + with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f: + cfg = json.load(f) + + print("Loading weights...") + w = load_weights() + + # Embed "The" + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR) + tid = tokenizer.encode("The")[-1] + embed_w = w["model.embed_tokens.weight"].bfloat16().to(DEVICE) + embed = torch.nn.functional.embedding(torch.tensor([tid], device=DEVICE), embed_w) + + print(f"\nProcessing 'The' (id={tid}) through layer 0:") + X_out = reference_layer0(embed, w, cfg) + print(f"\nOutput: |X|={X_out.abs().max().item():.4f}")