From db74a887ab01633a883c1d35d2c01edf1445c00f Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 31 May 2026 09:14:03 +0000 Subject: [PATCH] Add minimal e2e test + fix MoE expert loop bug (indentation) --- single_shot_inference.py | 44 ++-- tests/test_minimal_e2e.py | 511 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 533 insertions(+), 22 deletions(-) create mode 100644 tests/test_minimal_e2e.py diff --git a/single_shot_inference.py b/single_shot_inference.py index f75e0696..0a22a46c 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -524,31 +524,31 @@ def moe_forward(x, w, li, cfg, token_id, device): if not SKIP_ROUTED_MOE: for i, eid in enumerate(expert_ids): eid_int = eid.item() - epre = f"model.layers.{li}.mlp.experts.{eid_int}" + epre = f"model.layers.{li}.mlp.experts.{eid_int}" - gate = nvfp4_linear(x, - w[f"{epre}.gate_proj.weight"], - w[f"{epre}.gate_proj.weight_scale"], - w[f"{epre}.gate_proj.weight_scale_2"]) - up = nvfp4_linear(x, - w[f"{epre}.up_proj.weight"], - w[f"{epre}.up_proj.weight_scale"], - w[f"{epre}.up_proj.weight_scale_2"]) + gate = nvfp4_linear(x, + w[f"{epre}.gate_proj.weight"], + w[f"{epre}.gate_proj.weight_scale"], + w[f"{epre}.gate_proj.weight_scale_2"]) + up = nvfp4_linear(x, + w[f"{epre}.up_proj.weight"], + w[f"{epre}.up_proj.weight_scale"], + w[f"{epre}.up_proj.weight_scale_2"]) - # SwiGLU with clamping (paper §4.2.3) - silu_out = torch.nn.functional.silu(gate.float()) - if swiglu_limit is not None: - silu_out = silu_out.clamp(-swiglu_limit, swiglu_limit) - up_clamped = up.float().clamp(-swiglu_limit, swiglu_limit) - else: - up_clamped = up.float() - hidden = (silu_out * up_clamped).bfloat16() + # SwiGLU with clamping (paper §4.2.3) + silu_out = torch.nn.functional.silu(gate.float()) + if swiglu_limit is not None: + silu_out = silu_out.clamp(-swiglu_limit, swiglu_limit) + up_clamped = up.float().clamp(-swiglu_limit, swiglu_limit) + else: + up_clamped = up.float() + hidden = (silu_out * up_clamped).bfloat16() - down = nvfp4_linear(hidden, - w[f"{epre}.down_proj.weight"], - w[f"{epre}.down_proj.weight_scale"], - w[f"{epre}.down_proj.weight_scale_2"]) - expert_outputs.append(down) + down = nvfp4_linear(hidden, + w[f"{epre}.down_proj.weight"], + w[f"{epre}.down_proj.weight_scale"], + w[f"{epre}.down_proj.weight_scale_2"]) + expert_outputs.append(down) # Weighted combine + scaling routed_out = torch.zeros_like(x) diff --git a/tests/test_minimal_e2e.py b/tests/test_minimal_e2e.py new file mode 100644 index 00000000..167966e1 --- /dev/null +++ b/tests/test_minimal_e2e.py @@ -0,0 +1,511 @@ +#!/usr/bin/env python3 +"""Minimal end-to-end test: process "The" through DSV4-Pro, verify logits. + +Tests: + 1. RoPE → inverse RoPE round-trip (should be exact at any single position) + 2. Single token through layer 0 (shapes, finiteness, reasonable magnitudes) + 3. Full model logits for "The" (finite, not degenerate) + +Usage (on B200): + source /root/dsv4-nvfp4-workspace/venv/bin/activate + cd /root/dsv4-nvfp4-workspace/kernel + python3 tests/test_minimal_e2e.py +""" +import os, sys, math, json +import torch +from pathlib import Path + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" +NUM_GPUS = 8 + +# ===================================================================== +# Shared helpers +# ===================================================================== +FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]) + +def dequant_nvfp4_weight(weight, weight_scale, weight_scale_2): + out_dim = weight.shape[0] + in_features = weight.shape[1] * 2 + low = (weight & 0x0F).to(torch.int8) + high = (weight >> 4).to(torch.int8) + low_sign, low_idx = (low >> 3).bool(), (low & 0x07).long() + high_sign, high_idx = (high >> 3).bool(), (high & 0x07).long() + lut = FP4_LUT.to(device=weight.device, dtype=torch.float32) + low_f = lut[low_idx] * torch.where(low_sign, -1.0, 1.0) + high_f = lut[high_idx] * torch.where(high_sign, -1.0, 1.0) + w_f = torch.stack([low_f, high_f], dim=-1).reshape(out_dim, in_features) + scale_f = weight_scale.float() * weight_scale_2.float() + scale_expanded = scale_f.repeat_interleave(16, dim=1) + return (w_f * scale_expanded).bfloat16() + +def nvfp4_linear(x, weight, weight_scale, weight_scale_2): + w = dequant_nvfp4_weight(weight, weight_scale, weight_scale_2) + return torch.nn.functional.linear(x, w) + +class RMSNorm: + def __init__(self, hidden_size, eps=1e-6, device='cuda:0'): + self.eps = eps + self.weight = torch.ones(hidden_size, dtype=torch.float32, device=device) + def forward(self, x): + x_f = x.float() + rms = x_f.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt() + return (x_f * rms * self.weight).to(torch.bfloat16) + +def build_rope_cache(max_pos, rope_dim, device, theta=10000.0): + half = rope_dim // 2 + freqs = 1.0 / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim)) + angles = torch.outer(torch.arange(max_pos, dtype=torch.float32), freqs) + return torch.cos(angles).bfloat16().to(device), torch.sin(angles).bfloat16().to(device) + +def apply_rope_partial(x, positions, cos_cache, sin_cache, head_dim, rope_dim): + T, n_h, hd = x.shape + nope = hd - rope_dim + cos = cos_cache[positions].unsqueeze(1) + sin = sin_cache[positions].unsqueeze(1) + out = x.clone() + x_rope = x[:, :, nope:] + out[:, :, nope:][..., 0::2] = x_rope[..., 0::2] * cos - x_rope[..., 1::2] * sin + out[:, :, nope:][..., 1::2] = x_rope[..., 0::2] * sin + x_rope[..., 1::2] * cos + return out + +def apply_inverse_rope(o, positions, cos_cache, sin_cache, head_dim, rope_dim): + T, n_h, hd = o.shape + nope = hd - rope_dim + cos = cos_cache[positions].unsqueeze(1) + sin = sin_cache[positions].unsqueeze(1) + out = o.clone() + o_rope = o[:, :, nope:] + out[:, :, nope:][..., 0::2] = o_rope[..., 0::2] * cos + o_rope[..., 1::2] * sin + out[:, :, nope:][..., 1::2] = -o_rope[..., 0::2] * sin + o_rope[..., 1::2] * cos + return out + +def load_weights_to_cpu(checkpoint_dir): + from safetensors.torch import load_file + cdir = Path(checkpoint_dir) + index_path = cdir / "model.safetensors.index.json" + weight_map = {} + if index_path.exists(): + with open(index_path) as f: + weight_map = json.load(f).get("weight_map", {}) + shard_names = set(weight_map.values()) if weight_map else { + f"model-{i:05d}-of-00095.safetensors" for i in range(1, 96) + } + all_weights = {} + for shard_name in sorted(shard_names): + if not (cdir / shard_name).exists(): + continue + data = load_file(str(cdir / shard_name)) + all_weights.update(data) + return all_weights + +def get_layer_weights(all_weights, li, device): + prefix = f"model.layers.{li}." + return {k: v.to(device=device, non_blocking=True) for k, v in all_weights.items() if k.startswith(prefix)} + +# ===================================================================== +# Test 1: RoPE round-trip +# ===================================================================== +def test_rope_roundtrip(): + print("\n" + "="*60) + print("Test 1: RoPE → inverse RoPE round-trip") + print("="*60) + device = 'cuda:0' + hd, rd, n_h = 512, 64, 128 + cos, sin = build_rope_cache(8192, rd, device) + all_pass = True + + for pos_val in [0, 1, 10, 100]: + torch.manual_seed(42) + x = torch.randn(1, n_h, hd, dtype=torch.bfloat16, device=device) + pos = torch.tensor([pos_val], dtype=torch.long, device=device) + + x_roped = apply_rope_partial(x, pos, cos, sin, hd, rd) + x_recovered = apply_inverse_rope(x_roped, pos, cos, sin, hd, rd) + + diff = (x.float() - x_recovered.float()).abs().max().item() + ok = diff < 1e-4 + all_pass &= ok + print(f" pos={pos_val:4d}: max_diff={diff:.2e} {'✅' if ok else '❌'}") + + print(f" Result: {'✅ PASS' if all_pass else '❌ FAIL'}") + return all_pass + + +# ===================================================================== +# Test 2: Single token through layer 0 +# ===================================================================== +def test_layer0(): + print("\n" + "="*60) + print("Test 2: Single token through layer 0") + print("="*60) + device = 'cuda:0' + + with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f: + cfg = json.load(f) + + n_h = cfg["num_attention_heads"] + hd = cfg["head_dim"] + rd = cfg.get("qk_rope_head_dim", cfg.get("rope_dim", 64)) + H = cfg["hidden_size"] + n_hc = 4 + o_rank = cfg.get("output_group_dim", 1024) + o_groups = cfg.get("num_output_groups", 16) + heads_per_group = n_h // o_groups + group_input_dim = heads_per_group * hd + + print(f" Config: {n_h} heads, hd={hd}, rope_dim={rd}, H={H}, " + f"{o_groups} groups, o_rank={o_rank}") + + print(" Loading weights...") + all_weights = load_weights_to_cpu(CHECKPOINT_DIR) + w = get_layer_weights(all_weights, 0, device) + + rope_cos, rope_sin = build_rope_cache(8192, rd, device) + embed_w = all_weights.get("model.embed_tokens.weight") + embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to(device)) + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR) + input_ids = tokenizer.encode("The") + tid = torch.tensor([input_ids[0]], dtype=torch.long, device=device) + print(f" Token: {tid.item()} = '{tokenizer.decode([tid.item()])}'") + + emb = embed(tid) + print(f" Embedding: |emb|={emb.float().abs().max():.3f}") + + # mHC init + from dsv4.layers.mhc import mHCLayer + X = mHCLayer.init_state(emb, n_hc) + print(f" mHC state: |X|={X.float().abs().max():.3f}") + + # Build mHC + norms for layer 0 + li = 0 + attn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=device) + attn_mhc.load_from_checkpoint( + all_weights[f"model.layers.{li}.attn_hc.fn"], + all_weights[f"model.layers.{li}.attn_hc.base"], + all_weights[f"model.layers.{li}.attn_hc.scale"]) + ffn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=device) + ffn_mhc.load_from_checkpoint( + all_weights[f"model.layers.{li}.ffn_hc.fn"], + all_weights[f"model.layers.{li}.ffn_hc.base"], + all_weights[f"model.layers.{li}.ffn_hc.scale"]) + attn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=device) + attn_norm.weight = all_weights[f"model.layers.{li}.input_layernorm.weight"].to(device=device, dtype=torch.float32) + ffn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=device) + ffn_norm.weight = all_weights[f"model.layers.{li}.post_attention_layernorm.weight"].to(device=device, dtype=torch.float32) + + # === ATTENTION === + x_in, attn_ctx = attn_mhc.pre_block(X) + x_normed = attn_norm.forward(x_in) + + pre = f"model.layers.{li}.self_attn" + + # Q: q_a → q_a_norm → q_b + c_Q = nvfp4_linear(x_normed, w[f"{pre}.q_a_proj.weight"], + w[f"{pre}.q_a_proj.weight_scale"], w[f"{pre}.q_a_proj.weight_scale_2"]) + q_norm_w = w.get(f"{pre}.q_a_norm.weight") + if q_norm_w is not None: + c_Q_f = c_Q.float() + c_Q = (c_Q_f * c_Q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() * q_norm_w.float()).bfloat16() + q = nvfp4_linear(c_Q, w[f"{pre}.q_b_proj.weight"], + w[f"{pre}.q_b_proj.weight_scale"], w[f"{pre}.q_b_proj.weight_scale_2"]) + print(f" Q: shape={q.shape} |Q|={q.float().abs().max():.3f}") + + # KV + kv = nvfp4_linear(x_normed, w[f"{pre}.kv_proj.weight"], + w[f"{pre}.kv_proj.weight_scale"], w[f"{pre}.kv_proj.weight_scale_2"]) + kv_norm_w = w.get(f"{pre}.kv_norm.weight") + if kv_norm_w is not None: + kv_f = kv.float() + kv = (kv_f * kv_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() * kv_norm_w.float()).bfloat16() + print(f" KV: shape={kv.shape} |KV|={kv.float().abs().max():.3f}") + + # Reshape + RoPE + q_heads = q.reshape(1, n_h, hd) + kv_new = kv.reshape(1, 1, hd) + positions = torch.tensor([0], dtype=torch.long, device=device) + q_roped = apply_rope_partial(q_heads, positions, rope_cos, rope_sin, hd, rd) + kv_roped = apply_rope_partial(kv_new, positions, rope_cos, rope_sin, hd, rd) + + # SDPA (1 token, K=V) + k_exp = kv_roped.expand(n_h, -1, -1).contiguous() + v_exp = kv_roped.expand(n_h, -1, -1).contiguous() + q_input = q_roped.permute(1, 0, 2) + scale = 1.0 / math.sqrt(hd) + + sink_key = f"{pre}.sinks" + if sink_key in w: + sinks = w[sink_key].to(device=device) + sink_k = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device=device) + sink_v = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device=device) + k_ws = torch.cat([k_exp, sink_k], dim=1) # (n_h, 2, hd) + v_ws = torch.cat([v_exp, sink_v], dim=1) + # Sink bias: add logit to virtual position + sink_bias = torch.zeros(n_h, 1, 2, dtype=torch.bfloat16, device=device) + for h in range(n_h): + sink_bias[h, :, -1] = sinks[h] + attn_out = torch.nn.functional.scaled_dot_product_attention( + q_input, k_ws, v_ws, attn_mask=sink_bias, scale=scale) + print(f" SDPA (with sinks): |out|={attn_out.float().abs().max():.3f}") + else: + attn_out = torch.nn.functional.scaled_dot_product_attention( + q_input, k_exp, v_exp, scale=scale, is_causal=False) + print(f" SDPA (no sinks): |out|={attn_out.float().abs().max():.3f}") + + attn_out = attn_out.permute(1, 0, 2) # (1, n_h, hd) + + # Inverse RoPE + attn_inv = apply_inverse_rope(attn_out, positions, rope_cos, rope_sin, hd, rd) + print(f" After inverse RoPE: |out|={attn_inv.float().abs().max():.3f}") + + # Output projection: wo_a (grouped BMM) + wo_b (NVFP4) + attn_flat = attn_inv.reshape(1, n_h * hd) + attn_grouped = attn_flat.reshape(1, o_groups, heads_per_group * hd) + oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16() + oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim) + attn_for_bmm = attn_grouped.permute(1, 0, 2) + grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2)) + grouped_flat = grouped_out.permute(1, 0, 2).reshape(1, o_groups * o_rank) + F_attn = nvfp4_linear(grouped_flat, w[f"{pre}.o_b_proj.weight"], + w[f"{pre}.o_b_proj.weight_scale"], w[f"{pre}.o_b_proj.weight_scale_2"]) + print(f" F_attn: shape={F_attn.shape} |F_attn|={F_attn.float().abs().max():.3f}") + + X_mid = attn_mhc.post_block(X, F_attn, attn_ctx) + print(f" X_mid: |X_mid|={X_mid.float().abs().max():.3f}") + + # === FFN (shared expert only) === + x_ffn, ffn_ctx = ffn_mhc.pre_block(X_mid) + x_ffn_normed = ffn_norm.forward(x_ffn) + + se_pre = f"model.layers.{li}.mlp.shared_experts" + gate = nvfp4_linear(x_ffn_normed, w[f"{se_pre}.gate_proj.weight"], + w[f"{se_pre}.gate_proj.weight_scale"], w[f"{se_pre}.gate_proj.weight_scale_2"]) + up = nvfp4_linear(x_ffn_normed, w[f"{se_pre}.up_proj.weight"], + w[f"{se_pre}.up_proj.weight_scale"], w[f"{se_pre}.up_proj.weight_scale_2"]) + hidden = (torch.nn.functional.silu(gate.float()).clamp(-10, 10) * up.float().clamp(-10, 10)).bfloat16() + shared_out = nvfp4_linear(hidden, w[f"{se_pre}.down_proj.weight"], + w[f"{se_pre}.down_proj.weight_scale"], w[f"{se_pre}.down_proj.weight_scale_2"]) + print(f" Shared expert: |out|={shared_out.float().abs().max():.3f}") + + X_next = ffn_mhc.post_block(X_mid, shared_out, ffn_ctx) + has_nan = torch.isnan(X_next).any().item() + has_inf = torch.isinf(X_next).any().item() + print(f" X_next: |X_next|={X_next.float().abs().max():.3f} nan={has_nan} inf={has_inf}") + print(f" Result: {'✅ PASS' if not has_nan and not has_inf else '❌ FAIL'}") + + del w, all_weights + torch.cuda.empty_cache() + return not has_nan and not has_inf + + +# ===================================================================== +# Test 3: Full model logits +# ===================================================================== +def test_full_logits(): + print("\n" + "="*60) + print("Test 3: Full model logits for 'The'") + print("="*60) + device = 'cuda:0' + + with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f: + cfg = json.load(f) + + n_layers = cfg["num_hidden_layers"] + H = cfg["hidden_size"] + n_h = cfg["num_attention_heads"] + hd = cfg["head_dim"] + rd = cfg.get("qk_rope_head_dim", cfg.get("rope_dim", 64)) + n_hc = 4 + o_rank = cfg.get("output_group_dim", 1024) + o_groups = cfg.get("num_output_groups", 16) + heads_per_group = n_h // o_groups + group_input_dim = heads_per_group * hd + + print(" Loading weights to CPU...") + all_weights = load_weights_to_cpu(CHECKPOINT_DIR) + + embed_w = all_weights.get("model.embed_tokens.weight") + embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to(device)) + lm_w = all_weights.get("lm_head.weight", embed_w).bfloat16().to(device) + final_norm_w = all_weights.get("model.norm.weight") + if final_norm_w is not None: + final_norm_w = final_norm_w.to(device) + + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR) + input_ids = tokenizer.encode("The") + tid = torch.tensor([input_ids[0]], dtype=torch.long, device=device) + positions = torch.tensor([0], dtype=torch.long, device=device) + print(f" Token: {tid.item()} = '{tokenizer.decode([tid.item()])}'") + + emb = embed(tid) + X = mHCBlock.init_state(emb, n_hc) + + for li in range(n_layers): + gpu = li % NUM_GPUS + dev = f"cuda:{gpu}" + X = X.to(dev) + torch.cuda.set_device(gpu) + + w = get_layer_weights(all_weights, li, dev) + rope_cos, rope_sin = build_rope_cache(8192, rd, dev) + positions_dev = positions.to(dev) + + # Build mHC + norms + attn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev) + attn_mhc.load_from_checkpoint( + all_weights[f"model.layers.{li}.attn_hc.fn"], + all_weights[f"model.layers.{li}.attn_hc.base"], + all_weights[f"model.layers.{li}.attn_hc.scale"]) + ffn_mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev) + ffn_mhc.load_from_checkpoint( + all_weights[f"model.layers.{li}.ffn_hc.fn"], + all_weights[f"model.layers.{li}.ffn_hc.base"], + all_weights[f"model.layers.{li}.ffn_hc.scale"]) + attn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=dev) + attn_norm.weight = all_weights[f"model.layers.{li}.input_layernorm.weight"].to(device=dev, dtype=torch.float32) + ffn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=dev) + ffn_norm.weight = all_weights[f"model.layers.{li}.post_attention_layernorm.weight"].to(device=dev, dtype=torch.float32) + + # ATTENTION + x_in, attn_ctx = attn_mhc.pre_block(X) + x_normed = attn_norm.forward(x_in) + pre = f"model.layers.{li}.self_attn" + + c_Q = nvfp4_linear(x_normed, w[f"{pre}.q_a_proj.weight"], + w[f"{pre}.q_a_proj.weight_scale"], w[f"{pre}.q_a_proj.weight_scale_2"]) + q_norm_w = w.get(f"{pre}.q_a_norm.weight") + if q_norm_w is not None: + c_Q_f = c_Q.float() + c_Q = (c_Q_f * c_Q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() * q_norm_w.float()).bfloat16() + q = nvfp4_linear(c_Q, w[f"{pre}.q_b_proj.weight"], + w[f"{pre}.q_b_proj.weight_scale"], w[f"{pre}.q_b_proj.weight_scale_2"]) + + kv = nvfp4_linear(x_normed, w[f"{pre}.kv_proj.weight"], + w[f"{pre}.kv_proj.weight_scale"], w[f"{pre}.kv_proj.weight_scale_2"]) + kv_norm_w = w.get(f"{pre}.kv_norm.weight") + if kv_norm_w is not None: + kv_f = kv.float() + kv = (kv_f * kv_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() * kv_norm_w.float()).bfloat16() + + q_heads = q.reshape(1, n_h, hd) + kv_new = kv.reshape(1, 1, hd) + q_roped = apply_rope_partial(q_heads, positions_dev, rope_cos, rope_sin, hd, rd) + kv_roped = apply_rope_partial(kv_new, positions_dev, rope_cos, rope_sin, hd, rd) + + k_exp = kv_roped.expand(n_h, -1, -1).contiguous() + v_exp = kv_roped.expand(n_h, -1, -1).contiguous() + q_input = q_roped.permute(1, 0, 2) + scale = 1.0 / math.sqrt(hd) + + sink_key = f"{pre}.sinks" + if sink_key in w: + sinks = w[sink_key].to(device=dev) + sink_k = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device=dev) + sink_v = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device=dev) + k_ws = torch.cat([k_exp, sink_k], dim=1) + v_ws = torch.cat([v_exp, sink_v], dim=1) + sink_bias = torch.zeros(n_h, 1, 2, dtype=torch.bfloat16, device=dev) + for h in range(n_h): + sink_bias[h, :, -1] = sinks[h] + attn_out = torch.nn.functional.scaled_dot_product_attention( + q_input, k_ws, v_ws, attn_mask=sink_bias, scale=scale) + else: + attn_out = torch.nn.functional.scaled_dot_product_attention( + q_input, k_exp, v_exp, scale=scale, is_causal=False) + attn_out = attn_out.permute(1, 0, 2) + + attn_out = apply_inverse_rope(attn_out, positions_dev, rope_cos, rope_sin, hd, rd) + + # Output projection + attn_flat = attn_out.reshape(1, n_h * hd) + attn_grouped = attn_flat.reshape(1, o_groups, heads_per_group * hd) + oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16() + oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim) + grouped_out = torch.bmm(attn_grouped.permute(1, 0, 2), oa_3d.transpose(1, 2)) + grouped_flat = grouped_out.permute(1, 0, 2).reshape(1, o_groups * o_rank) + F_attn = nvfp4_linear(grouped_flat, w[f"{pre}.o_b_proj.weight"], + w[f"{pre}.o_b_proj.weight_scale"], w[f"{pre}.o_b_proj.weight_scale_2"]) + + X_mid = attn_mhc.post_block(X, F_attn, attn_ctx) + + # FFN (shared expert only for speed) + x_ffn, ffn_ctx = ffn_mhc.pre_block(X_mid) + x_ffn_normed = ffn_norm.forward(x_ffn) + se_pre = f"model.layers.{li}.mlp.shared_experts" + gate = nvfp4_linear(x_ffn_normed, w[f"{se_pre}.gate_proj.weight"], + w[f"{se_pre}.gate_proj.weight_scale"], w[f"{se_pre}.gate_proj.weight_scale_2"]) + up = nvfp4_linear(x_ffn_normed, w[f"{se_pre}.up_proj.weight"], + w[f"{se_pre}.up_proj.weight_scale"], w[f"{se_pre}.up_proj.weight_scale_2"]) + hidden = (torch.nn.functional.silu(gate.float()).clamp(-10, 10) * up.float().clamp(-10, 10)).bfloat16() + shared_out = nvfp4_linear(hidden, w[f"{se_pre}.down_proj.weight"], + w[f"{se_pre}.down_proj.weight_scale"], w[f"{se_pre}.down_proj.weight_scale_2"]) + X = ffn_mhc.post_block(X_mid, shared_out, ffn_ctx) + + if li % 10 == 0 or li == n_layers - 1: + print(f" L{li:2d}: |X|={X.float().abs().max():.3f}") + + del w + torch.cuda.empty_cache() + + # Logits + X = X.to('cuda:0') + x_out = X[:, 0, :] + if final_norm_w is not None: + xf = x_out.float() + rms = xf.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() + x_out = (xf * rms * final_norm_w.float()).bfloat16() + + logits = torch.nn.functional.linear(x_out, lm_w) + has_nan = torch.isnan(logits).any().item() + has_inf = torch.isinf(logits).any().item() + lmin, lmax = logits.float().min().item(), logits.float().max().item() + lmean = logits.float().mean().item() + lstd = logits.float().std().item() + + print(f"\n Logits: min={lmin:.3f} max={lmax:.3f} mean={lmean:.3f} std={lstd:.3f}") + print(f" nan={has_nan} inf={has_inf}") + + if not has_nan and not has_inf: + top5_vals, top5_ids = torch.topk(logits[0], 5) + top5_str = ' '.join([f'{tokenizer.decode([t.item()])}({v.item():.1f})' + for t, v in zip(top5_ids, top5_vals)]) + print(f" Top-5: {top5_str}") + + # Check: logits should have reasonable spread (not uniform) + spread_ok = lstd > 0.5 + print(f" Logit spread: {'✅' if spread_ok else '❌'} (std={lstd:.3f})") + + ok = not has_nan and not has_inf + print(f" Result: {'✅ PASS' if ok else '❌ FAIL'}") + return ok + + +# ===================================================================== +# Main +# ===================================================================== +if __name__ == "__main__": + print("DSV4 Minimal End-to-End Test") + print("="*60) + + results = {} + + # Test 1: RoPE round-trip (fast, no weights) + results["rope_roundtrip"] = test_rope_roundtrip() + + # Test 2: Layer 0 + results["layer0"] = test_layer0() + + # Test 3: Full model logits + results["full_logits"] = test_full_logits() + + print("\n" + "="*60) + print("SUMMARY") + print("="*60) + for name, passed in results.items(): + print(f" {name}: {'✅ PASS' if passed else '❌ FAIL'}") + all_pass = all(results.values()) + print(f"\n Overall: {'✅ ALL PASS' if all_pass else '❌ SOME FAILED'}")