#!/usr/bin/env python3 """Compare our single-shot inference with the HuggingFace reference for layer 0. This script processes a single token through just layer 0 and compares the output with a pure PyTorch reference implementation that matches the HF model exactly. Usage (on B200): python3 tests/compare_layer0.py """ import os, sys, json, math, torch from pathlib import Path # Add kernel to path sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel') CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" DEVICE = "cuda:0" def load_weights(): from safetensors.torch import load_file cdir = Path(CHECKPOINT_DIR) index_path = cdir / "model.safetensors.index.json" weight_map = {} if index_path.exists(): with open(index_path) as f: weight_map = json.load(f).get("weight_map", {}) shard_names = set(weight_map.values()) if weight_map else { f"model-{i:05d}-of-00095.safetensors" for i in range(1, 96) } all_w = {} for shard_name in sorted(shard_names): if not (cdir / shard_name).exists(): continue data = load_file(str(cdir / shard_name)) for k, v in data.items(): if k.startswith("model.layers.0.") or k in ["model.embed_tokens.weight", "model.norm.weight", "lm_head.weight"]: all_w[k] = v return all_w # ===================================================================== # FP4 dequant # ===================================================================== FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]) def dequant_nvfp4(weight, weight_scale, weight_scale_2): out_dim = weight.shape[0] in_packed = weight.shape[1] in_features = in_packed * 2 low = (weight & 0x0F).to(torch.int8) high = (weight >> 4).to(torch.int8) low_sign, low_idx = (low >> 3).bool(), (low & 0x07).long() high_sign, high_idx = (high >> 3).bool(), (high & 0x07).long() lut = FP4_LUT.to(device=weight.device, dtype=torch.float32) low_f = lut[low_idx] * torch.where(low_sign, -1.0, 1.0) high_f = lut[high_idx] * torch.where(high_sign, -1.0, 1.0) w_f = torch.stack([low_f, high_f], dim=-1).reshape(out_dim, in_features) scale_f = weight_scale.float() * weight_scale_2.float() scale_expanded = scale_f.repeat_interleave(16, dim=1) return (w_f * scale_expanded).bfloat16() def nvfp4_linear(x, weight, weight_scale, weight_scale_2): w = dequant_nvfp4(weight, weight_scale, weight_scale_2) return torch.nn.functional.linear(x, w) # ===================================================================== # Reference: pure PyTorch layer 0 # ===================================================================== def reference_layer0(embedding, w, cfg): """Process one token through layer 0 using pure PyTorch (matching HF).""" li = 0 pre = f"model.layers.{li}.self_attn" n_h = cfg["num_attention_heads"] # 128 hd = cfg["head_dim"] # 512 rd = cfg.get("qk_rope_head_dim", 64) # 64 H = cfg["hidden_size"] # 7168 o_groups = cfg.get("o_groups", 16) o_rank = cfg.get("o_group_dim", 1024) n_hc = 4 heads_per_group = n_h // o_groups # Init mHC state X = embedding.unsqueeze(1).expand(-1, n_hc, -1).clone() # (1, 4, H) # ============ mHC (attention) ============ # Match HF DeepseekV4HyperConnection.forward fn = w[f"model.layers.{li}.attn_hc.fn"] # (24, 28672) base = w[f"model.layers.{li}.attn_hc.base"] # (24,) scale = w[f"model.layers.{li}.attn_hc.scale"] # (3,) # Unweighted RMSNorm on flattened residual X_flat = X.reshape(1, n_hc * H).float() rms_inv = X_flat.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() flat = X_flat * rms_inv # F.linear(flat, fn) and split [pre(4), post(4), comb(16)] proj = torch.nn.functional.linear(flat.to(torch.bfloat16), fn.float().to(DEVICE)).float() pre_w, post_w, comb_w = proj.split([n_hc, n_hc, n_hc * n_hc], dim=-1) # Apply scale and bias pre_b, post_b, comb_b = base.split([n_hc, n_hc, n_hc * n_hc]) pre_scale, post_scale, comb_scale = scale.unbind(0) pre_vals = torch.sigmoid(pre_w * pre_scale + pre_b) + 1e-6 # A_l post_vals = 2.0 * torch.sigmoid(post_w * post_scale + post_b) # C_l # Sinkhorn on comb comb_logits = (comb_w * comb_scale + comb_b).reshape(1, n_hc, n_hc) comb = torch.softmax(comb_logits, dim=-1) + 1e-6 comb = comb / (comb.sum(dim=-2, keepdim=True) + 1e-6) for _ in range(19): # 20 total comb = comb / (comb.sum(dim=-1, keepdim=True) + 1e-6) comb = comb / (comb.sum(dim=-2, keepdim=True) + 1e-6) # collapsed = (pre * streams).sum(dim=streams) x_in = (pre_vals.unsqueeze(-1) * X.float()).sum(dim=1).to(torch.bfloat16) # (1, H) B_l = comb # (1, 4, 4) C_l = post_vals # (1, 4) print(f" A_l: {pre_vals[0].tolist()}") print(f" C_l: {C_l[0].tolist()}") print(f" B row sums: {B_l[0].sum(dim=-1).tolist()}") print(f" B col sums: {B_l[0].sum(dim=-2).tolist()}") # ============ RMSNorm ============ x_normed = x_in.float() rms_inv = x_normed.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() norm_w = w[f"model.layers.{li}.input_layernorm.weight"].to(DEVICE).float() x_normed = (x_normed * rms_inv * norm_w).to(torch.bfloat16) # ============ Q projection ============ c_Q = nvfp4_linear(x_normed, w[f"{pre}.q_a_proj.weight"], w[f"{pre}.q_a_proj.weight_scale"], w[f"{pre}.q_a_proj.weight_scale_2"]) # q_a_norm (weighted RMSNorm) q_norm_w = w[f"{pre}.q_a_norm.weight"].to(DEVICE).float() c_Q_f = c_Q.float() c_Q_rms = c_Q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() c_Q = (c_Q_f * c_Q_rms * q_norm_w).bfloat16() q = nvfp4_linear(c_Q, w[f"{pre}.q_b_proj.weight"], w[f"{pre}.q_b_proj.weight_scale"], w[f"{pre}.q_b_proj.weight_scale_2"]) # q_b_norm (unweighted RMSNorm) q_f = q.float() q_rms = q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() q = (q_f * q_rms).bfloat16() # ============ KV projection ============ kv = nvfp4_linear(x_normed, w[f"{pre}.kv_proj.weight"], w[f"{pre}.kv_proj.weight_scale"], w[f"{pre}.kv_proj.weight_scale_2"]) kv_norm_w = w[f"{pre}.kv_norm.weight"].to(DEVICE).float() kv_f = kv.float() kv_rms = kv_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() kv = (kv_f * kv_rms * kv_norm_w).bfloat16() print(f" |c_Q|={c_Q.abs().max().item():.4f} |q|={q.abs().max().item():.4f} |kv|={kv.abs().max().item():.4f}") # ============ Attention (self-attention for single token) ============ q_heads = q.reshape(1, n_h, hd) # (1, n_h, hd) kv_heads = kv.reshape(1, 1, hd) # (1, 1, hd) — 1 KV head # For single token, self-attention is trivially identity (weight=1 on self) # V = K (DSV4 MQA), so attn_out = V = K for single token attn_out = kv_heads.expand(1, n_h, hd) # (1, n_h, hd) — just V # Inverse RoPE would be applied here, but for single token with no RoPE (position 0, cos=1, sin=0), # RoPE is identity and inverse RoPE is also identity. # ============ Output projection ============ attn_flat = attn_out.reshape(1, n_h * hd) attn_grouped = attn_flat.reshape(1, o_groups, heads_per_group * hd) oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16().to(DEVICE) oa_3d = oa_w.reshape(o_groups, o_rank, heads_per_group * hd) attn_for_bmm = attn_grouped.permute(1, 0, 2) grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2)) grouped_flat = grouped_out.permute(1, 0, 2).reshape(1, o_groups * o_rank) F_attn = nvfp4_linear(grouped_flat, w[f"{pre}.o_b_proj.weight"], w[f"{pre}.o_b_proj.weight_scale"], w[f"{pre}.o_b_proj.weight_scale_2"]) print(f" |F_attn|={F_attn.abs().max().item():.4f} mean={F_attn.float().abs().mean().item():.6f}") # ============ mHC post_block ============ # X_next = C_l * F_attn + B_l.T @ X BX = torch.bmm(B_l.transpose(-1, -2), X.float()) CF = C_l.unsqueeze(-1) * F_attn.unsqueeze(1) X_mid = (CF.float() + BX).to(torch.bfloat16) print(f" |X_mid|={X_mid.abs().max().item():.4f} stream0_mean={X_mid[:,0,:].float().abs().mean().item():.6f}") # ============ FFN (shared expert only for simplicity) ============ # FFN mHC fn_ffn = w[f"model.layers.{li}.ffn_hc.fn"] base_ffn = w[f"model.layers.{li}.ffn_hc.base"] scale_ffn = w[f"model.layers.{li}.ffn_hc.scale"] X_flat_ffn = X_mid.reshape(1, n_hc * H).float() rms_inv_ffn = X_flat_ffn.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() flat_ffn = X_flat_ffn * rms_inv_ffn proj_ffn = torch.nn.functional.linear(flat_ffn.to(torch.bfloat16), fn_ffn.float().to(DEVICE)).float() pre_w_f, post_w_f, comb_w_f = proj_ffn.split([n_hc, n_hc, n_hc * n_hc], dim=-1) pre_b_f, post_b_f, comb_b_f = base_ffn.split([n_hc, n_hc, n_hc * n_hc]) pre_s_f, post_s_f, comb_s_f = scale_ffn.unbind(0) pre_vals_f = torch.sigmoid(pre_w_f * pre_s_f + pre_b_f) + 1e-6 post_vals_f = 2.0 * torch.sigmoid(post_w_f * post_s_f + post_b_f) comb_logits_f = (comb_w_f * comb_s_f + comb_b_f).reshape(1, n_hc, n_hc) comb_f = torch.softmax(comb_logits_f, dim=-1) + 1e-6 comb_f = comb_f / (comb_f.sum(dim=-2, keepdim=True) + 1e-6) for _ in range(19): comb_f = comb_f / (comb_f.sum(dim=-1, keepdim=True) + 1e-6) comb_f = comb_f / (comb_f.sum(dim=-2, keepdim=True) + 1e-6) x_ffn = (pre_vals_f.unsqueeze(-1) * X_mid.float()).sum(dim=1).to(torch.bfloat16) # FFN RMSNorm norm_w_ffn = w[f"model.layers.{li}.post_attention_layernorm.weight"].to(DEVICE).float() x_ffn_n = x_ffn.float() rms_ffn = x_ffn_n.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() x_ffn_n = (x_ffin_n * rms_ffn * norm_w_ffn).to(torch.bfloat16) # Shared expert se_pre = f"model.layers.{li}.mlp.shared_experts" gate = nvfp4_linear(x_ffn_n, w[f"{se_pre}.gate_proj.weight"], w[f"{se_pre}.gate_proj.weight_scale"], w[f"{se_pre}.gate_proj.weight_scale_2"]) up = nvfp4_linear(x_ffn_n, w[f"{se_pre}.up_proj.weight"], w[f"{se_pre}.up_proj.weight_scale"], w[f"{se_pre}.up_proj.weight_scale_2"]) hidden = (torch.nn.functional.silu(gate.float()) * up.float()).bfloat16() shared_out = nvfp4_linear(hidden, w[f"{se_pre}.down_proj.weight"], w[f"{se_pre}.down_proj.weight_scale"], w[f"{se_pre}.down_proj.weight_scale_2"]) # mHC post (FFN) BX_f = torch.bmm(comb_f.transpose(-1, -2), X_mid.float()) CF_f = post_vals_f.unsqueeze(-1) * shared_out.unsqueeze(1) X_next = (CF_f.float() + BX_f).to(torch.bfloat16) print(f" |X_next|={X_next.abs().max().item():.4f}") return X_next if __name__ == "__main__": with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f: cfg = json.load(f) print("Loading weights...") w = load_weights() # Embed "The" from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR) tid = tokenizer.encode("The")[-1] embed_w = w["model.embed_tokens.weight"].bfloat16().to(DEVICE) embed = torch.nn.functional.embedding(torch.tensor([tid], device=DEVICE), embed_w) print(f"\nProcessing 'The' (id={tid}) through layer 0:") X_out = reference_layer0(embed, w, cfg) print(f"\nOutput: |X|={X_out.abs().max().item():.4f}")