330 lines
14 KiB
Python
330 lines
14 KiB
Python
#!/usr/bin/env python3
|
|
"""Per-layer validation: compare forward_layer output with a step-by-step reference.
|
|
|
|
This test takes a known embedding, processes it through a SINGLE layer,
|
|
and compares the output at each intermediate step between the production
|
|
forward_layer function and a standalone PyTorch reference.
|
|
|
|
Usage (on B200):
|
|
source /root/dsv4-nvfp4-workspace/venv/bin/activate
|
|
cd /root/dsv4-nvfp4-workspace/kernel
|
|
python3 tests/validate_layer.py --layer 0
|
|
"""
|
|
import os, sys, json, math, argparse, torch
|
|
from pathlib import Path
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
|
|
|
# =====================================================================
|
|
# NVFP4 dequantization
|
|
# =====================================================================
|
|
FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0])
|
|
|
|
def dequant_nvfp4(weight, weight_scale, weight_scale_2):
|
|
out_dim = weight.shape[0]
|
|
in_packed = weight.shape[1]
|
|
in_features = in_packed * 2
|
|
low = (weight & 0x0F).to(torch.int8)
|
|
high = (weight >> 4).to(torch.int8)
|
|
low_sign, low_idx = (low >> 3).bool(), (low & 0x07).long()
|
|
high_sign, high_idx = (high >> 3).bool(), (high & 0x07).long()
|
|
lut = FP4_LUT.to(device=weight.device, dtype=torch.float32)
|
|
low_f = lut[low_idx] * torch.where(low_sign, -1.0, 1.0)
|
|
high_f = lut[high_idx] * torch.where(high_sign, -1.0, 1.0)
|
|
w_f = torch.stack([low_f, high_f], dim=-1).reshape(out_dim, in_features)
|
|
scale_f = weight_scale.float() * weight_scale_2.float()
|
|
scale_expanded = scale_f.repeat_interleave(16, dim=1)
|
|
return (w_f * scale_expanded).bfloat16()
|
|
|
|
def nvfp4_linear(x, weight, weight_scale, weight_scale_2):
|
|
w = dequant_nvfp4(weight, weight_scale, weight_scale_2)
|
|
return torch.nn.functional.linear(x, w)
|
|
|
|
def rmsnorm(x, weight, eps=1e-6):
|
|
"""Weighted RMSNorm matching HF DeepseekV4RMSNorm."""
|
|
x_f = x.float()
|
|
rms_inv = x_f.pow(2).mean(-1, keepdim=True).add(eps).rsqrt()
|
|
return (x_f * rms_inv * weight.float()).to(torch.bfloat16)
|
|
|
|
def unweighted_rmsnorm(x, eps=1e-6):
|
|
"""Unweighted RMSNorm matching HF DeepseekV4UnweightedRMSNorm."""
|
|
x_f = x.float()
|
|
rms_inv = x_f.pow(2).mean(-1, keepdim=True).add(eps).rsqrt()
|
|
return (x_f * rms_inv).to(torch.bfloat16)
|
|
|
|
|
|
def sinkhorn(logits, t_max=20, eps=1e-6):
|
|
"""Sinkhorn-Knopp from softmax, matching HF."""
|
|
M = torch.softmax(logits, dim=-1) + eps
|
|
M = M / (M.sum(dim=-2, keepdim=True) + eps)
|
|
for _ in range(t_max - 1):
|
|
M = M / (M.sum(dim=-1, keepdim=True) + eps)
|
|
M = M / (M.sum(dim=-2, keepdim=True) + eps)
|
|
return M
|
|
|
|
|
|
def build_rope_cache(max_pos, rope_dim, device, theta=10000.0,
|
|
rope_type="default", rope_factor=1.0,
|
|
original_max_pos=4096, beta_fast=32, beta_slow=1):
|
|
half = rope_dim // 2
|
|
freqs = 1.0 / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim))
|
|
if rope_type == "yarn" and rope_factor > 1.0:
|
|
new_freqs = []
|
|
for freq in freqs:
|
|
wavelen = 2 * math.pi / freq
|
|
if wavelen < original_max_pos / (beta_fast * 2.0):
|
|
new_freqs.append(freq)
|
|
elif wavelen > original_max_pos / (beta_slow * 2.0):
|
|
new_freqs.append(freq / rope_factor)
|
|
else:
|
|
smooth = (original_max_pos / (wavelen * beta_slow) - rope_factor) / (
|
|
rope_factor * (beta_fast / beta_slow - 1))
|
|
new_freqs.append((1 - smooth) * freq / rope_factor + smooth * freq)
|
|
freqs = torch.tensor(new_freqs, dtype=torch.float32)
|
|
angles = torch.outer(torch.arange(max_pos, dtype=torch.float32), freqs)
|
|
return torch.cos(angles).to(device), torch.sin(angles).to(device)
|
|
|
|
|
|
def apply_rope_partial(x, positions, cos_cache, sin_cache, head_dim, rope_dim):
|
|
T, n_h, hd = x.shape
|
|
nope = hd - rope_dim
|
|
cos = cos_cache[positions].unsqueeze(1)
|
|
sin = sin_cache[positions].unsqueeze(1)
|
|
x_rope = x[:, :, nope:].float()
|
|
x_even = x_rope[..., 0::2]
|
|
x_odd = x_rope[..., 1::2]
|
|
rot_even = x_even * cos - x_odd * sin
|
|
rot_odd = x_even * sin + x_odd * cos
|
|
result = x.clone()
|
|
rope_out = torch.empty_like(x_rope)
|
|
rope_out[..., 0::2] = rot_even
|
|
rope_out[..., 1::2] = rot_odd
|
|
result[:, :, nope:] = rope_out.to(torch.bfloat16)
|
|
return result
|
|
|
|
|
|
def apply_inverse_rope(o, positions, cos_cache, sin_cache, head_dim, rope_dim):
|
|
T, n_h, hd = o.shape
|
|
nope = hd - rope_dim
|
|
cos = cos_cache[positions].unsqueeze(1)
|
|
sin = sin_cache[positions].unsqueeze(1)
|
|
o_rope = o[:, :, nope:].float()
|
|
o_even = o_rope[..., 0::2]
|
|
o_odd = o_rope[..., 1::2]
|
|
inv_even = o_even * cos + o_odd * sin
|
|
inv_odd = -o_even * sin + o_odd * cos
|
|
result = o.clone()
|
|
rope_out = torch.empty_like(o_rope)
|
|
rope_out[..., 0::2] = inv_even
|
|
rope_out[..., 1::2] = inv_odd
|
|
result[:, :, nope:] = rope_out.to(torch.bfloat16)
|
|
return result
|
|
|
|
|
|
def validate_layer(li, all_weights, cfg, device='cuda:0'):
|
|
"""Validate a single layer by running both forward_layer and a step-by-step reference."""
|
|
from single_shot_inference import forward_layer, mHCBlock, RMSNorm, SimpleKVCache
|
|
|
|
n_h = cfg["num_attention_heads"]
|
|
hd = cfg["head_dim"]
|
|
rd = cfg.get("qk_rope_head_dim", 64)
|
|
H = cfg["hidden_size"]
|
|
o_groups = cfg.get("o_groups", 16)
|
|
o_rank = cfg.get("o_group_dim", 1024)
|
|
n_hc = 4
|
|
pre = f"model.layers.{li}.self_attn"
|
|
|
|
# Get weights
|
|
w = all_weights # Already filtered and on device
|
|
|
|
# Build RoPE caches
|
|
rope_params = cfg.get("rope_parameters", {})
|
|
rope_type = rope_params.get("rope_type", "yarn")
|
|
rope_factor = rope_params.get("factor", 16.0)
|
|
rope_theta = rope_params.get("rope_theta", cfg.get("rope_theta", 10000.0))
|
|
original_max_pos = rope_params.get("original_max_position_embeddings", 4096)
|
|
beta_fast = rope_params.get("beta_fast", 32)
|
|
beta_slow = rope_params.get("beta_slow", 1)
|
|
rope_cos, rope_sin = build_rope_cache(
|
|
8192, rd, device, theta=rope_theta,
|
|
rope_type=rope_type, rope_factor=rope_factor,
|
|
original_max_pos=original_max_pos,
|
|
beta_fast=beta_fast, beta_slow=beta_slow
|
|
)
|
|
|
|
# Create mHC blocks
|
|
attn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=device)
|
|
ffn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=device)
|
|
fn_key = f"model.layers.{li}.attn_hc.fn"
|
|
base_key = f"model.layers.{li}.attn_hc.base"
|
|
scale_key = f"model.layers.{li}.attn_hc.scale"
|
|
if fn_key in w and base_key in w and scale_key in w:
|
|
attn_mhc.load_from_checkpoint(w[fn_key], w[base_key], w[scale_key])
|
|
fn_key = f"model.layers.{li}.ffn_hc.fn"
|
|
base_key = f"model.layers.{li}.ffn_hc.base"
|
|
scale_key = f"model.layers.{li}.ffn_hc.scale"
|
|
if fn_key in w and base_key in w and scale_key in w:
|
|
ffn_mhc.load_from_checkpoint(w[fn_key], w[base_key], w[scale_key])
|
|
|
|
attn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=device)
|
|
an_key = f"model.layers.{li}.input_layernorm.weight"
|
|
if an_key in w:
|
|
attn_norm.weight = w[an_key].to(device=device, dtype=torch.float32)
|
|
|
|
ffn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=device)
|
|
fn_key = f"model.layers.{li}.post_attention_layernorm.weight"
|
|
if fn_key in w:
|
|
ffn_norm.weight = w[fn_key].to(device=device, dtype=torch.float32)
|
|
|
|
kv_cache = SimpleKVCache(head_dim=hd, max_seq=8192, device=device)
|
|
|
|
# Create input: random embedding
|
|
torch.manual_seed(42)
|
|
X_l = torch.randn(1, n_hc, H, dtype=torch.bfloat16, device=device) * 0.5
|
|
positions = torch.tensor([0], dtype=torch.long, device=device)
|
|
token_id = torch.tensor([671], dtype=torch.long, device=device) # "The"
|
|
|
|
# Run forward_layer (production code)
|
|
X_prod = forward_layer(X_l.clone(), w, li, cfg, rope_cos, rope_sin,
|
|
attn_mhc, ffn_mhc, attn_norm, ffn_norm,
|
|
kv_cache, token_id, positions)
|
|
|
|
print(f"Production: |X_next|={X_prod.abs().max().item():.4f}")
|
|
print(f" Stream norms: {[X_prod[0,s,:].float().norm().item() for s in range(4)]}")
|
|
|
|
# ============================================================
|
|
# Step-by-step reference (same math, no wrappers)
|
|
# ============================================================
|
|
X_l_ref = X_l.clone()
|
|
|
|
# --- mHC pre_block (attention) ---
|
|
fn = w[f"model.layers.{li}.attn_hc.fn"].float().to(device)
|
|
base = w[f"model.layers.{li}.attn_hc.base"].float().to(device)
|
|
scale = w[f"model.layers.{li}.attn_hc.scale"].float().to(device)
|
|
|
|
# Unweighted RMSNorm on flattened residual
|
|
X_flat = X_l_ref.reshape(1, n_hc * H).float()
|
|
rms_inv = X_flat.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
|
flat = (X_flat * rms_inv).to(torch.bfloat16)
|
|
|
|
# F.linear + split [pre(4), post(4), comb(16)]
|
|
proj = torch.nn.functional.linear(flat.float(), fn).float()
|
|
pre_w, post_w, comb_w = proj.split([n_hc, n_hc, n_hc * n_hc], dim=-1)
|
|
pre_b, post_b, comb_b = base.split([n_hc, n_hc, n_hc * n_hc])
|
|
pre_s, post_s, comb_s = scale.unbind(0)
|
|
|
|
A_l = torch.sigmoid(pre_w * pre_s + pre_b) + 1e-6
|
|
C_l = 2.0 * torch.sigmoid(post_w * post_s + post_b)
|
|
B_l = sinkhorn((comb_w * comb_s + comb_b).reshape(1, n_hc, n_hc))
|
|
|
|
x_in = (A_l.unsqueeze(-1) * X_l_ref.float()).sum(dim=1).to(torch.bfloat16)
|
|
print(f"\n mHC A_l: {A_l[0].tolist()}")
|
|
print(f" mHC C_l: {C_l[0].tolist()}")
|
|
print(f" B row sums: {B_l[0].sum(-1).tolist()}")
|
|
|
|
# --- RMSNorm ---
|
|
x_normed = rmsnorm(x_in, w[f"model.layers.{li}.input_layernorm.weight"].to(device))
|
|
|
|
# --- Q projection ---
|
|
c_Q = nvfp4_linear(x_normed, w[f"{pre}.q_a_proj.weight"], w[f"{pre}.q_a_proj.weight_scale"], w[f"{pre}.q_a_proj.weight_scale_2"])
|
|
c_Q = rmsnorm(c_Q, w[f"{pre}.q_a_norm.weight"].to(device))
|
|
q = nvfp4_linear(c_Q, w[f"{pre}.q_b_proj.weight"], w[f"{pre}.q_b_proj.weight_scale"], w[f"{pre}.q_b_proj.weight_scale_2"])
|
|
q = unweighted_rmsnorm(q)
|
|
|
|
# --- KV projection ---
|
|
kv = nvfp4_linear(x_normed, w[f"{pre}.kv_proj.weight"], w[f"{pre}.kv_proj.weight_scale"], w[f"{pre}.kv_proj.weight_scale_2"])
|
|
kv = rmsnorm(kv, w[f"{pre}.kv_norm.weight"].to(device))
|
|
|
|
q_heads = q.reshape(1, n_h, hd)
|
|
kv_new = kv.reshape(1, 1, hd)
|
|
|
|
# RoPE
|
|
q_heads = apply_rope_partial(q_heads, positions, rope_cos, rope_sin, hd, rd)
|
|
kv_new = apply_rope_partial(kv_new, positions, rope_cos, rope_sin, hd, rd)
|
|
|
|
# Attention (single token → self-attention is identity)
|
|
attn_out = kv_new.expand(1, n_h, hd)
|
|
|
|
# Inverse RoPE
|
|
attn_out = apply_inverse_rope(attn_out, positions, rope_cos, rope_sin, hd, rd)
|
|
|
|
# Output projection
|
|
attn_flat = attn_out.reshape(1, n_h * hd)
|
|
attn_grouped = attn_flat.reshape(1, o_groups, (n_h // o_groups) * hd)
|
|
oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16().to(device)
|
|
oa_3d = oa_w.reshape(o_groups, o_rank, (n_h // o_groups) * hd)
|
|
grouped_out = torch.bmm(attn_grouped.permute(1, 0, 2), oa_3d.transpose(1, 2))
|
|
grouped_flat = grouped_out.permute(1, 0, 2).reshape(1, o_groups * o_rank)
|
|
F_attn = nvfp4_linear(grouped_flat, w[f"{pre}.o_b_proj.weight"], w[f"{pre}.o_b_proj.weight_scale"], w[f"{pre}.o_b_proj.weight_scale_2"])
|
|
|
|
print(f" |F_attn|={F_attn.abs().max().item():.4f}")
|
|
|
|
# mHC post_block: X_mid = C_l * F_attn + B_l.T @ X_l
|
|
BX = torch.bmm(B_l.transpose(-1, -2), X_l_ref.float())
|
|
CF = C_l.unsqueeze(-1) * F_attn.unsqueeze(1)
|
|
X_mid = (CF.float() + BX).to(torch.bfloat16)
|
|
|
|
print(f" |X_mid|={X_mid.abs().max().item():.4f}")
|
|
print(f" Stream norms (mid): {[X_mid[0,s,:].float().norm().item() for s in range(4)]}")
|
|
|
|
# Compare X_mid with production
|
|
X_prod_mid_stream0 = X_prod[0, 0, :].float()
|
|
X_ref_mid_stream0 = X_mid[0, 0, :].float()
|
|
cos_sim = torch.nn.functional.cosine_similarity(X_prod_mid_stream0.unsqueeze(0), X_ref_mid_stream0.unsqueeze(0)).item()
|
|
max_diff = (X_prod - X_mid).abs().max().item()
|
|
print(f"\n Stream 0 cosine similarity: {cos_sim:.6f}")
|
|
print(f" Max diff: {max_diff:.6f}")
|
|
|
|
if cos_sim < 0.99:
|
|
print(" ⚠️ MISMATCH! Production and reference differ significantly!")
|
|
else:
|
|
print(" ✅ Match!")
|
|
|
|
return X_mid
|
|
|
|
|
|
def main():
|
|
from safetensors.torch import load_file
|
|
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument('--layer', type=int, default=0)
|
|
p.add_argument('--device', type=str, default='cuda:0')
|
|
args = p.parse_args()
|
|
|
|
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
|
|
cfg = json.load(f)
|
|
|
|
print(f"Loading weights for layer {args.layer}...")
|
|
cdir = Path(CHECKPOINT_DIR)
|
|
index_path = cdir / "model.safetensors.index.json"
|
|
weight_map = {}
|
|
if index_path.exists():
|
|
with open(index_path) as f:
|
|
weight_map = json.load(f).get("weight_map", {})
|
|
|
|
# Find which shards contain our layer
|
|
li = args.layer
|
|
prefix = f"model.layers.{li}."
|
|
needed_shards = set()
|
|
for key, shard in weight_map.items():
|
|
if key.startswith(prefix) or key in ["model.embed_tokens.weight", "model.norm.weight"]:
|
|
needed_shards.add(shard)
|
|
|
|
all_w = {}
|
|
for shard_name in sorted(needed_shards):
|
|
if not (cdir / shard_name).exists():
|
|
continue
|
|
data = load_file(str(cdir / shard_name))
|
|
for k, v in data.items():
|
|
if k.startswith(prefix) or k in ["model.embed_tokens.weight"]:
|
|
all_w[k] = v.to(device=args.device, non_blocking=True)
|
|
|
|
print(f" {len(all_w)} weights loaded")
|
|
validate_layer(li, all_w, cfg, device=args.device)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|