#!/usr/bin/env python3 """Single-shot DSV4-Pro inference — Full 61-layer pipeline, 8-GPU. This is a reference implementation that exercises the production kernel stack end-to-end. It should be usable as ground truth when integrating into vLLM or SGLang. Architecture (paper §2): X_l → mHC.pre_block → RMSNorm → Attention → F_attn → mHC.post_block → X_mid X_mid → mHC.pre_block → RMSNorm → FFN(MoE) → F_ffn → mHC.post_block → X_{l+1} Components exercised: - mHC (Manifold-Constrained Hyper-Connections) — proper Sinkhorn-Knopp - Low-rank Q projection (q_a → q_b) + KV projection (MQA, 1 KV head) - Partial RoPE (last 64 dims, GPT-J interleaved) - Production FMHA kernel (6-warp multi-tile, C API + ctypes) - Inverse RoPE on attention output (paper §2.3.3) - Grouped output projection (wo_a BMM + wo_b NVFP4) - Routed MoE (384 experts, top-6, hash + dense routing, SwiGLU clamp) - Shared expert (NVFP4 gate/up/down) - RMSNorm (pre-norm before each sub-block) - KV cache across decode steps Attention type simplification for this single-shot test: For short sequences (seq_len ≤ sliding_window=128), ALL attention types (CSA/HCA/SWA) reduce to dense attention over the full KV cache. CSA's compressed branch and indexer are only needed for long sequences where seq_len > sliding_window. HCA is dense over compressed entries, but at short sequence lengths, the compressed sequence is trivially small. So we use dense MQA attention over the full KV for all layers. This is mathematically correct for short sequences and exercises the FMHA kernel properly. Usage (on B200): source /root/dsv4-nvfp4-workspace/venv/bin/activate cd /root/dsv4-nvfp4-workspace/kernel python3 single_shot_inference.py """ import os, sys, time, json, math, argparse import torch from pathlib import Path # ===================================================================== # Configuration # ===================================================================== def parse_args(): p = argparse.ArgumentParser(description='DSV4 Single-Shot Inference') p.add_argument('--no-inverse-rope', action='store_true', help='Skip inverse RoPE on attention output') p.add_argument('--skip-moe', action='store_true', help='Only use shared expert (skip routed)') p.add_argument('--no-thinking', action='store_true', help='Force model to skip thinking (use <|EOT|> instead of thinking tokens)') p.add_argument('--max-tokens', type=int, default=512, help='Max new tokens to generate') p.add_argument('--prompt', type=str, default=None, help='Override prompt') return p.parse_args() _args = parse_args() CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" MAX_NEW_TOKENS = _args.max_tokens SYSTEM_PROMPT = "" # Empty system prompt for testing PROMPT = _args.prompt or "The capital of France is" NUM_GPUS = 8 SKIP_ROUTED_MOE = _args.skip_moe # If True, only use shared expert (debug) INVERSE_ROPE = not _args.no_inverse_rope # If False, skip inverse RoPE on attention output (diagnostic) SKIP_MHC = _args.skip_mhc # If True, bypass mHC and use simple residual connections (diagnostic) MHC_DIAG = True # If True, print per-layer mHC diagnostics (B_l row/col sums, C_l values) # When True: applies inverse RoPE at query position → converts absolute→relative # When False: leaves relative position encoding intact for output projection # DSV4 partial RoPE only affects last 64/512 dims; first 448 are always un-RoPE'd print(f"Config: INVERSE_ROPE={INVERSE_ROPE}, SKIP_ROUTED_MOE={SKIP_ROUTED_MOE}, MAX_NEW_TOKENS={MAX_NEW_TOKENS}") # ===================================================================== # NVFP4 dequantization — matches checkpoint format exactly # ===================================================================== FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]) # E2M1 magnitudes def dequant_nvfp4_weight(weight, weight_scale, weight_scale_2): """Dequantize NVFP4 weight to BF16. weight: (out_dim, in_dim//2) uint8 — 2 FP4 values per byte weight_scale: (out_dim, in_dim//16) E4M3 — per-16-element block scale weight_scale_2: (out_dim, 1) float32 — per-row global scale """ out_dim = weight.shape[0] in_packed = weight.shape[1] in_features = in_packed * 2 low = (weight & 0x0F).to(torch.int8) high = (weight >> 4).to(torch.int8) low_sign, low_idx = (low >> 3).bool(), (low & 0x07).long() high_sign, high_idx = (high >> 3).bool(), (high & 0x07).long() lut = FP4_LUT.to(device=weight.device, dtype=torch.float32) low_f = lut[low_idx] * torch.where(low_sign, -1.0, 1.0) high_f = lut[high_idx] * torch.where(high_sign, -1.0, 1.0) w_f = torch.stack([low_f, high_f], dim=-1).reshape(out_dim, in_features) scale_f = weight_scale.float() * weight_scale_2.float() scale_expanded = scale_f.repeat_interleave(16, dim=1) return (w_f * scale_expanded).bfloat16() def nvfp4_linear(x, weight, weight_scale, weight_scale_2): """BF16 linear with NVFP4 dequant.""" w = dequant_nvfp4_weight(weight, weight_scale, weight_scale_2) return torch.nn.functional.linear(x, w) # ===================================================================== # RMSNorm — matches dsv4/layers/norm.py # ===================================================================== class RMSNorm: def __init__(self, hidden_size, eps=1e-6, device='cuda:0'): self.eps = eps self.weight = torch.ones(hidden_size, dtype=torch.float32, device=device) def forward(self, x): """x: (T, H) BF16 → (T, H) BF16""" x_f = x.float() rms = x_f.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt() return (x_f * rms * self.weight).to(torch.bfloat16) # ===================================================================== # mHC — proper Sinkhorn-Knopp implementation # ===================================================================== class mHCBlock: """Wrapper around dsv4.layers.mhc.mHCLayer for single-shot inference. Uses the production mHCLayer implementation with proper Sinkhorn-Knopp. """ def __init__(self, hidden_dim=7168, n_hc=4, sinkhorn_iters=20, device='cuda:0'): from dsv4.layers.mhc import mHCLayer self._impl = mHCLayer( hidden_dim=hidden_dim, n_hc=n_hc, t_max_sinkhorn=sinkhorn_iters, device=device, dtype=torch.bfloat16) self.device = device self.n_hc = n_hc self.hidden_dim = hidden_dim def load_from_checkpoint(self, fn, base, scale): """Load from checkpoint tensors. Checkpoint layout (verified against HuggingFace DeepseekV4HyperConnection): fn: (24, 28672) — rows ordered [pre(4), post(4), comb(16)] base: (24,) — ordered [pre(4), post(4), comb(16)] scale: (3,) — [alpha_pre, alpha_post, alpha_comb] The HuggingFace model does: pre_w, post_w, comb_w = F.linear(flat, fn).split([4, 4, 16]) pre_b, post_b, comb_b = base.split([4, 4, 16]) pre_scale, post_scale, comb_scale = scale.unbind(0) """ n = self.n_hc dev = self.device # fn rows: [pre(4), post(4), comb(16)] — matches HuggingFace W_pre = fn[0:n].to(device=dev, dtype=torch.float32).contiguous() # fn[0:4] W_post = fn[n:2*n].to(device=dev, dtype=torch.float32).contiguous() # fn[4:8] W_comb = fn[2*n:].to(device=dev, dtype=torch.float32).contiguous() # fn[8:24] # base: [S_pre(4), S_post(4), S_comb(16)] — same ordering as fn S_pre = base[0:n].reshape(1, n).to(device=dev, dtype=torch.bfloat16).contiguous() # base[0:4] S_post = base[n:2*n].reshape(n, 1).to(device=dev, dtype=torch.bfloat16).contiguous() # base[4:8] S_comb = base[2*n:].reshape(n, n).to(device=dev, dtype=torch.bfloat16).contiguous() # base[8:24] # scale: [alpha_pre, alpha_post, alpha_comb] alpha_pre = scale[0].item() alpha_post = scale[1].item() alpha_comb = scale[2].item() self._impl.load_weights( W_pre=W_pre, W_post=W_post, W_comb=W_comb, S_pre=S_pre, S_post=S_post, S_comb=S_comb, alpha_pre=alpha_pre, alpha_post=alpha_post, alpha_comb=alpha_comb) @staticmethod def init_state(embeddings, n_hc=4): from dsv4.layers.mhc import mHCLayer return mHCLayer.init_state(embeddings, n_hc) def pre_block(self, X_l): return self._impl.pre_block(X_l) def post_block(self, X_l, F_out, ctx): return self._impl.post_block(X_l, F_out, ctx) # ===================================================================== # RoPE — partial, GPT-J interleaved, last rope_dim dims # ===================================================================== def build_rope_cache(max_pos, rope_dim, device, theta=10000.0, rope_type="default", rope_factor=1.0, original_max_pos=4096, beta_fast=32, beta_slow=1): """Build cos/sin caches for partial RoPE. CRITICAL: FP32, not BF16! BF16 quantization destroys cos²+sin²=1 identity needed for inverse RoPE. BF16 cos²+sin² can be 0.996, causing ~3% round-trip error that accumulates across 61 layers. Supports YaRN (Yet another RoPE extensioN) scaling for long context. The DSV4 Pro model uses rope_type='yarn' with factor=16. Returns: (cos_cache, sin_cache) each (max_pos, rope_dim//2) FP32 """ half = rope_dim // 2 # Base frequencies: 1 / theta^(2i/d) freqs = 1.0 / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim)) if rope_type == "yarn" and rope_factor > 1.0: # YaRN frequency scaling # Compute wavelength thresholds low_freq_wavelen = original_max_pos / (beta_fast * 2.0) # High-freq cutoff high_freq_wavelen = original_max_pos / (beta_slow * 2.0) # Low-freq cutoff new_freqs = [] for freq in freqs: wavelen = 2 * math.pi / freq if wavelen < low_freq_wavelen: # High frequency: no scaling new_freqs.append(freq) elif wavelen > high_freq_wavelen: # Low frequency: scale by 1/factor new_freqs.append(freq / rope_factor) else: # Medium frequency: smooth interpolation smooth = (original_max_pos / (wavelen * beta_slow) - rope_factor) / ( rope_factor * (beta_fast / beta_slow - 1) ) new_freqs.append((1 - smooth) * freq / rope_factor + smooth * freq) freqs = torch.tensor(new_freqs, dtype=torch.float32) angles = torch.outer(torch.arange(max_pos, dtype=torch.float32), freqs) return torch.cos(angles).to(device), torch.sin(angles).to(device) def apply_rope_partial(x, positions, cos_cache, sin_cache, head_dim, rope_dim): """Apply partial GPT-J interleaved RoPE to the last rope_dim dims of each head. Computes in FP32 for numerical stability (inverse RoPE requires cos²+sin²=1).""" T, n_h, hd = x.shape nope = hd - rope_dim cos = cos_cache[positions].unsqueeze(1) # (T, 1, half) FP32 sin = sin_cache[positions].unsqueeze(1) x_rope = x[:, :, nope:].float() # FP32 for accurate rotation x_even = x_rope[..., 0::2] x_odd = x_rope[..., 1::2] rot_even = x_even * cos - x_odd * sin rot_odd = x_even * sin + x_odd * cos result = x.clone() rope_out = torch.empty_like(x_rope) rope_out[..., 0::2] = rot_even rope_out[..., 1::2] = rot_odd result[:, :, nope:] = rope_out.to(torch.bfloat16) return result def apply_inverse_rope(o, positions, cos_cache, sin_cache, head_dim, rope_dim): """Apply inverse RoPE (conjugate rotation) to attention output. Computes in FP32 for numerical stability.""" T, n_h, hd = o.shape nope = hd - rope_dim cos = cos_cache[positions].unsqueeze(1) sin = sin_cache[positions].unsqueeze(1) o_rope = o[:, :, nope:].float() o_even = o_rope[..., 0::2] o_odd = o_rope[..., 1::2] inv_even = o_even * cos + o_odd * sin inv_odd = -o_even * sin + o_odd * cos result = o.clone() rope_out = torch.empty_like(o_rope) rope_out[..., 0::2] = inv_even rope_out[..., 1::2] = inv_odd result[:, :, nope:] = rope_out.to(torch.bfloat16) return result class SimpleKVCache: """Per-layer KV cache for decode. Stores BF16 K,V accumulated across steps. MQA: 1 KV head, so cache is (1, seq_len, hd) per layer.""" def __init__(self, head_dim, max_seq=8192, device='cuda:0'): self.hd = head_dim self.max_seq = max_seq self.device = device self.k = torch.zeros(1, max_seq, head_dim, dtype=torch.bfloat16, device=device) self.v = torch.zeros(1, max_seq, head_dim, dtype=torch.bfloat16, device=device) self.len = 0 def append(self, k_new, v_new): """Append K,V. k_new: (1, T, hd), v_new: (1, T, hd).""" T = k_new.shape[1] self.k[0, self.len:self.len + T] = k_new[0] self.v[0, self.len:self.len + T] = v_new[0] self.len += T def get(self): """Get K,V up to current length. Returns (1, seq_len, hd) each.""" return self.k[:, :self.len], self.v[:, :self.len] # ===================================================================== # Weight loading — streams safetensors shards, distributes to 8 GPUs # ===================================================================== def load_weights_to_cpu(checkpoint_dir): """Load all weights from checkpoint to CPU memory. Weights stay on CPU; we move per-layer to GPU on demand during inference. This avoids OOM from 285K GPU allocations and allows streaming. Returns: all_weights: dict[key] → tensor on CPU """ from safetensors.torch import load_file cdir = Path(checkpoint_dir) index_path = cdir / "model.safetensors.index.json" weight_map = {} if index_path.exists(): with open(index_path) as f: weight_map = json.load(f).get("weight_map", {}) shard_names = set(weight_map.values()) if weight_map else { f"model-{i:05d}-of-00095.safetensors" for i in range(1, 96) } print(f"Loading {len(shard_names)} shards to CPU...") all_weights = {} loaded = 0 for shard_name in sorted(shard_names): if not (cdir / shard_name).exists(): continue data = load_file(str(cdir / shard_name)) all_weights.update(data) loaded += 1 if loaded % 20 == 0: print(f" {loaded}/{len(shard_names)} shards, {len(all_weights)} tensors") print(f" Done: {len(all_weights)} tensors on CPU") return all_weights def get_layer_weights(all_weights, li, device): """Get weights for layer li, moved to target device. Returns dict of key→tensor on device. Filters by model.layers.{li} prefix. """ prefix = f"model.layers.{li}." w = {} for key in all_weights: if key.startswith(prefix): w[key] = all_weights[key].to(device=device, non_blocking=True) return w def cache_all_layer_weights(all_weights, n_layers, devices): """Pre-load ALL layer weights to their target GPUs. This avoids the per-token CPU→GPU transfer bottleneck. Each layer's weights stay on its target GPU for the entire inference run. """ print(f" Caching layer weights to GPUs...") cached = {} for li in range(n_layers): gpu = li % len(devices) dev = devices[gpu] cached[li] = get_layer_weights(all_weights, li, dev) if (li + 1) % 10 == 0: print(f" {li+1}/{n_layers} layers cached") print(f" All {n_layers} layers cached to GPUs") return cached # ===================================================================== # Single layer forward # ===================================================================== def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, attn_mhc, ffn_mhc, attn_norm, ffn_norm, kv_cache, token_id, positions): """Forward one layer with mHC + Attention + FFN. Architecture (paper §2): X_l → mHC.pre_block(attn) → RMSNorm → Attention → F_attn → mHC.post_block → X_mid X_mid → mHC.pre_block(ffn) → RMSNorm → MoE → F_ffn → mHC.post_block → X_{l+1} X_l: (T, n_hc, H) BF16 — mHC residual state Returns: X_next (T, n_hc, H) BF16 """ device = X_l.device H = cfg["hidden_size"] n_h = cfg["num_attention_heads"] hd = cfg["head_dim"] rd = cfg.get("qk_rope_head_dim", cfg.get("rope_dim", 64)) o_rank = cfg.get("output_group_dim", 1024) o_groups = cfg.get("num_output_groups", 16) n_hc = 4 pre = f"model.layers.{li}.self_attn" T = X_l.shape[0] heads_per_group = n_h // o_groups group_input_dim = heads_per_group * hd # ================================================================== # ATTENTION SUB-BLOCK # ================================================================== if SKIP_MHC: # Simple residual: skip mHC, use direct input x_in = X_l[:, 0, :] # Just take stream 0 attn_ctx = None else: # -- mHC pre_block (attention) -- x_in, attn_ctx = attn_mhc.pre_block(X_l) # x_in: (T, H) if MHC_DIAG and attn_ctx is not None: # mHC diagnostics B_l, C_l = attn_ctx.B_l, attn_ctx.C_l print(f" L{li} pre_attn: |X_l|={X_l.abs().max().item():.2f} |x_in|={x_in.abs().max().item():.2f}", flush=True) # -- RMSNorm (pre-norm before attention) -- x_normed = attn_norm.forward(x_in) # (T, H) BF16 # -- Q projection: q_a (low-rank down) → q_a_norm → q_b (low-rank up) -- c_Q = nvfp4_linear(x_normed, w[f"{pre}.q_a_proj.weight"], w[f"{pre}.q_a_proj.weight_scale"], w[f"{pre}.q_a_proj.weight_scale_2"]) # (T, dc) # Q norm (RMSNorm after q_a, before q_b) q_norm_w = w.get(f"{pre}.q_a_norm.weight") if q_norm_w is not None: c_Q_f = c_Q.float() c_Q_rms = c_Q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() c_Q = (c_Q_f * c_Q_rms * q_norm_w.float()).bfloat16() q = nvfp4_linear(c_Q, w[f"{pre}.q_b_proj.weight"], w[f"{pre}.q_b_proj.weight_scale"], w[f"{pre}.q_b_proj.weight_scale_2"]) # (T, n_h * hd) # q_b_norm — unweighted RMSNorm after q_b_proj (paper §2.3.1) # This is critical: normalizes Q before attention, preventing score collapse. # No learnable parameters — just q / RMS(q). q_f = q.float() q_rms = q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() q = (q_f * q_rms).bfloat16() # -- KV projection (MQA: 1 KV head) + KV norm -- kv = nvfp4_linear(x_normed, w[f"{pre}.kv_proj.weight"], w[f"{pre}.kv_proj.weight_scale"], w[f"{pre}.kv_proj.weight_scale_2"]) # (T, hd) # KV norm (RMSNorm after kv_proj) kv_norm_w = w.get(f"{pre}.kv_norm.weight") if kv_norm_w is not None: kv_f = kv.float() kv_rms = kv_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() kv = (kv_f * kv_rms * kv_norm_w.float()).bfloat16() # -- Reshape for attention -- q_heads = q.reshape(T, n_h, hd) # (T, n_h, hd) kv_new = kv.reshape(T, 1, hd) # (T, 1, hd) — 1 KV head # Diagnostic: Q/KV norms if MHC_DIAG and li < 3: print(f" L{li} Q: |q|={q_heads.abs().max().item():.2f} mean={q_heads.float().abs().mean().item():.4f}") print(f" L{li} KV: |kv|={kv_new.abs().max().item():.2f} mean={kv_new.float().abs().mean().item():.4f}") # -- Apply RoPE to Q (at current positions) -- positions_dev = positions.to(device) q_heads = apply_rope_partial(q_heads, positions_dev, rope_cos, rope_sin, hd, rd) # -- Apply RoPE to KV (at current positions) BEFORE caching -- # DSV4 convention: RoPE applied to KV before writing to cache. # K = V in DSV4 MQA (same projection, same RoPE'd tensor). kv_new = apply_rope_partial(kv_new, positions_dev, rope_cos, rope_sin, hd, rd) # -- KV cache: append RoPE'd KV (K=V) -- k_new = kv_new # (T, 1, hd) — RoPE'd v_new = kv_new # K = V in DSV4 MQA kv_cache.append(k_new.permute(1, 0, 2), v_new.permute(1, 0, 2)) # (1, T, hd) # -- Get full KV from cache (already RoPE'd) -- k_full, v_full = kv_cache.get() # (1, seq_len, hd) each — RoPE'd, K=V seq_len = k_full.shape[1] # -- Attention: SDPA for short seqs (avoids FMHA padding bug), FMHA for long -- q_input = q_heads.permute(1, 0, 2) # (n_h, T, hd) scale = 1.0 / math.sqrt(hd) # FMHA pads N to next multiple of 128. For N<<128, padded zero-K entries # contribute exp(0)=1 to softmax, diluting real attention weights by ~128/N. # Use SDPA for short sequences where padding dominates. if seq_len < 120: k_expanded = k_full.expand(n_h, -1, -1).contiguous() v_expanded = v_full.expand(n_h, -1, -1).contiguous() # Attention: compute raw scores, add sinks as logit bias, softmax, multiply by V # (paper D5c, matching HuggingFace reference implementation) # Sinks are added as a logit column, softmaxed together, then DROPPED # before V multiplication — NOT as a dummy KV entry. sink_key = f"{pre}.sinks" scores_raw = torch.matmul(q_input, k_expanded.transpose(-1, -2)) * scale # (n_h, T, seq_len) if sink_key in w and seq_len > 0: sinks = w[sink_key].to(device=device) # (n_h,) BF16 # sinks: (n_h,) → reshape to (n_h, 1, 1) for broadcasting with (n_h, T, seq_len) sink_logits = sinks.float().reshape(n_h, 1, 1).expand(-1, T, 1) combined_logits = torch.cat([scores_raw, sink_logits], dim=-1) # (n_h, T, seq_len+1) # Stable softmax combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values probs = torch.softmax(combined_logits.float(), dim=-1).to(torch.bfloat16) attn_weights = probs[..., :-1] # Drop sink column (n_h, T, seq_len) else: attn_weights = torch.softmax(scores_raw.float(), dim=-1).to(torch.bfloat16) attn_out = torch.matmul(attn_weights, v_expanded) # (n_h, T, hd) attn_out = attn_out.permute(1, 0, 2) # (T, n_h, hd) # Diagnostic: check attention entropy (how spread out the attention is) if MHC_DIAG and li < 3: with torch.no_grad(): scores = torch.matmul(q_input, k_expanded.transpose(-1, -2)) * scale # (n_h, T, seq_len) weights = torch.softmax(scores.float(), dim=-1) # (n_h, 1, seq_len) # For head 0: what positions get the most weight? w0 = weights[0, 0] # (seq_len,) top3_pos = torch.topk(w0, min(3, seq_len)) entropy = -(w0 * (w0 + 1e-10).log()).sum().item() print(f" L{li} attn: seq_len={seq_len} entropy={entropy:.2f} top3_pos={top3_pos.indices.tolist()} top3_w={top3_pos.values.tolist()}") else: # Use FMHA kernel for longer sequences (padding effect is negligible) from dsv4.kernels.attention.fmha_multitile_op import fmha_multitile_decode_raw q_4d = q_input.unsqueeze(0).contiguous() k_4d = k_full.unsqueeze(0).contiguous() v_4d = v_full.unsqueeze(0).transpose(-1, -2).contiguous() o_4d, lse = fmha_multitile_decode_raw(q_4d, k_4d, v_4d, scale) attn_out = o_4d.squeeze(0).permute(1, 0, 2) # Sink correction sink_key = f"{pre}.sinks" if sink_key in w and seq_len > 0: sinks = w[sink_key].to(device=device) lse_2d = lse.squeeze(0).t() sink_exp = torch.exp(sinks.float()) attn_exp = torch.exp(lse_2d.float()) correction = attn_exp / (attn_exp + sink_exp.unsqueeze(0) + 1e-10) attn_out = (attn_out.float() * correction.unsqueeze(-1)).bfloat16() attn_out = attn_out.bfloat16() # -- Inverse RoPE on attention output (paper §2.3.3) -- # DSV4 uses K=V in MQA; both get RoPE'd. Inverse RoPE on the output # at query position q converts: R(q)⁻¹ Σ softmax(R(q)Q·R(p)K) R(p)V # For single KV entry at p: R(p-q)V (relative position encoding) # This only affects the last 64 dims (partial RoPE); first 448 unchanged. # The relative encoding in those 64 dims may be INTENTIONAL — the # output projection can use it for position-dependent computation. # Test both modes via INVERSE_ROPE flag. if INVERSE_ROPE: attn_out = apply_inverse_rope(attn_out, positions_dev, rope_cos, rope_sin, hd, rd) # -- Output projection: wo_a (grouped BMM) + wo_b (NVFP4) -- # wo_a: grouped linear, (n_h, hd) → (n_groups, o_rank) via BMM attn_flat = attn_out.reshape(T, n_h * hd) # (T, n_h * hd) attn_grouped = attn_flat.reshape(T, o_groups, heads_per_group * hd) # (T, groups, group_dim) oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16() # (n_groups * o_rank, group_input_dim) BF16 oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim) # (groups, o_rank, group_dim) attn_for_bmm = attn_grouped.permute(1, 0, 2) # (groups, T, group_dim) grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2)) # (groups, T, o_rank) grouped_flat = grouped_out.permute(1, 0, 2).reshape(T, o_groups * o_rank) # (T, groups*o_rank) F_attn = nvfp4_linear(grouped_flat, w[f"{pre}.o_b_proj.weight"], w[f"{pre}.o_b_proj.weight_scale"], w[f"{pre}.o_b_proj.weight_scale_2"]) # (T, H) if SKIP_MHC: X_mid = X_l[:, 0, :].unsqueeze(1).expand(-1, 4, -1) + F_attn.unsqueeze(1) * 0.1 else: # -- mHC post_block (attention) -- X_mid = attn_mhc.post_block(X_l, F_attn, attn_ctx) # (T, n_hc, H) # Diagnostic: check mHC is stabilizing the residual if MHC_DIAG and attn_ctx is not None: # mHC diagnostics B_l, C_l = attn_ctx.B_l, attn_ctx.C_l print(f" L{li} attn: |X_l|={X_l.abs().max().item():.2f} |F_attn|={F_attn.abs().max().item():.2f} |B|={B_l.abs().max().item():.4f} |C|={C_l.abs().max().item():.4f} |X_mid|={X_mid.abs().max().item():.2f}") # Check B_l is doubly stochastic (rows sum to 1.0) B_row_sums = B_l.sum(dim=-1) # (T, n_hc) B_col_sums = B_l.sum(dim=-2) # (T, n_hc) print(f" B row_sums={B_row_sums[0].tolist()} col_sums={B_col_sums[0].tolist()}") print(f" C_l={C_l[0].tolist()}") # ================================================================== # FFN SUB-BLOCK # ================================================================== if SKIP_MHC: x_ffn = X_mid[:, 0, :] # Just take stream 0 ffn_ctx = None else: # -- mHC pre_block (FFN) -- x_ffn, ffn_ctx = ffn_mhc.pre_block(X_mid) # (T, H) # -- RMSNorm (pre-norm before FFN) -- x_ffn_normed = ffn_norm.forward(x_ffn) # (T, H) BF16 # -- MoE + shared expert -- F_ffn = moe_forward(x_ffn_normed, w, li, cfg, token_id, device) if SKIP_MHC: X_next = X_mid + F_ffn.unsqueeze(1) * 0.1 else: # -- mHC post_block (FFN) -- X_next = ffn_mhc.post_block(X_mid, F_ffn, ffn_ctx) # (T, n_hc, H) if MHC_DIAG and ffn_ctx is not None: # ffn mHC diagnostics B_l_ffn, C_l_ffn = ffn_ctx.B_l, ffn_ctx.C_l print(f" L{li} ffn: |X_mid|={X_mid.abs().max().item():.2f} |F_ffn|={F_ffn.abs().max().item():.2f} |B|={B_l_ffn.abs().max().item():.4f} |C|={C_l_ffn.abs().max().item():.4f} |X_next|={X_next.abs().max().item():.2f}", flush=True) return X_next # ===================================================================== # MoE forward — hash + dense routing, SwiGLU with clamping # ===================================================================== def moe_forward(x, w, li, cfg, token_id, device): """Run routed MoE + shared expert. x: (T, H) BF16 — post-RMSNorm FFN input Returns: (T, H) BF16 """ H = cfg["hidden_size"] n_experts = cfg["n_routed_experts"] top_k = cfg.get("num_experts_per_tok", 6) routed_scaling = cfg.get("routed_scaling_factor", 2.5) swiglu_limit = cfg.get("swiglu_limit", 10.0) mlp_inter = cfg["moe_intermediate_size"] # ---- Routing ---- # Layers 0-2: hash routing (tid2eid lookup) # Layers 3+: noaux_tc (sqrt(softplus) scoring + e_score_correction_bias for selection only) # Config: topk_method='noaux_tc', scoring_func='sqrtsoftplus' expert_ids = None expert_weights = None tid2eid_key = f"model.layers.{li}.mlp.gate.tid2eid" e_bias_key = f"model.layers.{li}.mlp.gate.e_score_correction_bias" is_hash = tid2eid_key in w and e_bias_key not in w if is_hash: # Hash routing: deterministic per-token lookup, uniform weights tid2eid = w[tid2eid_key] tid = token_id.item() if token_id.numel() == 1 else token_id[0].item() expert_ids = tid2eid[tid] # (top_k,) int64 expert_weights = torch.ones(top_k, dtype=torch.float32, device=x.device) / top_k else: # Dense routing: sqrt(softplus(logits)) scoring gate_w = w[f"model.layers.{li}.mlp.gate.weight"] # (H, n_experts) BF16 logits = torch.nn.functional.linear(x, gate_w.bfloat16()) # (T, n_experts) # Scoring: sqrt(softplus(logits)) scores = torch.sqrt(torch.nn.functional.softplus(logits.float()) + 1e-6) # e_score_correction_bias: per-expert bias for SELECTION ONLY selection_logits = scores.clone() if e_bias_key in w: selection_logits = selection_logits + w[e_bias_key].float().unsqueeze(0) _, indices = selection_logits.topk(top_k, dim=-1) # (T, top_k) # Weights from UNBIASED scores (no e_bias) expert_weights = torch.gather(scores, -1, indices) expert_weights = expert_weights / expert_weights.sum(dim=-1, keepdim=True) # For T=1 decode, squeeze if x.shape[0] == 1: expert_ids = indices[0] expert_weights = expert_weights[0] else: raise NotImplementedError("Multi-token MoE routing") # ---- Run selected experts ---- T = x.shape[0] expert_outputs = [] if not SKIP_ROUTED_MOE: for i, eid in enumerate(expert_ids): eid_int = eid.item() epre = f"model.layers.{li}.mlp.experts.{eid_int}" gate = nvfp4_linear(x, w[f"{epre}.gate_proj.weight"], w[f"{epre}.gate_proj.weight_scale"], w[f"{epre}.gate_proj.weight_scale_2"]) up = nvfp4_linear(x, w[f"{epre}.up_proj.weight"], w[f"{epre}.up_proj.weight_scale"], w[f"{epre}.up_proj.weight_scale_2"]) # SwiGLU with clamping (paper §4.2.3) silu_out = torch.nn.functional.silu(gate.float()) if swiglu_limit is not None: silu_out = silu_out.clamp(-swiglu_limit, swiglu_limit) up_clamped = up.float().clamp(-swiglu_limit, swiglu_limit) else: up_clamped = up.float() hidden = (silu_out * up_clamped).bfloat16() down = nvfp4_linear(hidden, w[f"{epre}.down_proj.weight"], w[f"{epre}.down_proj.weight_scale"], w[f"{epre}.down_proj.weight_scale_2"]) expert_outputs.append(down) # Weighted combine + scaling routed_out = torch.zeros_like(x) for i, (out, wt) in enumerate(zip(expert_outputs, expert_weights)): routed_out = routed_out + (out.float() * wt.item()).bfloat16() routed_out = (routed_out.float() * routed_scaling).bfloat16() # ---- Shared expert ---- se_pre = f"model.layers.{li}.mlp.shared_experts" se_gate_key = f"{se_pre}.gate_proj.weight" if se_gate_key in w: gate = nvfp4_linear(x, w[se_gate_key], w[f"{se_pre}.gate_proj.weight_scale"], w[f"{se_pre}.gate_proj.weight_scale_2"]) up = nvfp4_linear(x, w[f"{se_pre}.up_proj.weight"], w[f"{se_pre}.up_proj.weight_scale"], w[f"{se_pre}.up_proj.weight_scale_2"]) silu_out = torch.nn.functional.silu(gate.float()) if swiglu_limit is not None: silu_out = silu_out.clamp(-swiglu_limit, swiglu_limit) up_clamped = up.float().clamp(-swiglu_limit, swiglu_limit) else: up_clamped = up.float() hidden = (silu_out * up_clamped).bfloat16() shared_out = nvfp4_linear(hidden, w[f"{se_pre}.down_proj.weight"], w[f"{se_pre}.down_proj.weight_scale"], w[f"{se_pre}.down_proj.weight_scale_2"]) else: shared_out = torch.zeros_like(x) return routed_out + shared_out # ===================================================================== # Main # ===================================================================== def main(): t_start = time.time() print("=" * 70) print("DSV4 Single-Shot Inference — Full Pipeline (mHC+Attn+MoE)") print(" Proper Sinkhorn mHC, RMSNorm, inverse RoPE, production FMHA") print("=" * 70) with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f: cfg = json.load(f) n_layers = cfg["num_hidden_layers"] H = cfg["hidden_size"] n_h = cfg["num_attention_heads"] hd = cfg["head_dim"] rd = cfg.get("qk_rope_head_dim", cfg.get("rope_dim", 64)) n_hc = 4 print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}") print(f"Experts: {cfg['n_routed_experts']}, top-{cfg.get('num_experts_per_tok', 6)}") # ==== Phase 1: Load weights to CPU ==== print(f"\n{'='*70}\nPhase 1: Loading weights to CPU\n{'='*70}") all_weights = load_weights_to_cpu(CHECKPOINT_DIR) t_loaded = time.time() print(f"Weight loading: {t_loaded - t_start:.1f}s") # ==== Build mHC blocks + RMSNorms (small weights, keep on GPU) ==== print("Building mHC blocks and RMSNorms...") attn_mhc_blocks = {} ffn_mhc_blocks = {} attn_norms = {} ffn_norms = {} for li in range(n_layers): gpu = li % NUM_GPUS dev = f"cuda:{gpu}" # mHC blocks (small weights: fn (24, 28672) FP32 ≈ 2.6MB each) for prefix, blocks in [(f"model.layers.{li}.attn_hc", attn_mhc_blocks), (f"model.layers.{li}.ffn_hc", ffn_mhc_blocks)]: fn_key = f"{prefix}.fn" base_key = f"{prefix}.base" scale_key = f"{prefix}.scale" if fn_key in all_weights and base_key in all_weights and scale_key in all_weights: mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev) mhc.load_from_checkpoint( all_weights[fn_key], all_weights[base_key], all_weights[scale_key]) blocks[li] = mhc else: print(f" WARNING: no mHC weights for {prefix}, using identity fallback") # Fallback: near-identity mHC (small alphas, identity comb) mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev) n = n_hc K = n * H mhc._impl.W_pre = torch.zeros(n, K, dtype=torch.float32, device=dev) mhc._impl.W_post = torch.zeros(n, K, dtype=torch.float32, device=dev) mhc._impl.W_comb = torch.zeros(n*n, K, dtype=torch.float32, device=dev) mhc._impl.S_pre = torch.zeros(1, n, dtype=torch.bfloat16, device=dev) mhc._impl.S_post = torch.ones(n, 1, dtype=torch.bfloat16, device=dev) * 0.5 mhc._impl.S_comb = torch.eye(n, dtype=torch.bfloat16, device=dev) mhc._impl.alpha_pre = torch.tensor(0.01, dtype=torch.float32, device=dev) mhc._impl.alpha_post = torch.tensor(0.01, dtype=torch.float32, device=dev) mhc._impl.alpha_comb = torch.tensor(0.01, dtype=torch.float32, device=dev) blocks[li] = mhc # RMSNorms attn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=dev) an_key = f"model.layers.{li}.input_layernorm.weight" if an_key in all_weights: attn_norm.weight = all_weights[an_key].to(device=dev, dtype=torch.float32) attn_norms[li] = attn_norm ffn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=dev) fn_key = f"model.layers.{li}.post_attention_layernorm.weight" if fn_key in all_weights: ffn_norm.weight = all_weights[fn_key].to(device=dev, dtype=torch.float32) ffn_norms[li] = ffn_norm print(f" attn mHC: {len(attn_mhc_blocks)}, ffn mHC: {len(ffn_mhc_blocks)}") # ==== Global weights (small, keep on gpu0) ==== torch.cuda.set_device(0) embed_w = all_weights.get("model.embed_tokens.weight") embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to('cuda:0')) lm_w = all_weights.get("lm_head.weight", embed_w).bfloat16().to('cuda:0') final_norm_w = all_weights.get("model.norm.weight") if final_norm_w is not None: final_norm_w = final_norm_w.to('cuda:0') # Build RoPE caches with YaRN scaling from model config rope_params = cfg.get("rope_parameters", {}) rope_type = rope_params.get("rope_type", "default") rope_factor = rope_params.get("factor", 1.0) rope_theta = rope_params.get("rope_theta", cfg.get("rope_theta", 10000.0)) original_max_pos = rope_params.get("original_max_position_embeddings", 4096) beta_fast = rope_params.get("beta_fast", 32) beta_slow = rope_params.get("beta_slow", 1) print(f"RoPE: type={rope_type} factor={rope_factor} theta={rope_theta} " f"orig_max_pos={original_max_pos} beta_fast={beta_fast} beta_slow={beta_slow}", flush=True) rope_caches = {g: build_rope_cache( 8192, rd, f"cuda:{g}", theta=rope_theta, rope_type=rope_type, rope_factor=rope_factor, original_max_pos=original_max_pos, beta_fast=beta_fast, beta_slow=beta_slow ) for g in range(NUM_GPUS)} # ==== KV caches (one per layer on its GPU) ==== kv_caches = {} for li in range(n_layers): kv_caches[li] = SimpleKVCache(head_dim=hd, max_seq=8192, device=f"cuda:{li % NUM_GPUS}") # ==== Cache ALL layer weights to GPUs (avoids per-token CPU→GPU transfer) ==== print(f"\n Caching layer weights to GPUs (one-time transfer)...", flush=True) devices = [f"cuda:{g}" for g in range(NUM_GPUS)] layer_weights = cache_all_layer_weights(all_weights, n_layers, devices) print(f" Done. Freeing CPU weights...", flush=True) del all_weights import gc; gc.collect() # ==== Phase 2: Compile FMHA ==== print(f"\n{'='*70}\nPhase 2: JIT compiling\n{'='*70}") from dsv4.kernels.attention.production import dsv4_attention torch.cuda.set_device(0) dummy_q = torch.randn(n_h, 1, hd, dtype=torch.bfloat16, device='cuda:0') dummy_k = torch.randn(1, 1, hd, dtype=torch.bfloat16, device='cuda:0') try: _ = dsv4_attention(dummy_q, dummy_k, dummy_k.clone()) print(" FMHA: compiled OK") except Exception as e: print(f" FMHA error: {e}") t_compiled = time.time() print(f"Compile: {t_compiled - t_loaded:.1f}s") # ==== Phase 2.5: Minimal E2E test ==== print(f"\n{'='*70}\nPhase 2.5: Minimal E2E Test (single token 'The')\n{'='*70}") from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR) minimal_e2e_test(layer_weights, cfg, rope_caches, attn_mhc_blocks, ffn_mhc_blocks, attn_norms, ffn_norms, embed, lm_w, final_norm_w, tokenizer) # ==== Phase 2.6: Single-layer trace ==== if True: # Always run the trace print(f"\n{'='*70}\nPhase 2.6: Single-Layer Trace (layer 0, first prefill token)\n{'='*70}", flush=True) li = 0 dev = f"cuda:0" w = layer_weights[li] pre = f"model.layers.{li}.self_attn" T_dim = 1 positions = torch.tensor([0], dtype=torch.long, device=dev) rope_cos, rope_sin = rope_caches[0] # Start from the embedding tid = torch.tensor([tokenizer.encode("The")[-1]], dtype=torch.long, device=dev) emb = embed(tid) # (1, H) X = mHCBlock.init_state(emb, 4) # (1, 4, H) print(f" X after init_state: |X|={X.abs().max().item():.4f} stream0_mean={X[:,0,:].float().abs().mean().item():.6f}", flush=True) # mHC pre_block attn_mhc = attn_mhc_blocks[0] x_in, ctx = attn_mhc.pre_block(X) print(f" x_in (mHC pre_block): |x_in|={x_in.abs().max().item():.4f} mean={x_in.float().abs().mean().item():.6f}", flush=True) B_l = ctx.B_l C_l = ctx.C_l print(f" B_l row_sums={B_l[0].sum(dim=-1).tolist()}", flush=True) print(f" C_l={C_l[0].tolist()}", flush=True) # RMSNorm a_norm = attn_norms[0] x_normed = a_norm.forward(x_in) print(f" x_normed: |x|={x_normed.abs().max().item():.4f} mean={x_normed.float().abs().mean().item():.6f}", flush=True) # Q projection c_Q = nvfp4_linear(x_normed, w[f"{pre}.q_a_proj.weight"], w[f"{pre}.q_a_proj.weight_scale"], w[f"{pre}.q_a_proj.weight_scale_2"]) print(f" c_Q (q_a_proj): |c_Q|={c_Q.abs().max().item():.4f} mean={c_Q.float().abs().mean().item():.6f}", flush=True) # q_a_norm q_norm_w = w.get(f"{pre}.q_a_norm.weight") if q_norm_w is not None: c_Q_f = c_Q.float() c_Q_rms = c_Q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() c_Q = (c_Q_f * c_Q_rms * q_norm_w.float()).bfloat16() print(f" c_Q after q_a_norm: |c_Q|={c_Q.abs().max().item():.4f}", flush=True) q = nvfp4_linear(c_Q, w[f"{pre}.q_b_proj.weight"], w[f"{pre}.q_b_proj.weight_scale"], w[f"{pre}.q_b_proj.weight_scale_2"]) q_heads = q.reshape(T_dim, n_h, hd) print(f" q_heads: |q|={q_heads.abs().max().item():.4f} mean={q_heads.float().abs().mean().item():.6f}", flush=True) # KV projection kv = nvfp4_linear(x_normed, w[f"{pre}.kv_proj.weight"], w[f"{pre}.kv_proj.weight_scale"], w[f"{pre}.kv_proj.weight_scale_2"]) print(f" kv (kv_proj): |kv|={kv.abs().max().item():.4f} mean={kv.float().abs().mean().item():.6f}", flush=True) # kv_norm kv_norm_w = w.get(f"{pre}.kv_norm.weight") if kv_norm_w is not None: kv_f = kv.float() kv_rms = kv_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() kv = (kv_f * kv_rms * kv_norm_w.float()).bfloat16() print(f" kv after kv_norm: |kv|={kv.abs().max().item():.4f}", flush=True) kv_new = kv.reshape(T_dim, 1, hd) # (1, 1, hd) print(f" kv_new shape: {kv_new.shape}", flush=True) # Apply RoPE q_heads = apply_rope_partial(q_heads, positions, rope_cos, rope_sin, hd, rd) kv_new = apply_rope_partial(kv_new, positions, rope_cos, rope_sin, hd, rd) print(f" After RoPE: |q|={q_heads.abs().max().item():.4f} |kv|={kv_new.abs().max().item():.4f}", flush=True) # Self-attention (single token, trivially weight=1.0) q_input = q_heads.permute(1, 0, 2) # (n_h, 1, hd) k_input = kv_new.permute(1, 0, 2) # (1, 1, hd) -> expand k_expanded = k_input.expand(n_h, -1, -1).contiguous() v_expanded = k_expanded.clone() # K=V in DSV4 MQA attn_out = torch.nn.functional.scaled_dot_product_attention( q_input, k_expanded, v_expanded, scale=1.0/math.sqrt(hd)) attn_out = attn_out.permute(1, 0, 2) # (1, n_h, hd) print(f" attn_out: |o|={attn_out.abs().max().item():.4f} mean={attn_out.float().abs().mean().item():.6f}", flush=True) # Inverse RoPE if INVERSE_ROPE: attn_out = apply_inverse_rope(attn_out, positions, rope_cos, rope_sin, hd, rd) print(f" After inverse RoPE: |o|={attn_out.abs().max().item():.4f}", flush=True) # Output projection o_groups = cfg.get("num_output_groups", 16) o_rank = cfg.get("output_group_dim", 1024) heads_per_group = n_h // o_groups group_input_dim = heads_per_group * hd attn_flat = attn_out.reshape(T_dim, n_h * hd) attn_grouped = attn_flat.reshape(T_dim, o_groups, heads_per_group * hd) oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16() oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim) attn_for_bmm = attn_grouped.permute(1, 0, 2) grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2)) grouped_flat = grouped_out.permute(1, 0, 2).reshape(T_dim, o_groups * o_rank) print(f" grouped_out (wo_a): |o|={grouped_flat.abs().max().item():.4f} mean={grouped_flat.float().abs().mean().item():.6f}", flush=True) F_attn = nvfp4_linear(grouped_flat, w[f"{pre}.o_b_proj.weight"], w[f"{pre}.o_b_proj.weight_scale"], w[f"{pre}.o_b_proj.weight_scale_2"]) print(f" F_attn (wo_b): |F|={F_attn.abs().max().item():.4f} mean={F_attn.float().abs().mean().item():.6f}", flush=True) # mHC post_block X_mid = attn_mhc.post_block(X, F_attn, ctx) print(f" X_mid: |X|={X_mid.abs().max().item():.4f} stream0_mean={X_mid[:,0,:].float().abs().mean().item():.6f}", flush=True) print(f" Layer 0 trace complete.", flush=True) # ==== Phase 3: Inference ==== print(f"\n{'='*70}\nPhase 3: Inference\n{'='*70}") # DeepSeek V4 chat format: <|begin▁of▁sentence|><|User|>prompt<|Assistant|> # For reasoning models: <|User|>prompt<|Assistant|>fithinking...flanswer # Special token IDs: <|User|>=128803, <|Assistant|>=128804, <|EOT|>=128805 # Thinking tokens: fi=128821, fl=128822 USER_TOKEN = 128803 ASSISTANT_TOKEN = 128804 EOT_TOKEN = 128805 THINK_START = 128821 # fi THINK_END = 128822 # fl # Build input with proper DeepSeek chat format bos_id = tokenizer.bos_token_id or 0 # <|User|> System prompt \n\n User prompt <|Assistant|> input_ids_list = [bos_id, USER_TOKEN] input_ids_list += tokenizer.encode(SYSTEM_PROMPT, add_special_tokens=False) input_ids_list += tokenizer.encode('\n\n' + PROMPT, add_special_tokens=False) input_ids_list.append(ASSISTANT_TOKEN) input_ids = torch.tensor([input_ids_list], dtype=torch.long).cuda() print(f"DeepSeek chat format. Input: {input_ids.shape[1]} tokens", flush=True) print(f"Decoded start: '{tokenizer.decode(input_ids[0][:20])}...'", flush=True) print(f"Decoded end: '...{tokenizer.decode(input_ids[0][-5:])}'", flush=True) generated = input_ids[0].tolist() # ==== Prefill: process prompt tokens to fill KV cache ==== print(f"Prefilling {len(generated)} prompt tokens...", flush=True) for prefill_idx, tid_val in enumerate(generated): t0 = time.time() tid = torch.tensor([tid_val], dtype=torch.long, device='cuda:0') positions = torch.tensor([prefill_idx], dtype=torch.long, device='cuda:0') emb = embed(tid) # (1, H) on gpu0 X = mHCBlock.init_state(emb, n_hc) # (1, n_hc, H) for li in range(n_layers): gpu = li % NUM_GPUS dev = f"cuda:{gpu}" if X.device != torch.device(dev): X = X.to(dev) torch.cuda.set_device(gpu) w = layer_weights[li] attn_mhc = attn_mhc_blocks.get(li) ffn_mhc = ffn_mhc_blocks.get(li) a_norm = attn_norms[li] f_norm = ffn_norms[li] rc, rs = rope_caches[gpu] X = forward_layer(X, w, li, cfg, rc, rs, attn_mhc, ffn_mhc, a_norm, f_norm, kv_caches[li], tid, positions) X = X.to('cuda:0') torch.cuda.set_device(0) if prefill_idx % 10 == 0: print(f" Token {prefill_idx}/{len(generated)}: {time.time()-t0:.2f}s", flush=True) print(f" Prefill done ({len(generated)} tokens, {time.time()-t_compiled:.1f}s)") # ==== Decode: generate new tokens ==== print(f"\nDecoding (max {MAX_NEW_TOKENS} new tokens)...") all_tokens = generated.copy() for step in range(MAX_NEW_TOKENS): t0 = time.time() tid = torch.tensor([all_tokens[-1]], dtype=torch.long, device='cuda:0') decode_pos = len(all_tokens) - 1 positions = torch.tensor([decode_pos], dtype=torch.long, device='cuda:0') emb = embed(tid) # (1, H) on gpu0 X = mHCBlock.init_state(emb, n_hc) # (1, n_hc, H) for li in range(n_layers): gpu = li % NUM_GPUS dev = f"cuda:{gpu}" if X.device != torch.device(dev): X = X.to(dev) torch.cuda.set_device(gpu) w = layer_weights[li] attn_mhc = attn_mhc_blocks.get(li) ffn_mhc = ffn_mhc_blocks.get(li) a_norm = attn_norms[li] f_norm = ffn_norms[li] rc, rs = rope_caches[gpu] X = forward_layer(X, w, li, cfg, rc, rs, attn_mhc, ffn_mhc, a_norm, f_norm, kv_caches[li], tid, positions) X = X.to('cuda:0') torch.cuda.set_device(0) # Read out stream 0 → RMSNorm → lm_head x_out = X[:, 0, :] # (1, H) if final_norm_w is not None: xf = x_out.float() rms = xf.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() x_out = (xf * rms * final_norm_w.float()).bfloat16() logits = torch.nn.functional.linear(x_out, lm_w) # Top-5 predictions for debugging # Top-20 predictions for debugging (includes thinking tokens) top20_vals, top20_ids = torch.topk(logits[0], 20) top5_str = ' '.join([f'{tokenizer.decode([tid.item()])}({val.item():.1f})' for tid, val in zip(top20_ids[:5], top20_vals[:5])]) # Check if thinking tokens are in top-20 thinking_in_top20 = any(tid.item() in [128821, 128822] for tid in top20_ids) top20_ids_set = set(top20_ids.tolist()) next_id = torch.argmax(logits, dim=-1).item() generated.append(next_id) all_tokens.append(next_id) tok_str = tokenizer.decode([next_id]) dt = time.time() - t0 has_nan = torch.isnan(logits.float()).any().item() has_inf = torch.isinf(logits.float()).any().item() lmin, lmax = logits.float().min().item(), logits.float().max().item() x_max = X.abs().max().item() print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s) " f"logits=[{lmin:.1f},{lmax:.1f}] nan={has_nan} inf={has_inf} " f"|X|={x_max:.3f} top5: {top5_str}", flush=True) if thinking_in_top20: for tid_t, val_t in zip(top20_ids, top20_vals): if tid_t.item() in [128821, 128822]: print(f" THINK TOKEN: {tid_t.item()} logit={val_t.item():.3f}", flush=True) if step % 5 == 0: print(f" Top-20: {[(tokenizer.decode([t.item()]), f'{v.item():.2f}') for t, v in zip(top20_ids, top20_vals)]}", flush=True) if has_nan or has_inf: print(" Numerical issue — stopping") break if next_id == tokenizer.eos_token_id: break out = tokenizer.decode(generated, skip_special_tokens=True) total = time.time() - t_start print(f"\n{'='*70}") print(f"Input: '{PROMPT}'") print(f"Output: '{out}'") print(f"Total: {total:.1f}s") print(f"{'='*70}") # ===================================================================== # Minimal end-to-end test — single token "The" through the model # ===================================================================== def minimal_e2e_test(layer_weights, cfg, rope_caches, attn_mhc_blocks, ffn_mhc_blocks, attn_norms, ffn_norms, embed, lm_w, final_norm_w, tokenizer): """Process a single token 'The' through the model and check output logits. This is a focused diagnostic: if the model can't even produce reasonable logits for a single token, something is fundamentally wrong in the pipeline. We check: 1. No NaN/Inf in any layer output 2. Residual stream magnitude stays bounded 3. Top-5 logits are sensible (not all Chinese tokens for English) 4. Logit spread (max - min) is > 1.0 (not uniform) """ n_layers = cfg["num_hidden_layers"] H = cfg["hidden_size"] n_h = cfg["num_attention_heads"] hd = cfg["head_dim"] rd = cfg.get("qk_rope_head_dim", cfg.get("rope_dim", 64)) n_hc = 4 # Tokenize just "The" tid = torch.tensor(tokenizer.encode("The"), dtype=torch.long, device='cuda:0') if tid.numel() > 1: # If tokenizer adds BOS, take last token print(f" Note: 'The' tokenized to {tid.numel()} tokens, using last one") tid = tid[-1:] print(f" Token ID: {tid.item()} = '{tokenizer.decode(tid.tolist())}'") # Setup positions = torch.tensor([0], dtype=torch.long, device='cuda:0') emb = embed(tid) # (1, H) X = mHCBlock.init_state(emb, n_hc) # (1, n_hc, H) # Track per-layer diagnostics layer_diags = [] for li in range(n_layers): gpu = li % NUM_GPUS dev = f"cuda:{gpu}" if X.device != torch.device(dev): X = X.to(dev) torch.cuda.set_device(gpu) w = layer_weights[li] attn_mhc = attn_mhc_blocks.get(li) ffn_mhc = ffn_mhc_blocks.get(li) a_norm = attn_norms[li] f_norm = ffn_norms[li] rc, rs = rope_caches[gpu] kv_cache = SimpleKVCache(head_dim=hd, max_seq=8192, device=dev) X = forward_layer(X, w, li, cfg, rc, rs, attn_mhc, ffn_mhc, a_norm, f_norm, kv_cache, tid, positions) # Per-layer diagnostic x_max = X.abs().max().item() has_nan = torch.isnan(X.float()).any().item() has_inf = torch.isinf(X.float()).any().item() # Stream 0 (primary) x0 = X[:, 0, :] x0_mean = x0.float().abs().mean().item() x0_std = x0.float().std().item() layer_diags.append({ 'layer': li, 'gpu': gpu, 'x_max': x_max, 'x0_mean': x0_mean, 'x0_std': x0_std, 'nan': has_nan, 'inf': has_inf }) if has_nan or has_inf: print(f" ❌ Layer {li}: NaN={has_nan} Inf={has_inf} — STOPPING") break X = X.to('cuda:0') torch.cuda.set_device(0) # Final norm + lm_head x_out = X[:, 0, :] if final_norm_w is not None: xf = x_out.float() rms = xf.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() x_out = (xf * rms * final_norm_w.float()).bfloat16() logits = torch.nn.functional.linear(x_out, lm_w) # Results print(f"\n === Minimal E2E Test Results ===") print(f" Logits: min={logits.float().min().item():.2f} max={logits.float().max().item():.2f} " f"spread={logits.float().max().item() - logits.float().min().item():.2f}") print(f" NaN={torch.isnan(logits.float()).any().item()} " f"Inf={torch.isinf(logits.float()).any().item()}") top10_vals, top10_ids = torch.topk(logits[0], 10) print(f" Top-10 predictions:") for i, (tid_v, val) in enumerate(zip(top10_ids, top10_vals)): tok_str = tokenizer.decode([tid_v.item()]) print(f" {i+1}. '{tok_str}' (id={tid_v.item()}, logit={val.item():.3f})") # Print residual stream evolution print(f"\n Residual stream evolution (stream 0):") for d in layer_diags[::5]: # Every 5th layer print(f" L{d['layer']:2d}: |X|={d['x_max']:.1f} " f"mean={d['x0_mean']:.1f} std={d['x0_std']:.1f} " f"nan={d['nan']} inf={d['inf']}") # Always print last if layer_diags: d = layer_diags[-1] print(f" L{d['layer']:2d}: |X|={d['x_max']:.1f} " f"mean={d['x0_mean']:.1f} std={d['x0_std']:.1f} " f"nan={d['nan']} inf={d['inf']}") # Check for reasonable output spread = logits.float().max().item() - logits.float().min().item() if spread < 1.0: print(f" ⚠️ Logit spread {spread:.2f} is very low — model is essentially uniform") else: print(f" ✓ Logit spread {spread:.2f} looks reasonable") return logits, layer_diags if __name__ == "__main__": main()