541 lines
23 KiB
Python
541 lines
23 KiB
Python
#!/usr/bin/env python3
|
|
"""Minimal end-to-end test: process "The" through DSV4-Pro, verify logits.
|
|
|
|
Tests:
|
|
1. RoPE → inverse RoPE round-trip (should be exact at any single position)
|
|
2. Single token through layer 0 (shapes, finiteness, reasonable magnitudes)
|
|
3. Full model logits for "The" (finite, not degenerate)
|
|
|
|
Usage (on B200):
|
|
source /root/dsv4-nvfp4-workspace/venv/bin/activate
|
|
cd /root/dsv4-nvfp4-workspace/kernel
|
|
python3 tests/test_minimal_e2e.py
|
|
"""
|
|
import os, sys, math, json
|
|
import torch
|
|
from pathlib import Path
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
|
NUM_GPUS = 8
|
|
|
|
# =====================================================================
|
|
# Shared helpers
|
|
# =====================================================================
|
|
FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0])
|
|
|
|
def dequant_nvfp4_weight(weight, weight_scale, weight_scale_2):
|
|
out_dim = weight.shape[0]
|
|
in_features = weight.shape[1] * 2
|
|
low = (weight & 0x0F).to(torch.int8)
|
|
high = (weight >> 4).to(torch.int8)
|
|
low_sign, low_idx = (low >> 3).bool(), (low & 0x07).long()
|
|
high_sign, high_idx = (high >> 3).bool(), (high & 0x07).long()
|
|
lut = FP4_LUT.to(device=weight.device, dtype=torch.float32)
|
|
low_f = lut[low_idx] * torch.where(low_sign, -1.0, 1.0)
|
|
high_f = lut[high_idx] * torch.where(high_sign, -1.0, 1.0)
|
|
w_f = torch.stack([low_f, high_f], dim=-1).reshape(out_dim, in_features)
|
|
scale_f = weight_scale.float() * weight_scale_2.float()
|
|
scale_expanded = scale_f.repeat_interleave(16, dim=1)
|
|
return (w_f * scale_expanded).bfloat16()
|
|
|
|
def nvfp4_linear(x, weight, weight_scale, weight_scale_2):
|
|
w = dequant_nvfp4_weight(weight, weight_scale, weight_scale_2)
|
|
return torch.nn.functional.linear(x, w)
|
|
|
|
class RMSNorm:
|
|
def __init__(self, hidden_size, eps=1e-6, device='cuda:0'):
|
|
self.eps = eps
|
|
self.weight = torch.ones(hidden_size, dtype=torch.float32, device=device)
|
|
def forward(self, x):
|
|
x_f = x.float()
|
|
rms = x_f.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
|
|
return (x_f * rms * self.weight).to(torch.bfloat16)
|
|
|
|
def build_rope_cache(max_pos, rope_dim, device, theta=10000.0):
|
|
"""Build FP32 cos/sin caches for RoPE.
|
|
|
|
CRITICAL: Must be FP32, not BF16! BF16 quantization destroys
|
|
cos²+sin²=1 identity needed for inverse RoPE round-trip.
|
|
BF16 cos²+sin² can be as low as 0.996, causing ~3% error.
|
|
"""
|
|
half = rope_dim // 2
|
|
freqs = 1.0 / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim))
|
|
angles = torch.outer(torch.arange(max_pos, dtype=torch.float32), freqs)
|
|
return torch.cos(angles).to(device), torch.sin(angles).to(device)
|
|
|
|
def apply_rope_partial(x, positions, cos_cache, sin_cache, head_dim, rope_dim):
|
|
"""Apply partial GPT-J interleaved RoPE. Computes in FP32 for accuracy."""
|
|
T, n_h, hd = x.shape
|
|
nope = hd - rope_dim
|
|
cos = cos_cache[positions].unsqueeze(1) # (T, 1, half) FP32
|
|
sin = sin_cache[positions].unsqueeze(1)
|
|
x_rope = x[:, :, nope:].float() # (T, n_h, rope_dim)
|
|
x_even = x_rope[..., 0::2] # (T, n_h, half)
|
|
x_odd = x_rope[..., 1::2]
|
|
rot_even = x_even * cos - x_odd * sin
|
|
rot_odd = x_even * sin + x_odd * cos
|
|
result = x.clone()
|
|
rope_out = torch.empty_like(x_rope)
|
|
rope_out[..., 0::2] = rot_even
|
|
rope_out[..., 1::2] = rot_odd
|
|
result[:, :, nope:] = rope_out.to(torch.bfloat16)
|
|
return result
|
|
|
|
def apply_inverse_rope(o, positions, cos_cache, sin_cache, head_dim, rope_dim):
|
|
"""Apply inverse RoPE (conjugate rotation). Computes in FP32 for accuracy."""
|
|
T, n_h, hd = o.shape
|
|
nope = hd - rope_dim
|
|
cos = cos_cache[positions].unsqueeze(1)
|
|
sin = sin_cache[positions].unsqueeze(1)
|
|
o_rope = o[:, :, nope:].float()
|
|
o_even = o_rope[..., 0::2]
|
|
o_odd = o_rope[..., 1::2]
|
|
inv_even = o_even * cos + o_odd * sin
|
|
inv_odd = -o_even * sin + o_odd * cos
|
|
result = o.clone()
|
|
rope_out = torch.empty_like(o_rope)
|
|
rope_out[..., 0::2] = inv_even
|
|
rope_out[..., 1::2] = inv_odd
|
|
result[:, :, nope:] = rope_out.to(torch.bfloat16)
|
|
return result
|
|
|
|
def load_weights_to_cpu(checkpoint_dir):
|
|
from safetensors.torch import load_file
|
|
cdir = Path(checkpoint_dir)
|
|
index_path = cdir / "model.safetensors.index.json"
|
|
weight_map = {}
|
|
if index_path.exists():
|
|
with open(index_path) as f:
|
|
weight_map = json.load(f).get("weight_map", {})
|
|
shard_names = set(weight_map.values()) if weight_map else {
|
|
f"model-{i:05d}-of-00095.safetensors" for i in range(1, 96)
|
|
}
|
|
all_weights = {}
|
|
for shard_name in sorted(shard_names):
|
|
if not (cdir / shard_name).exists():
|
|
continue
|
|
data = load_file(str(cdir / shard_name))
|
|
all_weights.update(data)
|
|
return all_weights
|
|
|
|
def get_layer_weights(all_weights, li, device):
|
|
prefix = f"model.layers.{li}."
|
|
return {k: v.to(device=device, non_blocking=True) for k, v in all_weights.items() if k.startswith(prefix)}
|
|
|
|
# =====================================================================
|
|
# Test 1: RoPE round-trip
|
|
# =====================================================================
|
|
def test_rope_roundtrip():
|
|
print("\n" + "="*60)
|
|
print("Test 1: RoPE → inverse RoPE round-trip")
|
|
print("="*60)
|
|
device = 'cuda:0'
|
|
hd, rd, n_h = 512, 64, 128
|
|
cos, sin = build_rope_cache(8192, rd, device)
|
|
all_pass = True
|
|
|
|
for pos_val in [0, 1, 10, 100]:
|
|
torch.manual_seed(42)
|
|
x = torch.randn(1, n_h, hd, dtype=torch.bfloat16, device=device)
|
|
pos = torch.tensor([pos_val], dtype=torch.long, device=device)
|
|
|
|
x_roped = apply_rope_partial(x, pos, cos, sin, hd, rd)
|
|
x_recovered = apply_inverse_rope(x_roped, pos, cos, sin, hd, rd)
|
|
|
|
diff = (x.float() - x_recovered.float()).abs().max().item()
|
|
# BF16 round-trip error is expected (~0.01-0.02) due to BF16 intermediate
|
|
# between forward RoPE and inverse RoPE. The model trains with this.
|
|
# FP32 round-trip (no BF16 intermediate) would be exact.
|
|
ok = diff < 0.05 # 5% threshold for BF16 round-trip
|
|
all_pass &= ok
|
|
print(f" pos={pos_val:4d}: max_diff={diff:.2e} {'✅' if ok else '❌'}")
|
|
|
|
# The real check: FP32 arithmetic is exact (cos^2+sin^2=1 preserved)
|
|
# BF16 intermediates add expected quantization noise
|
|
print(f" Note: BF16 round-trip error of ~1-2% is EXPECTED (not a bug)")
|
|
print(f" Result: {'✅ PASS' if all_pass else '❌ FAIL'}")
|
|
return all_pass
|
|
|
|
|
|
# =====================================================================
|
|
# Test 2: Single token through layer 0
|
|
# =====================================================================
|
|
def test_layer0():
|
|
print("\n" + "="*60)
|
|
print("Test 2: Single token through layer 0")
|
|
print("="*60)
|
|
device = 'cuda:0'
|
|
|
|
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
|
|
cfg = json.load(f)
|
|
|
|
n_h = cfg["num_attention_heads"]
|
|
hd = cfg["head_dim"]
|
|
rd = cfg.get("qk_rope_head_dim", cfg.get("rope_dim", 64))
|
|
H = cfg["hidden_size"]
|
|
n_hc = 4
|
|
o_rank = cfg.get("output_group_dim", 1024)
|
|
o_groups = cfg.get("num_output_groups", 16)
|
|
heads_per_group = n_h // o_groups
|
|
group_input_dim = heads_per_group * hd
|
|
|
|
print(f" Config: {n_h} heads, hd={hd}, rope_dim={rd}, H={H}, "
|
|
f"{o_groups} groups, o_rank={o_rank}")
|
|
|
|
print(" Loading weights...")
|
|
all_weights = load_weights_to_cpu(CHECKPOINT_DIR)
|
|
w = get_layer_weights(all_weights, 0, device)
|
|
|
|
rope_cos, rope_sin = build_rope_cache(8192, rd, device)
|
|
embed_w = all_weights.get("model.embed_tokens.weight")
|
|
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to(device))
|
|
|
|
from transformers import AutoTokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
|
|
input_ids = tokenizer.encode("The")
|
|
tid = torch.tensor([input_ids[0]], dtype=torch.long, device=device)
|
|
print(f" Token: {tid.item()} = '{tokenizer.decode([tid.item()])}'")
|
|
|
|
emb = embed(tid)
|
|
print(f" Embedding: |emb|={emb.float().abs().max():.3f}")
|
|
|
|
# mHC init
|
|
from dsv4.layers.mhc import mHCLayer
|
|
X = mHCLayer.init_state(emb, n_hc)
|
|
print(f" mHC state: |X|={X.float().abs().max():.3f}")
|
|
|
|
# Build mHC + norms for layer 0
|
|
li = 0
|
|
from single_shot_inference import mHCBlock
|
|
attn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=device)
|
|
attn_mhc.load_from_checkpoint(
|
|
all_weights[f"model.layers.{li}.attn_hc.fn"],
|
|
all_weights[f"model.layers.{li}.attn_hc.base"],
|
|
all_weights[f"model.layers.{li}.attn_hc.scale"])
|
|
ffn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=device)
|
|
ffn_mhc.load_from_checkpoint(
|
|
all_weights[f"model.layers.{li}.ffn_hc.fn"],
|
|
all_weights[f"model.layers.{li}.ffn_hc.base"],
|
|
all_weights[f"model.layers.{li}.ffn_hc.scale"])
|
|
attn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=device)
|
|
attn_norm.weight = all_weights[f"model.layers.{li}.input_layernorm.weight"].to(device=device, dtype=torch.float32)
|
|
ffn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=device)
|
|
ffn_norm.weight = all_weights[f"model.layers.{li}.post_attention_layernorm.weight"].to(device=device, dtype=torch.float32)
|
|
|
|
# === ATTENTION ===
|
|
x_in, attn_ctx = attn_mhc.pre_block(X)
|
|
x_normed = attn_norm.forward(x_in)
|
|
|
|
pre = f"model.layers.{li}.self_attn"
|
|
|
|
# Q: q_a → q_a_norm → q_b
|
|
c_Q = nvfp4_linear(x_normed, w[f"{pre}.q_a_proj.weight"],
|
|
w[f"{pre}.q_a_proj.weight_scale"], w[f"{pre}.q_a_proj.weight_scale_2"])
|
|
q_norm_w = w.get(f"{pre}.q_a_norm.weight")
|
|
if q_norm_w is not None:
|
|
c_Q_f = c_Q.float()
|
|
c_Q = (c_Q_f * c_Q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() * q_norm_w.float()).bfloat16()
|
|
q = nvfp4_linear(c_Q, w[f"{pre}.q_b_proj.weight"],
|
|
w[f"{pre}.q_b_proj.weight_scale"], w[f"{pre}.q_b_proj.weight_scale_2"])
|
|
print(f" Q: shape={q.shape} |Q|={q.float().abs().max():.3f}")
|
|
|
|
# KV
|
|
kv = nvfp4_linear(x_normed, w[f"{pre}.kv_proj.weight"],
|
|
w[f"{pre}.kv_proj.weight_scale"], w[f"{pre}.kv_proj.weight_scale_2"])
|
|
kv_norm_w = w.get(f"{pre}.kv_norm.weight")
|
|
if kv_norm_w is not None:
|
|
kv_f = kv.float()
|
|
kv = (kv_f * kv_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() * kv_norm_w.float()).bfloat16()
|
|
print(f" KV: shape={kv.shape} |KV|={kv.float().abs().max():.3f}")
|
|
|
|
# Reshape + RoPE
|
|
q_heads = q.reshape(1, n_h, hd)
|
|
kv_new = kv.reshape(1, 1, hd)
|
|
positions = torch.tensor([0], dtype=torch.long, device=device)
|
|
q_roped = apply_rope_partial(q_heads, positions, rope_cos, rope_sin, hd, rd)
|
|
kv_roped = apply_rope_partial(kv_new, positions, rope_cos, rope_sin, hd, rd)
|
|
|
|
# SDPA (1 token, K=V)
|
|
k_exp = kv_roped.expand(n_h, -1, -1).contiguous()
|
|
v_exp = kv_roped.expand(n_h, -1, -1).contiguous()
|
|
q_input = q_roped.permute(1, 0, 2)
|
|
scale = 1.0 / math.sqrt(hd)
|
|
|
|
sink_key = f"{pre}.sinks"
|
|
if sink_key in w:
|
|
sinks = w[sink_key].to(device=device)
|
|
sink_k = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device=device)
|
|
sink_v = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device=device)
|
|
k_ws = torch.cat([k_exp, sink_k], dim=1) # (n_h, 2, hd)
|
|
v_ws = torch.cat([v_exp, sink_v], dim=1)
|
|
# Sink bias: add logit to virtual position
|
|
sink_bias = torch.zeros(n_h, 1, 2, dtype=torch.bfloat16, device=device)
|
|
for h in range(n_h):
|
|
sink_bias[h, :, -1] = sinks[h]
|
|
attn_out = torch.nn.functional.scaled_dot_product_attention(
|
|
q_input, k_ws, v_ws, attn_mask=sink_bias, scale=scale)
|
|
print(f" SDPA (with sinks): |out|={attn_out.float().abs().max():.3f}")
|
|
else:
|
|
attn_out = torch.nn.functional.scaled_dot_product_attention(
|
|
q_input, k_exp, v_exp, scale=scale, is_causal=False)
|
|
print(f" SDPA (no sinks): |out|={attn_out.float().abs().max():.3f}")
|
|
|
|
attn_out = attn_out.permute(1, 0, 2) # (1, n_h, hd)
|
|
|
|
# Inverse RoPE
|
|
attn_inv = apply_inverse_rope(attn_out, positions, rope_cos, rope_sin, hd, rd)
|
|
print(f" After inverse RoPE: |out|={attn_inv.float().abs().max():.3f}")
|
|
|
|
# Output projection: wo_a (grouped BMM) + wo_b (NVFP4)
|
|
attn_flat = attn_inv.reshape(1, n_h * hd)
|
|
attn_grouped = attn_flat.reshape(1, o_groups, heads_per_group * hd)
|
|
oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16()
|
|
oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim)
|
|
attn_for_bmm = attn_grouped.permute(1, 0, 2)
|
|
grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2))
|
|
grouped_flat = grouped_out.permute(1, 0, 2).reshape(1, o_groups * o_rank)
|
|
F_attn = nvfp4_linear(grouped_flat, w[f"{pre}.o_b_proj.weight"],
|
|
w[f"{pre}.o_b_proj.weight_scale"], w[f"{pre}.o_b_proj.weight_scale_2"])
|
|
print(f" F_attn: shape={F_attn.shape} |F_attn|={F_attn.float().abs().max():.3f}")
|
|
|
|
X_mid = attn_mhc.post_block(X, F_attn, attn_ctx)
|
|
print(f" X_mid: |X_mid|={X_mid.float().abs().max():.3f}")
|
|
|
|
# === FFN (shared expert only) ===
|
|
x_ffn, ffn_ctx = ffn_mhc.pre_block(X_mid)
|
|
x_ffn_normed = ffn_norm.forward(x_ffn)
|
|
|
|
se_pre = f"model.layers.{li}.mlp.shared_experts"
|
|
gate = nvfp4_linear(x_ffn_normed, w[f"{se_pre}.gate_proj.weight"],
|
|
w[f"{se_pre}.gate_proj.weight_scale"], w[f"{se_pre}.gate_proj.weight_scale_2"])
|
|
up = nvfp4_linear(x_ffn_normed, w[f"{se_pre}.up_proj.weight"],
|
|
w[f"{se_pre}.up_proj.weight_scale"], w[f"{se_pre}.up_proj.weight_scale_2"])
|
|
hidden = (torch.nn.functional.silu(gate.float()).clamp(-10, 10) * up.float().clamp(-10, 10)).bfloat16()
|
|
shared_out = nvfp4_linear(hidden, w[f"{se_pre}.down_proj.weight"],
|
|
w[f"{se_pre}.down_proj.weight_scale"], w[f"{se_pre}.down_proj.weight_scale_2"])
|
|
print(f" Shared expert: |out|={shared_out.float().abs().max():.3f}")
|
|
|
|
X_next = ffn_mhc.post_block(X_mid, shared_out, ffn_ctx)
|
|
has_nan = torch.isnan(X_next).any().item()
|
|
has_inf = torch.isinf(X_next).any().item()
|
|
print(f" X_next: |X_next|={X_next.float().abs().max():.3f} nan={has_nan} inf={has_inf}")
|
|
print(f" Result: {'✅ PASS' if not has_nan and not has_inf else '❌ FAIL'}")
|
|
|
|
del w, all_weights
|
|
torch.cuda.empty_cache()
|
|
return not has_nan and not has_inf
|
|
|
|
|
|
# =====================================================================
|
|
# Test 3: Full model logits
|
|
# =====================================================================
|
|
def test_full_logits():
|
|
print("\n" + "="*60)
|
|
print("Test 3: Full model logits for 'The'")
|
|
print("="*60)
|
|
device = 'cuda:0'
|
|
|
|
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
|
|
cfg = json.load(f)
|
|
|
|
n_layers = cfg["num_hidden_layers"]
|
|
H = cfg["hidden_size"]
|
|
n_h = cfg["num_attention_heads"]
|
|
hd = cfg["head_dim"]
|
|
rd = cfg.get("qk_rope_head_dim", cfg.get("rope_dim", 64))
|
|
n_hc = 4
|
|
o_rank = cfg.get("output_group_dim", 1024)
|
|
o_groups = cfg.get("num_output_groups", 16)
|
|
heads_per_group = n_h // o_groups
|
|
group_input_dim = heads_per_group * hd
|
|
|
|
print(" Loading weights to CPU...")
|
|
all_weights = load_weights_to_cpu(CHECKPOINT_DIR)
|
|
|
|
embed_w = all_weights.get("model.embed_tokens.weight")
|
|
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to(device))
|
|
lm_w = all_weights.get("lm_head.weight", embed_w).bfloat16().to(device)
|
|
final_norm_w = all_weights.get("model.norm.weight")
|
|
if final_norm_w is not None:
|
|
final_norm_w = final_norm_w.to(device)
|
|
|
|
from transformers import AutoTokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
|
|
input_ids = tokenizer.encode("The")
|
|
tid = torch.tensor([input_ids[0]], dtype=torch.long, device=device)
|
|
positions = torch.tensor([0], dtype=torch.long, device=device)
|
|
print(f" Token: {tid.item()} = '{tokenizer.decode([tid.item()])}'")
|
|
|
|
emb = embed(tid)
|
|
from single_shot_inference import mHCBlock
|
|
from dsv4.layers.mhc import mHCLayer
|
|
X = mHCLayer.init_state(emb, n_hc)
|
|
|
|
for li in range(n_layers):
|
|
gpu = li % NUM_GPUS
|
|
dev = f"cuda:{gpu}"
|
|
X = X.to(dev)
|
|
torch.cuda.set_device(gpu)
|
|
|
|
w = get_layer_weights(all_weights, li, dev)
|
|
rope_cos, rope_sin = build_rope_cache(8192, rd, dev)
|
|
positions_dev = positions.to(dev)
|
|
|
|
# Build mHC + norms
|
|
attn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
|
|
attn_mhc.load_from_checkpoint(
|
|
all_weights[f"model.layers.{li}.attn_hc.fn"],
|
|
all_weights[f"model.layers.{li}.attn_hc.base"],
|
|
all_weights[f"model.layers.{li}.attn_hc.scale"])
|
|
ffn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
|
|
ffn_mhc.load_from_checkpoint(
|
|
all_weights[f"model.layers.{li}.ffn_hc.fn"],
|
|
all_weights[f"model.layers.{li}.ffn_hc.base"],
|
|
all_weights[f"model.layers.{li}.ffn_hc.scale"])
|
|
attn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=dev)
|
|
attn_norm.weight = all_weights[f"model.layers.{li}.input_layernorm.weight"].to(device=dev, dtype=torch.float32)
|
|
ffn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=dev)
|
|
ffn_norm.weight = all_weights[f"model.layers.{li}.post_attention_layernorm.weight"].to(device=dev, dtype=torch.float32)
|
|
|
|
# ATTENTION
|
|
x_in, attn_ctx = attn_mhc.pre_block(X)
|
|
x_normed = attn_norm.forward(x_in)
|
|
pre = f"model.layers.{li}.self_attn"
|
|
|
|
c_Q = nvfp4_linear(x_normed, w[f"{pre}.q_a_proj.weight"],
|
|
w[f"{pre}.q_a_proj.weight_scale"], w[f"{pre}.q_a_proj.weight_scale_2"])
|
|
q_norm_w = w.get(f"{pre}.q_a_norm.weight")
|
|
if q_norm_w is not None:
|
|
c_Q_f = c_Q.float()
|
|
c_Q = (c_Q_f * c_Q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() * q_norm_w.float()).bfloat16()
|
|
q = nvfp4_linear(c_Q, w[f"{pre}.q_b_proj.weight"],
|
|
w[f"{pre}.q_b_proj.weight_scale"], w[f"{pre}.q_b_proj.weight_scale_2"])
|
|
|
|
kv = nvfp4_linear(x_normed, w[f"{pre}.kv_proj.weight"],
|
|
w[f"{pre}.kv_proj.weight_scale"], w[f"{pre}.kv_proj.weight_scale_2"])
|
|
kv_norm_w = w.get(f"{pre}.kv_norm.weight")
|
|
if kv_norm_w is not None:
|
|
kv_f = kv.float()
|
|
kv = (kv_f * kv_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() * kv_norm_w.float()).bfloat16()
|
|
|
|
q_heads = q.reshape(1, n_h, hd)
|
|
kv_new = kv.reshape(1, 1, hd)
|
|
q_roped = apply_rope_partial(q_heads, positions_dev, rope_cos, rope_sin, hd, rd)
|
|
kv_roped = apply_rope_partial(kv_new, positions_dev, rope_cos, rope_sin, hd, rd)
|
|
|
|
k_exp = kv_roped.expand(n_h, -1, -1).contiguous()
|
|
v_exp = kv_roped.expand(n_h, -1, -1).contiguous()
|
|
q_input = q_roped.permute(1, 0, 2)
|
|
scale = 1.0 / math.sqrt(hd)
|
|
|
|
sink_key = f"{pre}.sinks"
|
|
if sink_key in w:
|
|
sinks = w[sink_key].to(device=dev)
|
|
sink_k = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device=dev)
|
|
sink_v = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device=dev)
|
|
k_ws = torch.cat([k_exp, sink_k], dim=1)
|
|
v_ws = torch.cat([v_exp, sink_v], dim=1)
|
|
sink_bias = torch.zeros(n_h, 1, 2, dtype=torch.bfloat16, device=dev)
|
|
for h in range(n_h):
|
|
sink_bias[h, :, -1] = sinks[h]
|
|
attn_out = torch.nn.functional.scaled_dot_product_attention(
|
|
q_input, k_ws, v_ws, attn_mask=sink_bias, scale=scale)
|
|
else:
|
|
attn_out = torch.nn.functional.scaled_dot_product_attention(
|
|
q_input, k_exp, v_exp, scale=scale, is_causal=False)
|
|
attn_out = attn_out.permute(1, 0, 2)
|
|
|
|
attn_out = apply_inverse_rope(attn_out, positions_dev, rope_cos, rope_sin, hd, rd)
|
|
|
|
# Output projection
|
|
attn_flat = attn_out.reshape(1, n_h * hd)
|
|
attn_grouped = attn_flat.reshape(1, o_groups, heads_per_group * hd)
|
|
oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16()
|
|
oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim)
|
|
grouped_out = torch.bmm(attn_grouped.permute(1, 0, 2), oa_3d.transpose(1, 2))
|
|
grouped_flat = grouped_out.permute(1, 0, 2).reshape(1, o_groups * o_rank)
|
|
F_attn = nvfp4_linear(grouped_flat, w[f"{pre}.o_b_proj.weight"],
|
|
w[f"{pre}.o_b_proj.weight_scale"], w[f"{pre}.o_b_proj.weight_scale_2"])
|
|
|
|
X_mid = attn_mhc.post_block(X, F_attn, attn_ctx)
|
|
|
|
# FFN (shared expert only for speed)
|
|
x_ffn, ffn_ctx = ffn_mhc.pre_block(X_mid)
|
|
x_ffn_normed = ffn_norm.forward(x_ffn)
|
|
se_pre = f"model.layers.{li}.mlp.shared_experts"
|
|
gate = nvfp4_linear(x_ffn_normed, w[f"{se_pre}.gate_proj.weight"],
|
|
w[f"{se_pre}.gate_proj.weight_scale"], w[f"{se_pre}.gate_proj.weight_scale_2"])
|
|
up = nvfp4_linear(x_ffn_normed, w[f"{se_pre}.up_proj.weight"],
|
|
w[f"{se_pre}.up_proj.weight_scale"], w[f"{se_pre}.up_proj.weight_scale_2"])
|
|
hidden = (torch.nn.functional.silu(gate.float()).clamp(-10, 10) * up.float().clamp(-10, 10)).bfloat16()
|
|
shared_out = nvfp4_linear(hidden, w[f"{se_pre}.down_proj.weight"],
|
|
w[f"{se_pre}.down_proj.weight_scale"], w[f"{se_pre}.down_proj.weight_scale_2"])
|
|
X = ffn_mhc.post_block(X_mid, shared_out, ffn_ctx)
|
|
|
|
if li % 10 == 0 or li == n_layers - 1:
|
|
print(f" L{li:2d}: |X|={X.float().abs().max():.3f}")
|
|
|
|
del w
|
|
torch.cuda.empty_cache()
|
|
|
|
# Logits
|
|
X = X.to('cuda:0')
|
|
x_out = X[:, 0, :]
|
|
if final_norm_w is not None:
|
|
xf = x_out.float()
|
|
rms = xf.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
|
x_out = (xf * rms * final_norm_w.float()).bfloat16()
|
|
|
|
logits = torch.nn.functional.linear(x_out, lm_w)
|
|
has_nan = torch.isnan(logits).any().item()
|
|
has_inf = torch.isinf(logits).any().item()
|
|
lmin, lmax = logits.float().min().item(), logits.float().max().item()
|
|
lmean = logits.float().mean().item()
|
|
lstd = logits.float().std().item()
|
|
|
|
print(f"\n Logits: min={lmin:.3f} max={lmax:.3f} mean={lmean:.3f} std={lstd:.3f}")
|
|
print(f" nan={has_nan} inf={has_inf}")
|
|
|
|
if not has_nan and not has_inf:
|
|
top5_vals, top5_ids = torch.topk(logits[0], 5)
|
|
top5_str = ' '.join([f'{tokenizer.decode([t.item()])}({v.item():.1f})'
|
|
for t, v in zip(top5_ids, top5_vals)])
|
|
print(f" Top-5: {top5_str}")
|
|
|
|
# Check: logits should have reasonable spread (not uniform)
|
|
spread_ok = lstd > 0.5
|
|
print(f" Logit spread: {'✅' if spread_ok else '❌'} (std={lstd:.3f})")
|
|
|
|
ok = not has_nan and not has_inf
|
|
print(f" Result: {'✅ PASS' if ok else '❌ FAIL'}")
|
|
return ok
|
|
|
|
|
|
# =====================================================================
|
|
# Main
|
|
# =====================================================================
|
|
if __name__ == "__main__":
|
|
print("DSV4 Minimal End-to-End Test")
|
|
print("="*60)
|
|
|
|
results = {}
|
|
|
|
# Test 1: RoPE round-trip (fast, no weights)
|
|
results["rope_roundtrip"] = test_rope_roundtrip()
|
|
|
|
# Test 2: Layer 0
|
|
results["layer0"] = test_layer0()
|
|
|
|
# Test 3: Full model logits
|
|
results["full_logits"] = test_full_logits()
|
|
|
|
print("\n" + "="*60)
|
|
print("SUMMARY")
|
|
print("="*60)
|
|
for name, passed in results.items():
|
|
print(f" {name}: {'✅ PASS' if passed else '❌ FAIL'}")
|
|
all_pass = all(results.values())
|
|
print(f"\n Overall: {'✅ ALL PASS' if all_pass else '❌ SOME FAILED'}")
|