Add HF reference test script
This commit is contained in:
28
hf_reference_test.py
Normal file
28
hf_reference_test.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_name = '/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4'
|
||||
print('Loading tokenizer...', flush=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
print('Loading model...', flush=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, torch_dtype=torch.bfloat16,
|
||||
device_map='auto', trust_remote_code=True, low_cpu_mem_usage=True
|
||||
)
|
||||
model.eval()
|
||||
print('Model loaded!', flush=True)
|
||||
|
||||
msg = [{'role':'user','content':'The capital of France is'}]
|
||||
ids = tokenizer.apply_chat_template(msg, add_generation_prompt=True, return_tensors='pt').cuda()
|
||||
print(f'Input: {ids.shape} tokens: {repr(tokenizer.decode(ids[0]))}', flush=True)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(ids).logits[0, -1]
|
||||
top10 = torch.topk(logits, 10)
|
||||
print('HF Top-10:', flush=True)
|
||||
for i, (tid, val) in enumerate(zip(top10.indices, top10.values)):
|
||||
print(f' {i+1}. {repr(tokenizer.decode([tid.item()]))} (id={tid.item()}, logit={val.item():.3f})', flush=True)
|
||||
|
||||
# Generate 10 tokens
|
||||
out = model.generate(ids, max_new_tokens=10, do_sample=False)
|
||||
print(f'Generated: {repr(tokenizer.decode(out[0]))}', flush=True)
|
||||
250
tests/compare_layer0.py
Normal file
250
tests/compare_layer0.py
Normal file
@@ -0,0 +1,250 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Compare our single-shot inference with the HuggingFace reference for layer 0.
|
||||
|
||||
This script processes a single token through just layer 0 and compares the output
|
||||
with a pure PyTorch reference implementation that matches the HF model exactly.
|
||||
|
||||
Usage (on B200):
|
||||
python3 tests/compare_layer0.py
|
||||
"""
|
||||
import os, sys, json, math, torch
|
||||
from pathlib import Path
|
||||
|
||||
# Add kernel to path
|
||||
sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel')
|
||||
|
||||
CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
||||
DEVICE = "cuda:0"
|
||||
|
||||
def load_weights():
|
||||
from safetensors.torch import load_file
|
||||
cdir = Path(CHECKPOINT_DIR)
|
||||
index_path = cdir / "model.safetensors.index.json"
|
||||
weight_map = {}
|
||||
if index_path.exists():
|
||||
with open(index_path) as f:
|
||||
weight_map = json.load(f).get("weight_map", {})
|
||||
shard_names = set(weight_map.values()) if weight_map else {
|
||||
f"model-{i:05d}-of-00095.safetensors" for i in range(1, 96)
|
||||
}
|
||||
all_w = {}
|
||||
for shard_name in sorted(shard_names):
|
||||
if not (cdir / shard_name).exists():
|
||||
continue
|
||||
data = load_file(str(cdir / shard_name))
|
||||
for k, v in data.items():
|
||||
if k.startswith("model.layers.0.") or k in ["model.embed_tokens.weight", "model.norm.weight", "lm_head.weight"]:
|
||||
all_w[k] = v
|
||||
return all_w
|
||||
|
||||
# =====================================================================
|
||||
# FP4 dequant
|
||||
# =====================================================================
|
||||
FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0])
|
||||
|
||||
def dequant_nvfp4(weight, weight_scale, weight_scale_2):
|
||||
out_dim = weight.shape[0]
|
||||
in_packed = weight.shape[1]
|
||||
in_features = in_packed * 2
|
||||
low = (weight & 0x0F).to(torch.int8)
|
||||
high = (weight >> 4).to(torch.int8)
|
||||
low_sign, low_idx = (low >> 3).bool(), (low & 0x07).long()
|
||||
high_sign, high_idx = (high >> 3).bool(), (high & 0x07).long()
|
||||
lut = FP4_LUT.to(device=weight.device, dtype=torch.float32)
|
||||
low_f = lut[low_idx] * torch.where(low_sign, -1.0, 1.0)
|
||||
high_f = lut[high_idx] * torch.where(high_sign, -1.0, 1.0)
|
||||
w_f = torch.stack([low_f, high_f], dim=-1).reshape(out_dim, in_features)
|
||||
scale_f = weight_scale.float() * weight_scale_2.float()
|
||||
scale_expanded = scale_f.repeat_interleave(16, dim=1)
|
||||
return (w_f * scale_expanded).bfloat16()
|
||||
|
||||
def nvfp4_linear(x, weight, weight_scale, weight_scale_2):
|
||||
w = dequant_nvfp4(weight, weight_scale, weight_scale_2)
|
||||
return torch.nn.functional.linear(x, w)
|
||||
|
||||
# =====================================================================
|
||||
# Reference: pure PyTorch layer 0
|
||||
# =====================================================================
|
||||
def reference_layer0(embedding, w, cfg):
|
||||
"""Process one token through layer 0 using pure PyTorch (matching HF)."""
|
||||
li = 0
|
||||
pre = f"model.layers.{li}.self_attn"
|
||||
n_h = cfg["num_attention_heads"] # 128
|
||||
hd = cfg["head_dim"] # 512
|
||||
rd = cfg.get("qk_rope_head_dim", 64) # 64
|
||||
H = cfg["hidden_size"] # 7168
|
||||
o_groups = cfg.get("o_groups", 16)
|
||||
o_rank = cfg.get("o_group_dim", 1024)
|
||||
n_hc = 4
|
||||
heads_per_group = n_h // o_groups
|
||||
|
||||
# Init mHC state
|
||||
X = embedding.unsqueeze(1).expand(-1, n_hc, -1).clone() # (1, 4, H)
|
||||
|
||||
# ============ mHC (attention) ============
|
||||
# Match HF DeepseekV4HyperConnection.forward
|
||||
fn = w[f"model.layers.{li}.attn_hc.fn"] # (24, 28672)
|
||||
base = w[f"model.layers.{li}.attn_hc.base"] # (24,)
|
||||
scale = w[f"model.layers.{li}.attn_hc.scale"] # (3,)
|
||||
|
||||
# Unweighted RMSNorm on flattened residual
|
||||
X_flat = X.reshape(1, n_hc * H).float()
|
||||
rms_inv = X_flat.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||||
flat = X_flat * rms_inv
|
||||
|
||||
# F.linear(flat, fn) and split [pre(4), post(4), comb(16)]
|
||||
proj = torch.nn.functional.linear(flat.to(torch.bfloat16), fn.float().to(DEVICE)).float()
|
||||
pre_w, post_w, comb_w = proj.split([n_hc, n_hc, n_hc * n_hc], dim=-1)
|
||||
|
||||
# Apply scale and bias
|
||||
pre_b, post_b, comb_b = base.split([n_hc, n_hc, n_hc * n_hc])
|
||||
pre_scale, post_scale, comb_scale = scale.unbind(0)
|
||||
|
||||
pre_vals = torch.sigmoid(pre_w * pre_scale + pre_b) + 1e-6 # A_l
|
||||
post_vals = 2.0 * torch.sigmoid(post_w * post_scale + post_b) # C_l
|
||||
|
||||
# Sinkhorn on comb
|
||||
comb_logits = (comb_w * comb_scale + comb_b).reshape(1, n_hc, n_hc)
|
||||
comb = torch.softmax(comb_logits, dim=-1) + 1e-6
|
||||
comb = comb / (comb.sum(dim=-2, keepdim=True) + 1e-6)
|
||||
for _ in range(19): # 20 total
|
||||
comb = comb / (comb.sum(dim=-1, keepdim=True) + 1e-6)
|
||||
comb = comb / (comb.sum(dim=-2, keepdim=True) + 1e-6)
|
||||
|
||||
# collapsed = (pre * streams).sum(dim=streams)
|
||||
x_in = (pre_vals.unsqueeze(-1) * X.float()).sum(dim=1).to(torch.bfloat16) # (1, H)
|
||||
B_l = comb # (1, 4, 4)
|
||||
C_l = post_vals # (1, 4)
|
||||
|
||||
print(f" A_l: {pre_vals[0].tolist()}")
|
||||
print(f" C_l: {C_l[0].tolist()}")
|
||||
print(f" B row sums: {B_l[0].sum(dim=-1).tolist()}")
|
||||
print(f" B col sums: {B_l[0].sum(dim=-2).tolist()}")
|
||||
|
||||
# ============ RMSNorm ============
|
||||
x_normed = x_in.float()
|
||||
rms_inv = x_normed.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||||
norm_w = w[f"model.layers.{li}.input_layernorm.weight"].to(DEVICE).float()
|
||||
x_normed = (x_normed * rms_inv * norm_w).to(torch.bfloat16)
|
||||
|
||||
# ============ Q projection ============
|
||||
c_Q = nvfp4_linear(x_normed, w[f"{pre}.q_a_proj.weight"], w[f"{pre}.q_a_proj.weight_scale"], w[f"{pre}.q_a_proj.weight_scale_2"])
|
||||
# q_a_norm (weighted RMSNorm)
|
||||
q_norm_w = w[f"{pre}.q_a_norm.weight"].to(DEVICE).float()
|
||||
c_Q_f = c_Q.float()
|
||||
c_Q_rms = c_Q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||||
c_Q = (c_Q_f * c_Q_rms * q_norm_w).bfloat16()
|
||||
|
||||
q = nvfp4_linear(c_Q, w[f"{pre}.q_b_proj.weight"], w[f"{pre}.q_b_proj.weight_scale"], w[f"{pre}.q_b_proj.weight_scale_2"])
|
||||
# q_b_norm (unweighted RMSNorm)
|
||||
q_f = q.float()
|
||||
q_rms = q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||||
q = (q_f * q_rms).bfloat16()
|
||||
|
||||
# ============ KV projection ============
|
||||
kv = nvfp4_linear(x_normed, w[f"{pre}.kv_proj.weight"], w[f"{pre}.kv_proj.weight_scale"], w[f"{pre}.kv_proj.weight_scale_2"])
|
||||
kv_norm_w = w[f"{pre}.kv_norm.weight"].to(DEVICE).float()
|
||||
kv_f = kv.float()
|
||||
kv_rms = kv_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||||
kv = (kv_f * kv_rms * kv_norm_w).bfloat16()
|
||||
|
||||
print(f" |c_Q|={c_Q.abs().max().item():.4f} |q|={q.abs().max().item():.4f} |kv|={kv.abs().max().item():.4f}")
|
||||
|
||||
# ============ Attention (self-attention for single token) ============
|
||||
q_heads = q.reshape(1, n_h, hd) # (1, n_h, hd)
|
||||
kv_heads = kv.reshape(1, 1, hd) # (1, 1, hd) — 1 KV head
|
||||
|
||||
# For single token, self-attention is trivially identity (weight=1 on self)
|
||||
# V = K (DSV4 MQA), so attn_out = V = K for single token
|
||||
attn_out = kv_heads.expand(1, n_h, hd) # (1, n_h, hd) — just V
|
||||
|
||||
# Inverse RoPE would be applied here, but for single token with no RoPE (position 0, cos=1, sin=0),
|
||||
# RoPE is identity and inverse RoPE is also identity.
|
||||
|
||||
# ============ Output projection ============
|
||||
attn_flat = attn_out.reshape(1, n_h * hd)
|
||||
attn_grouped = attn_flat.reshape(1, o_groups, heads_per_group * hd)
|
||||
oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16().to(DEVICE)
|
||||
oa_3d = oa_w.reshape(o_groups, o_rank, heads_per_group * hd)
|
||||
attn_for_bmm = attn_grouped.permute(1, 0, 2)
|
||||
grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2))
|
||||
grouped_flat = grouped_out.permute(1, 0, 2).reshape(1, o_groups * o_rank)
|
||||
F_attn = nvfp4_linear(grouped_flat, w[f"{pre}.o_b_proj.weight"], w[f"{pre}.o_b_proj.weight_scale"], w[f"{pre}.o_b_proj.weight_scale_2"])
|
||||
|
||||
print(f" |F_attn|={F_attn.abs().max().item():.4f} mean={F_attn.float().abs().mean().item():.6f}")
|
||||
|
||||
# ============ mHC post_block ============
|
||||
# X_next = C_l * F_attn + B_l.T @ X
|
||||
BX = torch.bmm(B_l.transpose(-1, -2), X.float())
|
||||
CF = C_l.unsqueeze(-1) * F_attn.unsqueeze(1)
|
||||
X_mid = (CF.float() + BX).to(torch.bfloat16)
|
||||
|
||||
print(f" |X_mid|={X_mid.abs().max().item():.4f} stream0_mean={X_mid[:,0,:].float().abs().mean().item():.6f}")
|
||||
|
||||
# ============ FFN (shared expert only for simplicity) ============
|
||||
# FFN mHC
|
||||
fn_ffn = w[f"model.layers.{li}.ffn_hc.fn"]
|
||||
base_ffn = w[f"model.layers.{li}.ffn_hc.base"]
|
||||
scale_ffn = w[f"model.layers.{li}.ffn_hc.scale"]
|
||||
|
||||
X_flat_ffn = X_mid.reshape(1, n_hc * H).float()
|
||||
rms_inv_ffn = X_flat_ffn.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||||
flat_ffn = X_flat_ffn * rms_inv_ffn
|
||||
|
||||
proj_ffn = torch.nn.functional.linear(flat_ffn.to(torch.bfloat16), fn_ffn.float().to(DEVICE)).float()
|
||||
pre_w_f, post_w_f, comb_w_f = proj_ffn.split([n_hc, n_hc, n_hc * n_hc], dim=-1)
|
||||
|
||||
pre_b_f, post_b_f, comb_b_f = base_ffn.split([n_hc, n_hc, n_hc * n_hc])
|
||||
pre_s_f, post_s_f, comb_s_f = scale_ffn.unbind(0)
|
||||
|
||||
pre_vals_f = torch.sigmoid(pre_w_f * pre_s_f + pre_b_f) + 1e-6
|
||||
post_vals_f = 2.0 * torch.sigmoid(post_w_f * post_s_f + post_b_f)
|
||||
|
||||
comb_logits_f = (comb_w_f * comb_s_f + comb_b_f).reshape(1, n_hc, n_hc)
|
||||
comb_f = torch.softmax(comb_logits_f, dim=-1) + 1e-6
|
||||
comb_f = comb_f / (comb_f.sum(dim=-2, keepdim=True) + 1e-6)
|
||||
for _ in range(19):
|
||||
comb_f = comb_f / (comb_f.sum(dim=-1, keepdim=True) + 1e-6)
|
||||
comb_f = comb_f / (comb_f.sum(dim=-2, keepdim=True) + 1e-6)
|
||||
|
||||
x_ffn = (pre_vals_f.unsqueeze(-1) * X_mid.float()).sum(dim=1).to(torch.bfloat16)
|
||||
|
||||
# FFN RMSNorm
|
||||
norm_w_ffn = w[f"model.layers.{li}.post_attention_layernorm.weight"].to(DEVICE).float()
|
||||
x_ffn_n = x_ffn.float()
|
||||
rms_ffn = x_ffn_n.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||||
x_ffn_n = (x_ffin_n * rms_ffn * norm_w_ffn).to(torch.bfloat16)
|
||||
|
||||
# Shared expert
|
||||
se_pre = f"model.layers.{li}.mlp.shared_experts"
|
||||
gate = nvfp4_linear(x_ffn_n, w[f"{se_pre}.gate_proj.weight"], w[f"{se_pre}.gate_proj.weight_scale"], w[f"{se_pre}.gate_proj.weight_scale_2"])
|
||||
up = nvfp4_linear(x_ffn_n, w[f"{se_pre}.up_proj.weight"], w[f"{se_pre}.up_proj.weight_scale"], w[f"{se_pre}.up_proj.weight_scale_2"])
|
||||
hidden = (torch.nn.functional.silu(gate.float()) * up.float()).bfloat16()
|
||||
shared_out = nvfp4_linear(hidden, w[f"{se_pre}.down_proj.weight"], w[f"{se_pre}.down_proj.weight_scale"], w[f"{se_pre}.down_proj.weight_scale_2"])
|
||||
|
||||
# mHC post (FFN)
|
||||
BX_f = torch.bmm(comb_f.transpose(-1, -2), X_mid.float())
|
||||
CF_f = post_vals_f.unsqueeze(-1) * shared_out.unsqueeze(1)
|
||||
X_next = (CF_f.float() + BX_f).to(torch.bfloat16)
|
||||
|
||||
print(f" |X_next|={X_next.abs().max().item():.4f}")
|
||||
return X_next
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
|
||||
cfg = json.load(f)
|
||||
|
||||
print("Loading weights...")
|
||||
w = load_weights()
|
||||
|
||||
# Embed "The"
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
|
||||
tid = tokenizer.encode("The")[-1]
|
||||
embed_w = w["model.embed_tokens.weight"].bfloat16().to(DEVICE)
|
||||
embed = torch.nn.functional.embedding(torch.tensor([tid], device=DEVICE), embed_w)
|
||||
|
||||
print(f"\nProcessing 'The' (id={tid}) through layer 0:")
|
||||
X_out = reference_layer0(embed, w, cfg)
|
||||
print(f"\nOutput: |X|={X_out.abs().max().item():.4f}")
|
||||
Reference in New Issue
Block a user