Files
nvfp4-megamoe-kernel/tests/compare_layer0.py

251 lines
11 KiB
Python

#!/usr/bin/env python3
"""Compare our single-shot inference with the HuggingFace reference for layer 0.
This script processes a single token through just layer 0 and compares the output
with a pure PyTorch reference implementation that matches the HF model exactly.
Usage (on B200):
python3 tests/compare_layer0.py
"""
import os, sys, json, math, torch
from pathlib import Path
# Add kernel to path
sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel')
CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
DEVICE = "cuda:0"
def load_weights():
from safetensors.torch import load_file
cdir = Path(CHECKPOINT_DIR)
index_path = cdir / "model.safetensors.index.json"
weight_map = {}
if index_path.exists():
with open(index_path) as f:
weight_map = json.load(f).get("weight_map", {})
shard_names = set(weight_map.values()) if weight_map else {
f"model-{i:05d}-of-00095.safetensors" for i in range(1, 96)
}
all_w = {}
for shard_name in sorted(shard_names):
if not (cdir / shard_name).exists():
continue
data = load_file(str(cdir / shard_name))
for k, v in data.items():
if k.startswith("model.layers.0.") or k in ["model.embed_tokens.weight", "model.norm.weight", "lm_head.weight"]:
all_w[k] = v
return all_w
# =====================================================================
# FP4 dequant
# =====================================================================
FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0])
def dequant_nvfp4(weight, weight_scale, weight_scale_2):
out_dim = weight.shape[0]
in_packed = weight.shape[1]
in_features = in_packed * 2
low = (weight & 0x0F).to(torch.int8)
high = (weight >> 4).to(torch.int8)
low_sign, low_idx = (low >> 3).bool(), (low & 0x07).long()
high_sign, high_idx = (high >> 3).bool(), (high & 0x07).long()
lut = FP4_LUT.to(device=weight.device, dtype=torch.float32)
low_f = lut[low_idx] * torch.where(low_sign, -1.0, 1.0)
high_f = lut[high_idx] * torch.where(high_sign, -1.0, 1.0)
w_f = torch.stack([low_f, high_f], dim=-1).reshape(out_dim, in_features)
scale_f = weight_scale.float() * weight_scale_2.float()
scale_expanded = scale_f.repeat_interleave(16, dim=1)
return (w_f * scale_expanded).bfloat16()
def nvfp4_linear(x, weight, weight_scale, weight_scale_2):
w = dequant_nvfp4(weight, weight_scale, weight_scale_2)
return torch.nn.functional.linear(x, w)
# =====================================================================
# Reference: pure PyTorch layer 0
# =====================================================================
def reference_layer0(embedding, w, cfg):
"""Process one token through layer 0 using pure PyTorch (matching HF)."""
li = 0
pre = f"model.layers.{li}.self_attn"
n_h = cfg["num_attention_heads"] # 128
hd = cfg["head_dim"] # 512
rd = cfg.get("qk_rope_head_dim", 64) # 64
H = cfg["hidden_size"] # 7168
o_groups = cfg.get("o_groups", 16)
o_rank = cfg.get("o_group_dim", 1024)
n_hc = 4
heads_per_group = n_h // o_groups
# Init mHC state
X = embedding.unsqueeze(1).expand(-1, n_hc, -1).clone() # (1, 4, H)
# ============ mHC (attention) ============
# Match HF DeepseekV4HyperConnection.forward
fn = w[f"model.layers.{li}.attn_hc.fn"] # (24, 28672)
base = w[f"model.layers.{li}.attn_hc.base"] # (24,)
scale = w[f"model.layers.{li}.attn_hc.scale"] # (3,)
# Unweighted RMSNorm on flattened residual
X_flat = X.reshape(1, n_hc * H).float()
rms_inv = X_flat.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
flat = X_flat * rms_inv
# F.linear(flat, fn) and split [pre(4), post(4), comb(16)]
proj = torch.nn.functional.linear(flat.to(torch.bfloat16), fn.float().to(DEVICE)).float()
pre_w, post_w, comb_w = proj.split([n_hc, n_hc, n_hc * n_hc], dim=-1)
# Apply scale and bias
pre_b, post_b, comb_b = base.split([n_hc, n_hc, n_hc * n_hc])
pre_scale, post_scale, comb_scale = scale.unbind(0)
pre_vals = torch.sigmoid(pre_w * pre_scale + pre_b) + 1e-6 # A_l
post_vals = 2.0 * torch.sigmoid(post_w * post_scale + post_b) # C_l
# Sinkhorn on comb
comb_logits = (comb_w * comb_scale + comb_b).reshape(1, n_hc, n_hc)
comb = torch.softmax(comb_logits, dim=-1) + 1e-6
comb = comb / (comb.sum(dim=-2, keepdim=True) + 1e-6)
for _ in range(19): # 20 total
comb = comb / (comb.sum(dim=-1, keepdim=True) + 1e-6)
comb = comb / (comb.sum(dim=-2, keepdim=True) + 1e-6)
# collapsed = (pre * streams).sum(dim=streams)
x_in = (pre_vals.unsqueeze(-1) * X.float()).sum(dim=1).to(torch.bfloat16) # (1, H)
B_l = comb # (1, 4, 4)
C_l = post_vals # (1, 4)
print(f" A_l: {pre_vals[0].tolist()}")
print(f" C_l: {C_l[0].tolist()}")
print(f" B row sums: {B_l[0].sum(dim=-1).tolist()}")
print(f" B col sums: {B_l[0].sum(dim=-2).tolist()}")
# ============ RMSNorm ============
x_normed = x_in.float()
rms_inv = x_normed.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
norm_w = w[f"model.layers.{li}.input_layernorm.weight"].to(DEVICE).float()
x_normed = (x_normed * rms_inv * norm_w).to(torch.bfloat16)
# ============ Q projection ============
c_Q = nvfp4_linear(x_normed, w[f"{pre}.q_a_proj.weight"], w[f"{pre}.q_a_proj.weight_scale"], w[f"{pre}.q_a_proj.weight_scale_2"])
# q_a_norm (weighted RMSNorm)
q_norm_w = w[f"{pre}.q_a_norm.weight"].to(DEVICE).float()
c_Q_f = c_Q.float()
c_Q_rms = c_Q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
c_Q = (c_Q_f * c_Q_rms * q_norm_w).bfloat16()
q = nvfp4_linear(c_Q, w[f"{pre}.q_b_proj.weight"], w[f"{pre}.q_b_proj.weight_scale"], w[f"{pre}.q_b_proj.weight_scale_2"])
# q_b_norm (unweighted RMSNorm)
q_f = q.float()
q_rms = q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
q = (q_f * q_rms).bfloat16()
# ============ KV projection ============
kv = nvfp4_linear(x_normed, w[f"{pre}.kv_proj.weight"], w[f"{pre}.kv_proj.weight_scale"], w[f"{pre}.kv_proj.weight_scale_2"])
kv_norm_w = w[f"{pre}.kv_norm.weight"].to(DEVICE).float()
kv_f = kv.float()
kv_rms = kv_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
kv = (kv_f * kv_rms * kv_norm_w).bfloat16()
print(f" |c_Q|={c_Q.abs().max().item():.4f} |q|={q.abs().max().item():.4f} |kv|={kv.abs().max().item():.4f}")
# ============ Attention (self-attention for single token) ============
q_heads = q.reshape(1, n_h, hd) # (1, n_h, hd)
kv_heads = kv.reshape(1, 1, hd) # (1, 1, hd) — 1 KV head
# For single token, self-attention is trivially identity (weight=1 on self)
# V = K (DSV4 MQA), so attn_out = V = K for single token
attn_out = kv_heads.expand(1, n_h, hd) # (1, n_h, hd) — just V
# Inverse RoPE would be applied here, but for single token with no RoPE (position 0, cos=1, sin=0),
# RoPE is identity and inverse RoPE is also identity.
# ============ Output projection ============
attn_flat = attn_out.reshape(1, n_h * hd)
attn_grouped = attn_flat.reshape(1, o_groups, heads_per_group * hd)
oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16().to(DEVICE)
oa_3d = oa_w.reshape(o_groups, o_rank, heads_per_group * hd)
attn_for_bmm = attn_grouped.permute(1, 0, 2)
grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2))
grouped_flat = grouped_out.permute(1, 0, 2).reshape(1, o_groups * o_rank)
F_attn = nvfp4_linear(grouped_flat, w[f"{pre}.o_b_proj.weight"], w[f"{pre}.o_b_proj.weight_scale"], w[f"{pre}.o_b_proj.weight_scale_2"])
print(f" |F_attn|={F_attn.abs().max().item():.4f} mean={F_attn.float().abs().mean().item():.6f}")
# ============ mHC post_block ============
# X_next = C_l * F_attn + B_l.T @ X
BX = torch.bmm(B_l.transpose(-1, -2), X.float())
CF = C_l.unsqueeze(-1) * F_attn.unsqueeze(1)
X_mid = (CF.float() + BX).to(torch.bfloat16)
print(f" |X_mid|={X_mid.abs().max().item():.4f} stream0_mean={X_mid[:,0,:].float().abs().mean().item():.6f}")
# ============ FFN (shared expert only for simplicity) ============
# FFN mHC
fn_ffn = w[f"model.layers.{li}.ffn_hc.fn"]
base_ffn = w[f"model.layers.{li}.ffn_hc.base"]
scale_ffn = w[f"model.layers.{li}.ffn_hc.scale"]
X_flat_ffn = X_mid.reshape(1, n_hc * H).float()
rms_inv_ffn = X_flat_ffn.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
flat_ffn = X_flat_ffn * rms_inv_ffn
proj_ffn = torch.nn.functional.linear(flat_ffn.to(torch.bfloat16), fn_ffn.float().to(DEVICE)).float()
pre_w_f, post_w_f, comb_w_f = proj_ffn.split([n_hc, n_hc, n_hc * n_hc], dim=-1)
pre_b_f, post_b_f, comb_b_f = base_ffn.split([n_hc, n_hc, n_hc * n_hc])
pre_s_f, post_s_f, comb_s_f = scale_ffn.unbind(0)
pre_vals_f = torch.sigmoid(pre_w_f * pre_s_f + pre_b_f) + 1e-6
post_vals_f = 2.0 * torch.sigmoid(post_w_f * post_s_f + post_b_f)
comb_logits_f = (comb_w_f * comb_s_f + comb_b_f).reshape(1, n_hc, n_hc)
comb_f = torch.softmax(comb_logits_f, dim=-1) + 1e-6
comb_f = comb_f / (comb_f.sum(dim=-2, keepdim=True) + 1e-6)
for _ in range(19):
comb_f = comb_f / (comb_f.sum(dim=-1, keepdim=True) + 1e-6)
comb_f = comb_f / (comb_f.sum(dim=-2, keepdim=True) + 1e-6)
x_ffn = (pre_vals_f.unsqueeze(-1) * X_mid.float()).sum(dim=1).to(torch.bfloat16)
# FFN RMSNorm
norm_w_ffn = w[f"model.layers.{li}.post_attention_layernorm.weight"].to(DEVICE).float()
x_ffn_n = x_ffn.float()
rms_ffn = x_ffn_n.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
x_ffn_n = (x_ffin_n * rms_ffn * norm_w_ffn).to(torch.bfloat16)
# Shared expert
se_pre = f"model.layers.{li}.mlp.shared_experts"
gate = nvfp4_linear(x_ffn_n, w[f"{se_pre}.gate_proj.weight"], w[f"{se_pre}.gate_proj.weight_scale"], w[f"{se_pre}.gate_proj.weight_scale_2"])
up = nvfp4_linear(x_ffn_n, w[f"{se_pre}.up_proj.weight"], w[f"{se_pre}.up_proj.weight_scale"], w[f"{se_pre}.up_proj.weight_scale_2"])
hidden = (torch.nn.functional.silu(gate.float()) * up.float()).bfloat16()
shared_out = nvfp4_linear(hidden, w[f"{se_pre}.down_proj.weight"], w[f"{se_pre}.down_proj.weight_scale"], w[f"{se_pre}.down_proj.weight_scale_2"])
# mHC post (FFN)
BX_f = torch.bmm(comb_f.transpose(-1, -2), X_mid.float())
CF_f = post_vals_f.unsqueeze(-1) * shared_out.unsqueeze(1)
X_next = (CF_f.float() + BX_f).to(torch.bfloat16)
print(f" |X_next|={X_next.abs().max().item():.4f}")
return X_next
if __name__ == "__main__":
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
cfg = json.load(f)
print("Loading weights...")
w = load_weights()
# Embed "The"
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
tid = tokenizer.encode("The")[-1]
embed_w = w["model.embed_tokens.weight"].bfloat16().to(DEVICE)
embed = torch.nn.functional.embedding(torch.tensor([tid], device=DEVICE), embed_w)
print(f"\nProcessing 'The' (id={tid}) through layer 0:")
X_out = reference_layer0(embed, w, cfg)
print(f"\nOutput: |X|={X_out.abs().max().item():.4f}")