198 lines
8.0 KiB
Python
198 lines
8.0 KiB
Python
#!/usr/bin/env python3
|
|
"""Layer-by-layer comparison between our single_shot_inference and HF reference.
|
|
|
|
This test processes a single token through LAYER 0 using BOTH implementations
|
|
and compares the intermediate values to identify the exact point of divergence.
|
|
|
|
The "reference" implementation follows the HuggingFace DeepseekV4ForCausalLM
|
|
source code exactly, but using our NVFP4 dequantization for the weights.
|
|
|
|
Usage (on B200):
|
|
source /root/dsv4-nvfp4-workspace/venv/bin/activate
|
|
cd /root/dsv4-nvfp4-workspace/kernel
|
|
python tests/layer_compare.py
|
|
"""
|
|
import os, sys, json, math
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from pathlib import Path
|
|
|
|
sys.path.insert(0, "/root/dsv4-nvfp4-workspace/kernel")
|
|
from single_shot_inference import (
|
|
dequant_nvfp4_weight, nvfp4_linear, RMSNorm,
|
|
apply_rope_partial, apply_inverse_rope, build_rope_cache,
|
|
SimpleKVCache, mHCBlock
|
|
)
|
|
|
|
CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
|
|
|
def main():
|
|
from safetensors.torch import load_file
|
|
cdir = Path(CHECKPOINT_DIR)
|
|
with open(cdir / "config.json") as f:
|
|
cfg = json.load(f)
|
|
with open(cdir / "model.safetensors.index.json") as f:
|
|
wm = json.load(f)["weight_map"]
|
|
|
|
H = cfg["hidden_size"]
|
|
n_h = cfg["num_attention_heads"]
|
|
hd = cfg["head_dim"]
|
|
rd = cfg.get("qk_rope_head_dim", 64)
|
|
dc = cfg.get("q_lora_rank", 1536)
|
|
n_hc = 4
|
|
device = "cuda:0"
|
|
|
|
# Load layer 0 weights
|
|
print("Loading layer 0 weights...")
|
|
prefix = "model.layers.0."
|
|
layer0_keys = [k for k in wm if k.startswith(prefix)]
|
|
shards_needed = set(wm[k] for k in layer0_keys)
|
|
all_w = {}
|
|
for shard in shards_needed:
|
|
data = load_file(str(cdir / shard))
|
|
for k in layer0_keys:
|
|
if k in data:
|
|
all_w[k] = data[k].to(device)
|
|
|
|
# Load embedding
|
|
embed_w = load_file(str(cdir / wm["model.embed_tokens.weight"]))["model.embed_tokens.weight"].to(device).bfloat16()
|
|
from transformers import AutoTokenizer
|
|
tok = AutoTokenizer.from_pretrained(str(cdir))
|
|
|
|
# Process token "The"
|
|
tid = torch.tensor([tok.encode("The")[-1]], dtype=torch.long, device=device)
|
|
pos = torch.tensor([0], dtype=torch.long, device=device)
|
|
|
|
# Build RoPE cache with YaRN
|
|
rope_params = cfg.get("rope_parameters", {})
|
|
rope_cos, rope_sin = build_rope_cache(
|
|
8192, rd, device, theta=rope_params.get("rope_theta", 10000.0),
|
|
rope_type=rope_params.get("rope_type", "default"),
|
|
rope_factor=rope_params.get("factor", 1.0),
|
|
original_max_pos=rope_params.get("original_max_position_embeddings", 4096),
|
|
beta_fast=rope_params.get("beta_fast", 32),
|
|
beta_slow=rope_params.get("beta_slow", 1)
|
|
)
|
|
|
|
# Embed
|
|
emb = F.embedding(tid, embed_w) # (1, H)
|
|
print(f"Embedding: |emb|={emb.abs().max():.4f}")
|
|
|
|
# Init mHC state
|
|
X = mHCBlock.init_state(emb, n_hc) # (1, 4, H)
|
|
|
|
# Load mHC
|
|
fn = all_w[f"{prefix}attn_hc.fn"]
|
|
base = all_w[f"{prefix}attn_hc.base"]
|
|
scale = all_w[f"{prefix}attn_hc.scale"]
|
|
attn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=device)
|
|
n = n_hc
|
|
attn_mhc.load_weights(
|
|
W_pre=fn[0:n].to(device, dtype=torch.float32),
|
|
W_post=fn[n:2*n].to(device, dtype=torch.float32),
|
|
W_comb=fn[2*n:].to(device, dtype=torch.float32),
|
|
S_pre=base[0:n].reshape(1, n).to(device, dtype=torch.bfloat16),
|
|
S_post=base[n:2*n].reshape(n, 1).to(device, dtype=torch.bfloat16),
|
|
S_comb=base[2*n:].reshape(n, n).to(device, dtype=torch.bfloat16),
|
|
alpha_pre=scale[0].item(),
|
|
alpha_post=scale[1].item(),
|
|
alpha_comb=scale[2].item(),
|
|
)
|
|
|
|
# === OUR IMPLEMENTATION (single_shot_inference) ===
|
|
print("\n=== OUR IMPLEMENTATION ===")
|
|
|
|
# mHC pre_block
|
|
x_in, ctx = attn_mhc.pre_block(X)
|
|
print(f"x_in: |x_in|={x_in.abs().max():.4f} mean={x_in.float().abs().mean():.6f}")
|
|
|
|
# RMSNorm
|
|
norm = RMSNorm(H, device=device)
|
|
norm.weight = all_w[f"{prefix}input_layernorm.weight"].to(device, dtype=torch.float32)
|
|
x_norm = norm.forward(x_in)
|
|
print(f"x_norm: |x|={x_norm.abs().max():.4f} mean={x_norm.float().abs().mean():.6f}")
|
|
|
|
# Q projection: q_a → q_a_norm → q_b → q_b_norm
|
|
c_Q = nvfp4_linear(x_norm, all_w[f"{prefix}self_attn.q_a_proj.weight"],
|
|
all_w[f"{prefix}self_attn.q_a_proj.weight_scale"],
|
|
all_w[f"{prefix}self_attn.q_a_proj.weight_scale_2"])
|
|
# q_a_norm
|
|
q_norm_w = all_w.get(f"{prefix}self_attn.q_a_norm.weight")
|
|
if q_norm_w is not None:
|
|
c_Q_f = c_Q.float()
|
|
c_Q_rms = c_Q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
|
c_Q = (c_Q_f * c_Q_rms * q_norm_w.float()).bfloat16()
|
|
print(f"c_Q: |c_Q|={c_Q.abs().max():.4f} mean={c_Q.float().abs().mean():.6f}")
|
|
|
|
q = nvfp4_linear(c_Q, all_w[f"{prefix}self_attn.q_b_proj.weight"],
|
|
all_w[f"{prefix}self_attn.q_b_proj.weight_scale"],
|
|
all_w[f"{prefix}self_attn.q_b_proj.weight_scale_2"])
|
|
# q_b_norm (unweighted RMSNorm)
|
|
q_f = q.float()
|
|
q_rms = q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
|
q = (q_f * q_rms).bfloat16()
|
|
q_heads = q.reshape(1, n_h, hd)
|
|
print(f"q_heads: |q|={q_heads.abs().max():.4f} mean={q_heads.float().abs().mean():.6f}")
|
|
|
|
# KV projection
|
|
kv = nvfp4_linear(x_norm, all_w[f"{prefix}self_attn.kv_proj.weight"],
|
|
all_w[f"{prefix}self_attn.kv_proj.weight_scale"],
|
|
all_w[f"{prefix}self_attn.kv_proj.weight_scale_2"])
|
|
# kv_norm
|
|
kv_norm_w = all_w.get(f"{prefix}self_attn.kv_norm.weight")
|
|
if kv_norm_w is not None:
|
|
kv_f = kv.float()
|
|
kv_rms = kv_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
|
kv = (kv_f * kv_rms * kv_norm_w.float()).bfloat16()
|
|
kv_new = kv.reshape(1, 1, hd)
|
|
print(f"kv: |kv|={kv_new.abs().max():.4f} mean={kv_new.float().abs().mean():.6f}")
|
|
|
|
# Apply RoPE
|
|
q_heads = apply_rope_partial(q_heads, pos, rope_cos, rope_sin, hd, rd)
|
|
kv_new = apply_rope_partial(kv_new, pos, rope_cos, rope_sin, hd, rd)
|
|
print(f"After RoPE: |q|={q_heads.abs().max():.4f} |kv|={kv_new.abs().max():.4f}")
|
|
|
|
# Attention (single token, trivially 1.0)
|
|
q_in = q_heads.permute(1, 0, 2) # (n_h, 1, hd)
|
|
k_in = kv_new.permute(1, 0, 2) # (1, 1, hd)
|
|
k_exp = k_in.expand(n_h, -1, -1)
|
|
v_exp = k_exp.clone() # K=V in DSV4
|
|
attn_out = F.scaled_dot_product_attention(q_in, k_exp, v_exp, scale=1.0/math.sqrt(hd))
|
|
attn_out = attn_out.permute(1, 0, 2) # (1, n_h, hd)
|
|
print(f"attn_out: |o|={attn_out.abs().max():.4f} mean={attn_out.float().abs().mean():.6f}")
|
|
|
|
# Inverse RoPE
|
|
attn_out = apply_inverse_rope(attn_out, pos, rope_cos, rope_sin, hd, rd)
|
|
print(f"After inverse RoPE: |o|={attn_out.abs().max():.4f}")
|
|
|
|
# Output projection: wo_a (grouped BMM) + wo_b
|
|
o_groups = cfg.get("num_output_groups", 16)
|
|
o_rank = cfg.get("output_group_dim", 1024)
|
|
heads_per_group = n_h // o_groups
|
|
group_input_dim = heads_per_group * hd
|
|
|
|
attn_flat = attn_out.reshape(1, n_h * hd)
|
|
attn_grouped = attn_flat.reshape(1, o_groups, heads_per_group * hd)
|
|
oa_w = all_w[f"{prefix}self_attn.o_a_proj.weight"].bfloat16()
|
|
oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim)
|
|
attn_bmm = attn_grouped.permute(1, 0, 2)
|
|
grouped_out = torch.bmm(attn_bmm, oa_3d.transpose(1, 2))
|
|
grouped_flat = grouped_out.permute(1, 0, 2).reshape(1, o_groups * o_rank)
|
|
print(f"grouped_out: |o|={grouped_flat.abs().max():.4f}")
|
|
|
|
F_attn = nvfp4_linear(grouped_flat,
|
|
all_w[f"{prefix}self_attn.o_b_proj.weight"],
|
|
all_w[f"{prefix}self_attn.o_b_proj.weight_scale"],
|
|
all_w[f"{prefix}self_attn.o_b_proj.weight_scale_2"])
|
|
print(f"F_attn: |F|={F_attn.abs().max():.4f} mean={F_attn.float().abs().mean():.6f}")
|
|
|
|
# mHC post_block
|
|
X_mid = attn_mhc.post_block(X, F_attn, ctx)
|
|
print(f"X_mid: |X|={X_mid.abs().max():.4f}")
|
|
|
|
print("\nLayer 0 attention sub-block complete.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|