Files
nvfp4-megamoe-kernel/tests/validate_layer.py

330 lines
14 KiB
Python

#!/usr/bin/env python3
"""Per-layer validation: compare forward_layer output with a step-by-step reference.
This test takes a known embedding, processes it through a SINGLE layer,
and compares the output at each intermediate step between the production
forward_layer function and a standalone PyTorch reference.
Usage (on B200):
source /root/dsv4-nvfp4-workspace/venv/bin/activate
cd /root/dsv4-nvfp4-workspace/kernel
python3 tests/validate_layer.py --layer 0
"""
import os, sys, json, math, argparse, torch
from pathlib import Path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
# =====================================================================
# NVFP4 dequantization
# =====================================================================
FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0])
def dequant_nvfp4(weight, weight_scale, weight_scale_2):
out_dim = weight.shape[0]
in_packed = weight.shape[1]
in_features = in_packed * 2
low = (weight & 0x0F).to(torch.int8)
high = (weight >> 4).to(torch.int8)
low_sign, low_idx = (low >> 3).bool(), (low & 0x07).long()
high_sign, high_idx = (high >> 3).bool(), (high & 0x07).long()
lut = FP4_LUT.to(device=weight.device, dtype=torch.float32)
low_f = lut[low_idx] * torch.where(low_sign, -1.0, 1.0)
high_f = lut[high_idx] * torch.where(high_sign, -1.0, 1.0)
w_f = torch.stack([low_f, high_f], dim=-1).reshape(out_dim, in_features)
scale_f = weight_scale.float() * weight_scale_2.float()
scale_expanded = scale_f.repeat_interleave(16, dim=1)
return (w_f * scale_expanded).bfloat16()
def nvfp4_linear(x, weight, weight_scale, weight_scale_2):
w = dequant_nvfp4(weight, weight_scale, weight_scale_2)
return torch.nn.functional.linear(x, w)
def rmsnorm(x, weight, eps=1e-6):
"""Weighted RMSNorm matching HF DeepseekV4RMSNorm."""
x_f = x.float()
rms_inv = x_f.pow(2).mean(-1, keepdim=True).add(eps).rsqrt()
return (x_f * rms_inv * weight.float()).to(torch.bfloat16)
def unweighted_rmsnorm(x, eps=1e-6):
"""Unweighted RMSNorm matching HF DeepseekV4UnweightedRMSNorm."""
x_f = x.float()
rms_inv = x_f.pow(2).mean(-1, keepdim=True).add(eps).rsqrt()
return (x_f * rms_inv).to(torch.bfloat16)
def sinkhorn(logits, t_max=20, eps=1e-6):
"""Sinkhorn-Knopp from softmax, matching HF."""
M = torch.softmax(logits, dim=-1) + eps
M = M / (M.sum(dim=-2, keepdim=True) + eps)
for _ in range(t_max - 1):
M = M / (M.sum(dim=-1, keepdim=True) + eps)
M = M / (M.sum(dim=-2, keepdim=True) + eps)
return M
def build_rope_cache(max_pos, rope_dim, device, theta=10000.0,
rope_type="default", rope_factor=1.0,
original_max_pos=4096, beta_fast=32, beta_slow=1):
half = rope_dim // 2
freqs = 1.0 / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim))
if rope_type == "yarn" and rope_factor > 1.0:
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < original_max_pos / (beta_fast * 2.0):
new_freqs.append(freq)
elif wavelen > original_max_pos / (beta_slow * 2.0):
new_freqs.append(freq / rope_factor)
else:
smooth = (original_max_pos / (wavelen * beta_slow) - rope_factor) / (
rope_factor * (beta_fast / beta_slow - 1))
new_freqs.append((1 - smooth) * freq / rope_factor + smooth * freq)
freqs = torch.tensor(new_freqs, dtype=torch.float32)
angles = torch.outer(torch.arange(max_pos, dtype=torch.float32), freqs)
return torch.cos(angles).to(device), torch.sin(angles).to(device)
def apply_rope_partial(x, positions, cos_cache, sin_cache, head_dim, rope_dim):
T, n_h, hd = x.shape
nope = hd - rope_dim
cos = cos_cache[positions].unsqueeze(1)
sin = sin_cache[positions].unsqueeze(1)
x_rope = x[:, :, nope:].float()
x_even = x_rope[..., 0::2]
x_odd = x_rope[..., 1::2]
rot_even = x_even * cos - x_odd * sin
rot_odd = x_even * sin + x_odd * cos
result = x.clone()
rope_out = torch.empty_like(x_rope)
rope_out[..., 0::2] = rot_even
rope_out[..., 1::2] = rot_odd
result[:, :, nope:] = rope_out.to(torch.bfloat16)
return result
def apply_inverse_rope(o, positions, cos_cache, sin_cache, head_dim, rope_dim):
T, n_h, hd = o.shape
nope = hd - rope_dim
cos = cos_cache[positions].unsqueeze(1)
sin = sin_cache[positions].unsqueeze(1)
o_rope = o[:, :, nope:].float()
o_even = o_rope[..., 0::2]
o_odd = o_rope[..., 1::2]
inv_even = o_even * cos + o_odd * sin
inv_odd = -o_even * sin + o_odd * cos
result = o.clone()
rope_out = torch.empty_like(o_rope)
rope_out[..., 0::2] = inv_even
rope_out[..., 1::2] = inv_odd
result[:, :, nope:] = rope_out.to(torch.bfloat16)
return result
def validate_layer(li, all_weights, cfg, device='cuda:0'):
"""Validate a single layer by running both forward_layer and a step-by-step reference."""
from single_shot_inference import forward_layer, mHCBlock, RMSNorm, SimpleKVCache
n_h = cfg["num_attention_heads"]
hd = cfg["head_dim"]
rd = cfg.get("qk_rope_head_dim", 64)
H = cfg["hidden_size"]
o_groups = cfg.get("o_groups", 16)
o_rank = cfg.get("o_group_dim", 1024)
n_hc = 4
pre = f"model.layers.{li}.self_attn"
# Get weights
w = all_weights # Already filtered and on device
# Build RoPE caches
rope_params = cfg.get("rope_parameters", {})
rope_type = rope_params.get("rope_type", "yarn")
rope_factor = rope_params.get("factor", 16.0)
rope_theta = rope_params.get("rope_theta", cfg.get("rope_theta", 10000.0))
original_max_pos = rope_params.get("original_max_position_embeddings", 4096)
beta_fast = rope_params.get("beta_fast", 32)
beta_slow = rope_params.get("beta_slow", 1)
rope_cos, rope_sin = build_rope_cache(
8192, rd, device, theta=rope_theta,
rope_type=rope_type, rope_factor=rope_factor,
original_max_pos=original_max_pos,
beta_fast=beta_fast, beta_slow=beta_slow
)
# Create mHC blocks
attn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=device)
ffn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=device)
fn_key = f"model.layers.{li}.attn_hc.fn"
base_key = f"model.layers.{li}.attn_hc.base"
scale_key = f"model.layers.{li}.attn_hc.scale"
if fn_key in w and base_key in w and scale_key in w:
attn_mhc.load_from_checkpoint(w[fn_key], w[base_key], w[scale_key])
fn_key = f"model.layers.{li}.ffn_hc.fn"
base_key = f"model.layers.{li}.ffn_hc.base"
scale_key = f"model.layers.{li}.ffn_hc.scale"
if fn_key in w and base_key in w and scale_key in w:
ffn_mhc.load_from_checkpoint(w[fn_key], w[base_key], w[scale_key])
attn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=device)
an_key = f"model.layers.{li}.input_layernorm.weight"
if an_key in w:
attn_norm.weight = w[an_key].to(device=device, dtype=torch.float32)
ffn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=device)
fn_key = f"model.layers.{li}.post_attention_layernorm.weight"
if fn_key in w:
ffn_norm.weight = w[fn_key].to(device=device, dtype=torch.float32)
kv_cache = SimpleKVCache(head_dim=hd, max_seq=8192, device=device)
# Create input: random embedding
torch.manual_seed(42)
X_l = torch.randn(1, n_hc, H, dtype=torch.bfloat16, device=device) * 0.5
positions = torch.tensor([0], dtype=torch.long, device=device)
token_id = torch.tensor([671], dtype=torch.long, device=device) # "The"
# Run forward_layer (production code)
X_prod = forward_layer(X_l.clone(), w, li, cfg, rope_cos, rope_sin,
attn_mhc, ffn_mhc, attn_norm, ffn_norm,
kv_cache, token_id, positions)
print(f"Production: |X_next|={X_prod.abs().max().item():.4f}")
print(f" Stream norms: {[X_prod[0,s,:].float().norm().item() for s in range(4)]}")
# ============================================================
# Step-by-step reference (same math, no wrappers)
# ============================================================
X_l_ref = X_l.clone()
# --- mHC pre_block (attention) ---
fn = w[f"model.layers.{li}.attn_hc.fn"].float().to(device)
base = w[f"model.layers.{li}.attn_hc.base"].float().to(device)
scale = w[f"model.layers.{li}.attn_hc.scale"].float().to(device)
# Unweighted RMSNorm on flattened residual
X_flat = X_l_ref.reshape(1, n_hc * H).float()
rms_inv = X_flat.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
flat = (X_flat * rms_inv).to(torch.bfloat16)
# F.linear + split [pre(4), post(4), comb(16)]
proj = torch.nn.functional.linear(flat.float(), fn).float()
pre_w, post_w, comb_w = proj.split([n_hc, n_hc, n_hc * n_hc], dim=-1)
pre_b, post_b, comb_b = base.split([n_hc, n_hc, n_hc * n_hc])
pre_s, post_s, comb_s = scale.unbind(0)
A_l = torch.sigmoid(pre_w * pre_s + pre_b) + 1e-6
C_l = 2.0 * torch.sigmoid(post_w * post_s + post_b)
B_l = sinkhorn((comb_w * comb_s + comb_b).reshape(1, n_hc, n_hc))
x_in = (A_l.unsqueeze(-1) * X_l_ref.float()).sum(dim=1).to(torch.bfloat16)
print(f"\n mHC A_l: {A_l[0].tolist()}")
print(f" mHC C_l: {C_l[0].tolist()}")
print(f" B row sums: {B_l[0].sum(-1).tolist()}")
# --- RMSNorm ---
x_normed = rmsnorm(x_in, w[f"model.layers.{li}.input_layernorm.weight"].to(device))
# --- Q projection ---
c_Q = nvfp4_linear(x_normed, w[f"{pre}.q_a_proj.weight"], w[f"{pre}.q_a_proj.weight_scale"], w[f"{pre}.q_a_proj.weight_scale_2"])
c_Q = rmsnorm(c_Q, w[f"{pre}.q_a_norm.weight"].to(device))
q = nvfp4_linear(c_Q, w[f"{pre}.q_b_proj.weight"], w[f"{pre}.q_b_proj.weight_scale"], w[f"{pre}.q_b_proj.weight_scale_2"])
q = unweighted_rmsnorm(q)
# --- KV projection ---
kv = nvfp4_linear(x_normed, w[f"{pre}.kv_proj.weight"], w[f"{pre}.kv_proj.weight_scale"], w[f"{pre}.kv_proj.weight_scale_2"])
kv = rmsnorm(kv, w[f"{pre}.kv_norm.weight"].to(device))
q_heads = q.reshape(1, n_h, hd)
kv_new = kv.reshape(1, 1, hd)
# RoPE
q_heads = apply_rope_partial(q_heads, positions, rope_cos, rope_sin, hd, rd)
kv_new = apply_rope_partial(kv_new, positions, rope_cos, rope_sin, hd, rd)
# Attention (single token → self-attention is identity)
attn_out = kv_new.expand(1, n_h, hd)
# Inverse RoPE
attn_out = apply_inverse_rope(attn_out, positions, rope_cos, rope_sin, hd, rd)
# Output projection
attn_flat = attn_out.reshape(1, n_h * hd)
attn_grouped = attn_flat.reshape(1, o_groups, (n_h // o_groups) * hd)
oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16().to(device)
oa_3d = oa_w.reshape(o_groups, o_rank, (n_h // o_groups) * hd)
grouped_out = torch.bmm(attn_grouped.permute(1, 0, 2), oa_3d.transpose(1, 2))
grouped_flat = grouped_out.permute(1, 0, 2).reshape(1, o_groups * o_rank)
F_attn = nvfp4_linear(grouped_flat, w[f"{pre}.o_b_proj.weight"], w[f"{pre}.o_b_proj.weight_scale"], w[f"{pre}.o_b_proj.weight_scale_2"])
print(f" |F_attn|={F_attn.abs().max().item():.4f}")
# mHC post_block: X_mid = C_l * F_attn + B_l.T @ X_l
BX = torch.bmm(B_l.transpose(-1, -2), X_l_ref.float())
CF = C_l.unsqueeze(-1) * F_attn.unsqueeze(1)
X_mid = (CF.float() + BX).to(torch.bfloat16)
print(f" |X_mid|={X_mid.abs().max().item():.4f}")
print(f" Stream norms (mid): {[X_mid[0,s,:].float().norm().item() for s in range(4)]}")
# Compare X_mid with production
X_prod_mid_stream0 = X_prod[0, 0, :].float()
X_ref_mid_stream0 = X_mid[0, 0, :].float()
cos_sim = torch.nn.functional.cosine_similarity(X_prod_mid_stream0.unsqueeze(0), X_ref_mid_stream0.unsqueeze(0)).item()
max_diff = (X_prod - X_mid).abs().max().item()
print(f"\n Stream 0 cosine similarity: {cos_sim:.6f}")
print(f" Max diff: {max_diff:.6f}")
if cos_sim < 0.99:
print(" ⚠️ MISMATCH! Production and reference differ significantly!")
else:
print(" ✅ Match!")
return X_mid
def main():
from safetensors.torch import load_file
p = argparse.ArgumentParser()
p.add_argument('--layer', type=int, default=0)
p.add_argument('--device', type=str, default='cuda:0')
args = p.parse_args()
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
cfg = json.load(f)
print(f"Loading weights for layer {args.layer}...")
cdir = Path(CHECKPOINT_DIR)
index_path = cdir / "model.safetensors.index.json"
weight_map = {}
if index_path.exists():
with open(index_path) as f:
weight_map = json.load(f).get("weight_map", {})
# Find which shards contain our layer
li = args.layer
prefix = f"model.layers.{li}."
needed_shards = set()
for key, shard in weight_map.items():
if key.startswith(prefix) or key in ["model.embed_tokens.weight", "model.norm.weight"]:
needed_shards.add(shard)
all_w = {}
for shard_name in sorted(needed_shards):
if not (cdir / shard_name).exists():
continue
data = load_file(str(cdir / shard_name))
for k, v in data.items():
if k.startswith(prefix) or k in ["model.embed_tokens.weight"]:
all_w[k] = v.to(device=args.device, non_blocking=True)
print(f" {len(all_w)} weights loaded")
validate_layer(li, all_w, cfg, device=args.device)
if __name__ == "__main__":
main()