Files
nvfp4-megamoe-kernel/tests/layer_compare.py

198 lines
8.0 KiB
Python

#!/usr/bin/env python3
"""Layer-by-layer comparison between our single_shot_inference and HF reference.
This test processes a single token through LAYER 0 using BOTH implementations
and compares the intermediate values to identify the exact point of divergence.
The "reference" implementation follows the HuggingFace DeepseekV4ForCausalLM
source code exactly, but using our NVFP4 dequantization for the weights.
Usage (on B200):
source /root/dsv4-nvfp4-workspace/venv/bin/activate
cd /root/dsv4-nvfp4-workspace/kernel
python tests/layer_compare.py
"""
import os, sys, json, math
import torch
import torch.nn.functional as F
from pathlib import Path
sys.path.insert(0, "/root/dsv4-nvfp4-workspace/kernel")
from single_shot_inference import (
dequant_nvfp4_weight, nvfp4_linear, RMSNorm,
apply_rope_partial, apply_inverse_rope, build_rope_cache,
SimpleKVCache, mHCBlock
)
CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
def main():
from safetensors.torch import load_file
cdir = Path(CHECKPOINT_DIR)
with open(cdir / "config.json") as f:
cfg = json.load(f)
with open(cdir / "model.safetensors.index.json") as f:
wm = json.load(f)["weight_map"]
H = cfg["hidden_size"]
n_h = cfg["num_attention_heads"]
hd = cfg["head_dim"]
rd = cfg.get("qk_rope_head_dim", 64)
dc = cfg.get("q_lora_rank", 1536)
n_hc = 4
device = "cuda:0"
# Load layer 0 weights
print("Loading layer 0 weights...")
prefix = "model.layers.0."
layer0_keys = [k for k in wm if k.startswith(prefix)]
shards_needed = set(wm[k] for k in layer0_keys)
all_w = {}
for shard in shards_needed:
data = load_file(str(cdir / shard))
for k in layer0_keys:
if k in data:
all_w[k] = data[k].to(device)
# Load embedding
embed_w = load_file(str(cdir / wm["model.embed_tokens.weight"]))["model.embed_tokens.weight"].to(device).bfloat16()
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(str(cdir))
# Process token "The"
tid = torch.tensor([tok.encode("The")[-1]], dtype=torch.long, device=device)
pos = torch.tensor([0], dtype=torch.long, device=device)
# Build RoPE cache with YaRN
rope_params = cfg.get("rope_parameters", {})
rope_cos, rope_sin = build_rope_cache(
8192, rd, device, theta=rope_params.get("rope_theta", 10000.0),
rope_type=rope_params.get("rope_type", "default"),
rope_factor=rope_params.get("factor", 1.0),
original_max_pos=rope_params.get("original_max_position_embeddings", 4096),
beta_fast=rope_params.get("beta_fast", 32),
beta_slow=rope_params.get("beta_slow", 1)
)
# Embed
emb = F.embedding(tid, embed_w) # (1, H)
print(f"Embedding: |emb|={emb.abs().max():.4f}")
# Init mHC state
X = mHCBlock.init_state(emb, n_hc) # (1, 4, H)
# Load mHC
fn = all_w[f"{prefix}attn_hc.fn"]
base = all_w[f"{prefix}attn_hc.base"]
scale = all_w[f"{prefix}attn_hc.scale"]
attn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=device)
n = n_hc
attn_mhc.load_weights(
W_pre=fn[0:n].to(device, dtype=torch.float32),
W_post=fn[n:2*n].to(device, dtype=torch.float32),
W_comb=fn[2*n:].to(device, dtype=torch.float32),
S_pre=base[0:n].reshape(1, n).to(device, dtype=torch.bfloat16),
S_post=base[n:2*n].reshape(n, 1).to(device, dtype=torch.bfloat16),
S_comb=base[2*n:].reshape(n, n).to(device, dtype=torch.bfloat16),
alpha_pre=scale[0].item(),
alpha_post=scale[1].item(),
alpha_comb=scale[2].item(),
)
# === OUR IMPLEMENTATION (single_shot_inference) ===
print("\n=== OUR IMPLEMENTATION ===")
# mHC pre_block
x_in, ctx = attn_mhc.pre_block(X)
print(f"x_in: |x_in|={x_in.abs().max():.4f} mean={x_in.float().abs().mean():.6f}")
# RMSNorm
norm = RMSNorm(H, device=device)
norm.weight = all_w[f"{prefix}input_layernorm.weight"].to(device, dtype=torch.float32)
x_norm = norm.forward(x_in)
print(f"x_norm: |x|={x_norm.abs().max():.4f} mean={x_norm.float().abs().mean():.6f}")
# Q projection: q_a → q_a_norm → q_b → q_b_norm
c_Q = nvfp4_linear(x_norm, all_w[f"{prefix}self_attn.q_a_proj.weight"],
all_w[f"{prefix}self_attn.q_a_proj.weight_scale"],
all_w[f"{prefix}self_attn.q_a_proj.weight_scale_2"])
# q_a_norm
q_norm_w = all_w.get(f"{prefix}self_attn.q_a_norm.weight")
if q_norm_w is not None:
c_Q_f = c_Q.float()
c_Q_rms = c_Q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
c_Q = (c_Q_f * c_Q_rms * q_norm_w.float()).bfloat16()
print(f"c_Q: |c_Q|={c_Q.abs().max():.4f} mean={c_Q.float().abs().mean():.6f}")
q = nvfp4_linear(c_Q, all_w[f"{prefix}self_attn.q_b_proj.weight"],
all_w[f"{prefix}self_attn.q_b_proj.weight_scale"],
all_w[f"{prefix}self_attn.q_b_proj.weight_scale_2"])
# q_b_norm (unweighted RMSNorm)
q_f = q.float()
q_rms = q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
q = (q_f * q_rms).bfloat16()
q_heads = q.reshape(1, n_h, hd)
print(f"q_heads: |q|={q_heads.abs().max():.4f} mean={q_heads.float().abs().mean():.6f}")
# KV projection
kv = nvfp4_linear(x_norm, all_w[f"{prefix}self_attn.kv_proj.weight"],
all_w[f"{prefix}self_attn.kv_proj.weight_scale"],
all_w[f"{prefix}self_attn.kv_proj.weight_scale_2"])
# kv_norm
kv_norm_w = all_w.get(f"{prefix}self_attn.kv_norm.weight")
if kv_norm_w is not None:
kv_f = kv.float()
kv_rms = kv_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
kv = (kv_f * kv_rms * kv_norm_w.float()).bfloat16()
kv_new = kv.reshape(1, 1, hd)
print(f"kv: |kv|={kv_new.abs().max():.4f} mean={kv_new.float().abs().mean():.6f}")
# Apply RoPE
q_heads = apply_rope_partial(q_heads, pos, rope_cos, rope_sin, hd, rd)
kv_new = apply_rope_partial(kv_new, pos, rope_cos, rope_sin, hd, rd)
print(f"After RoPE: |q|={q_heads.abs().max():.4f} |kv|={kv_new.abs().max():.4f}")
# Attention (single token, trivially 1.0)
q_in = q_heads.permute(1, 0, 2) # (n_h, 1, hd)
k_in = kv_new.permute(1, 0, 2) # (1, 1, hd)
k_exp = k_in.expand(n_h, -1, -1)
v_exp = k_exp.clone() # K=V in DSV4
attn_out = F.scaled_dot_product_attention(q_in, k_exp, v_exp, scale=1.0/math.sqrt(hd))
attn_out = attn_out.permute(1, 0, 2) # (1, n_h, hd)
print(f"attn_out: |o|={attn_out.abs().max():.4f} mean={attn_out.float().abs().mean():.6f}")
# Inverse RoPE
attn_out = apply_inverse_rope(attn_out, pos, rope_cos, rope_sin, hd, rd)
print(f"After inverse RoPE: |o|={attn_out.abs().max():.4f}")
# Output projection: wo_a (grouped BMM) + wo_b
o_groups = cfg.get("num_output_groups", 16)
o_rank = cfg.get("output_group_dim", 1024)
heads_per_group = n_h // o_groups
group_input_dim = heads_per_group * hd
attn_flat = attn_out.reshape(1, n_h * hd)
attn_grouped = attn_flat.reshape(1, o_groups, heads_per_group * hd)
oa_w = all_w[f"{prefix}self_attn.o_a_proj.weight"].bfloat16()
oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim)
attn_bmm = attn_grouped.permute(1, 0, 2)
grouped_out = torch.bmm(attn_bmm, oa_3d.transpose(1, 2))
grouped_flat = grouped_out.permute(1, 0, 2).reshape(1, o_groups * o_rank)
print(f"grouped_out: |o|={grouped_flat.abs().max():.4f}")
F_attn = nvfp4_linear(grouped_flat,
all_w[f"{prefix}self_attn.o_b_proj.weight"],
all_w[f"{prefix}self_attn.o_b_proj.weight_scale"],
all_w[f"{prefix}self_attn.o_b_proj.weight_scale_2"])
print(f"F_attn: |F|={F_attn.abs().max():.4f} mean={F_attn.float().abs().mean():.6f}")
# mHC post_block
X_mid = attn_mhc.post_block(X, F_attn, ctx)
print(f"X_mid: |X|={X_mid.abs().max():.4f}")
print("\nLayer 0 attention sub-block complete.")
if __name__ == "__main__":
main()