From d772885d7e40e44c8b1a5fdafeab8b3d444d4044 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 31 May 2026 02:45:52 +0000 Subject: [PATCH] single_shot_inference: proper mHC+RMSNorm+inverse RoPE pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Major rewrite of single_shot_inference.py: - Replace broken mHC (gentle normalization hack) with proper Sinkhorn-Knopp - Add RMSNorm before each sub-block (attention + FFN) - Add inverse RoPE on attention output (paper §2.3.3) - Fix KV cache: RoPE applied before caching, K=V in DSV4 MQA - Fix MoE: proper dense routing with e_bias, SwiGLU clamping - Proper weight mapping: fn→W_stacked, base→S_pre/S_res/S_post, scale→alphas - Add identity mHC fallback when weights missing - No emergency normalization, no bandaids --- .../CONSULTANT_RECCOMMENDATION_DOC.md | 0 single_shot_inference.py | 891 ++++++++++++------ 2 files changed, 588 insertions(+), 303 deletions(-) rename CONSULTANT_RECCOMMENDATION_DOC.md => archived_plans/CONSULTANT_RECCOMMENDATION_DOC.md (100%) diff --git a/CONSULTANT_RECCOMMENDATION_DOC.md b/archived_plans/CONSULTANT_RECCOMMENDATION_DOC.md similarity index 100% rename from CONSULTANT_RECCOMMENDATION_DOC.md rename to archived_plans/CONSULTANT_RECCOMMENDATION_DOC.md diff --git a/single_shot_inference.py b/single_shot_inference.py index 371c3f32..dec2b412 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -1,15 +1,35 @@ #!/usr/bin/env python3 -"""Single-shot DSV4 inference — 8-GPU with mHC + MoE + KV cache. +"""Single-shot DSV4-Pro inference — Full 61-layer pipeline, 8-GPU. -Full model forward pass with all architectural components: -- mHC (Manifold-Constrained Hyper-Connections) -- Q low-rank + KV projection -- RoPE (partial, last 64 dims) -- Production FMHA kernel (tcgen05 MMA) -- Grouped output projection (wo_a BMM + wo_b NVFP4) -- Routed MoE (384 experts, top-6, hash + dense routing) -- Shared expert FFN (SwiGLU with clamping) -- KV cache across decode steps +This is a reference implementation that exercises the production kernel +stack end-to-end. It should be usable as ground truth when integrating +into vLLM or SGLang. + +Architecture (paper §2): + X_l → mHC.pre_block → RMSNorm → Attention → F_attn → mHC.post_block → X_mid + X_mid → mHC.pre_block → RMSNorm → FFN(MoE) → F_ffn → mHC.post_block → X_{l+1} + +Components exercised: + - mHC (Manifold-Constrained Hyper-Connections) — proper Sinkhorn-Knopp + - Low-rank Q projection (q_a → q_b) + KV projection (MQA, 1 KV head) + - Partial RoPE (last 64 dims, GPT-J interleaved) + - Production FMHA kernel (6-warp multi-tile, C API + ctypes) + - Inverse RoPE on attention output (paper §2.3.3) + - Grouped output projection (wo_a BMM + wo_b NVFP4) + - Routed MoE (384 experts, top-6, hash + dense routing, SwiGLU clamp) + - Shared expert (NVFP4 gate/up/down) + - RMSNorm (pre-norm before each sub-block) + - KV cache across decode steps + +Attention type simplification for this single-shot test: + For short sequences (seq_len ≤ sliding_window=128), ALL attention + types (CSA/HCA/SWA) reduce to dense attention over the full KV cache. + CSA's compressed branch and indexer are only needed for long sequences + where seq_len > sliding_window. HCA is dense over compressed entries, + but at short sequence lengths, the compressed sequence is trivially + small. So we use dense MQA attention over the full KV for all layers. + This is mathematically correct for short sequences and exercises the + FMHA kernel properly. Usage (on B200): source /root/dsv4-nvfp4-workspace/venv/bin/activate @@ -20,18 +40,28 @@ import os, sys, time, json, math import torch from pathlib import Path +# ===================================================================== +# Configuration +# ===================================================================== + CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" MAX_NEW_TOKENS = 10 PROMPT = "The capital of France is" NUM_GPUS = 8 # ===================================================================== -# NVFP4 dequantization +# NVFP4 dequantization — matches checkpoint format exactly # ===================================================================== FP4_LUT = torch.tensor([0., 2., 3., 4., 6., 8., 12., 24.]) def dequant_nvfp4_weight(weight, weight_scale, weight_scale_2): + """Dequantize NVFP4 weight to BF16. + + weight: (out_dim, in_dim//2) uint8 — 2 FP4 values per byte + weight_scale: (out_dim, in_dim//16) E4M3 — per-16-element block scale + weight_scale_2: (out_dim, 1) float32 — per-row global scale + """ out_dim = weight.shape[0] in_packed = weight.shape[1] in_features = in_packed * 2 @@ -47,221 +77,279 @@ def dequant_nvfp4_weight(weight, weight_scale, weight_scale_2): scale_expanded = scale_f.repeat_interleave(16, dim=1) return (w_f * scale_expanded).bfloat16() + def nvfp4_linear(x, weight, weight_scale, weight_scale_2): + """BF16 linear with NVFP4 dequant.""" w = dequant_nvfp4_weight(weight, weight_scale, weight_scale_2) return torch.nn.functional.linear(x, w) -def bf16_linear(x, weight): - return torch.nn.functional.linear(x, weight.bfloat16()) + +# ===================================================================== +# RMSNorm — matches dsv4/layers/norm.py +# ===================================================================== + +class RMSNorm: + def __init__(self, hidden_size, eps=1e-6, device='cuda:0'): + self.eps = eps + self.weight = torch.ones(hidden_size, dtype=torch.float32, device=device) + + def forward(self, x): + """x: (T, H) BF16 → (T, H) BF16""" + x_f = x.float() + rms = x_f.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt() + return (x_f * rms * self.weight).to(torch.bfloat16) # ===================================================================== -# mHC +# mHC — proper Sinkhorn-Knopp implementation # ===================================================================== class mHCBlock: - def __init__(self, hidden_dim=7168, n_hc=4, sinkhorn_repeat=20, device='cuda'): + """Manifold-Constrained Hyper-Connections (paper §2.2). + + This is a standalone implementation matching dsv4/layers/mhc.py + but self-contained for single-shot inference. Uses the BF16 matmul + fallback (no DeepGEMM dependency). + + Weight mapping from checkpoint: + fn: (N_proj, K_proj) FP32 — fused projection [W_pre; W_res; W_post] + base: (N_proj,) — static biases [S_pre; S_res; S_post] + scale: (3,) — gating alphas [alpha_pre, alpha_res, alpha_post] + """ + def __init__(self, hidden_dim=7168, n_hc=4, sinkhorn_iters=20, device='cuda:0'): + self.d = hidden_dim self.n_hc = n_hc - self.K = n_hc * hidden_dim - self.sinkhorn_repeat = sinkhorn_repeat + self.K_proj = n_hc * hidden_dim # 28672 + self.N_proj = n_hc + n_hc**2 + n_hc # 4 + 16 + 4 = 24 + self.sinkhorn_iters = sinkhorn_iters self.device = device - self.fn = None - self.hc_scale = None - self.hc_base = None - self.rms_eps = 1e-6 - self.hc_pre_eps = 0.0 - self.hc_sinkhorn_eps = 1e-6 - self.hc_post_mult_value = 2.0 - + self.eps = 1e-6 + + # Weights — set by load_from_checkpoint + self.W_stacked = None # (N_proj, K_proj) FP32 + self.S_pre = None # (1, n_hc) FP32 + self.S_res = None # (n_hc, n_hc) FP32 + self.S_post = None # (n_hc, 1) FP32 + self.alpha_pre = None # float + self.alpha_res = None # float + self.alpha_post = None # float + def load_from_checkpoint(self, fn, base, scale): - self.fn = fn.to(device=self.device, dtype=torch.float32).contiguous() - self.hc_base = base.to(device=self.device, dtype=torch.float32).contiguous() - self.hc_scale = scale.to(device=self.device, dtype=torch.float32).contiguous() - - def pre_block(self, residual): + """Load from checkpoint tensors. + + fn: (24, 28672) FP32 + base: (24,) + scale: (3,) — [alpha_pre, alpha_res, alpha_post] + """ n = self.n_hc - K = self.K - T = residual.shape[0] - res_flat = residual.reshape(T, K).float() - mixes = torch.matmul(res_flat, self.fn.t()) - sqrsum = res_flat.square().sum(dim=-1, keepdim=True) - mixes = mixes * torch.rsqrt(sqrsum / K + self.rms_eps) - - pre_logits = mixes[:, :n] * self.hc_scale[0] + self.hc_base[:n] - pre_mix = torch.sigmoid(pre_logits) + self.hc_pre_eps - post_logits = mixes[:, n:2*n] * self.hc_scale[1] + self.hc_base[n:2*n] - post_mix = torch.sigmoid(post_logits) * self.hc_post_mult_value - comb_logits = (mixes[:, 2*n:].reshape(T, n, n) * self.hc_scale[2] - + self.hc_base[2*n:].reshape(1, n, n)) - comb_mix = torch.softmax(comb_logits, dim=-1) + self.hc_sinkhorn_eps - comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + self.hc_sinkhorn_eps) - for _ in range(self.sinkhorn_repeat - 1): - comb_mix = comb_mix / (comb_mix.sum(dim=-1, keepdim=True) + self.hc_sinkhorn_eps) - comb_mix = comb_mix / (comb_mix.sum(dim=-2, keepdim=True) + self.hc_sinkhorn_eps) - - layer_input = (pre_mix.unsqueeze(-1) * residual.float()).sum(dim=1).bfloat16() - return layer_input, (post_mix, comb_mix) - - def post_block(self, residual, F_out, ctx): - post_mix, comb_mix = ctx - mixed_residual = torch.einsum('tij,tjh->tjh', comb_mix, residual.float()) - post_term = post_mix.unsqueeze(-1) * F_out.unsqueeze(1).float() - residual_next = mixed_residual + post_term - # Gentle normalization: RMSNorm but preserving relative magnitudes - # Only active to prevent runaway growth (MoE should handle most balance) - _T = residual_next.shape[0] - rn_f = residual_next.reshape(_T, self.n_hc, -1) - rms = rn_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() - # Scale RMS so unit norm ≈ 1.0, not squash to sqrt(d) - scale = (rms * math.sqrt(rn_f.shape[-1])).clamp(max=1.0) - return (rn_f * scale).bfloat16() + self.W_stacked = fn.to(device=self.device, dtype=torch.float32).contiguous() + self.S_pre = base[0:n].reshape(1, n).to(device=self.device, dtype=torch.float32) + self.S_res = base[n:n + n*n].reshape(n, n).to(device=self.device, dtype=torch.float32) + self.S_post = base[n + n*n:].reshape(n, 1).to(device=self.device, dtype=torch.float32) + self.alpha_pre = scale[0].item() if scale.numel() > 0 else 0.01 + self.alpha_res = scale[1].item() if scale.numel() > 1 else 0.01 + self.alpha_post = scale[2].item() if scale.numel() > 2 else 0.01 + + def _project_and_rms(self, X_flat): + """Compute RMSNorm(X_flat) @ W_stacked.T → (T, N_proj) BF16. + + X_flat: (T, K_proj) BF16 + """ + T = X_flat.shape[0] + K = self.K_proj + x_f32 = X_flat.float() + d_out = x_f32 @ self.W_stacked.T # (T, N_proj) + sqr_sum = x_f32.pow(2).sum(dim=-1) # (T,) + rms_scale = torch.sqrt(K / (sqr_sum + self.eps)) # (T,) + return (d_out * rms_scale.unsqueeze(-1)).bfloat16() # (T, N_proj) BF16 + + def _dynamic_params(self, X_l): + """Compute A_l, B_l, C_l from residual state. + + X_l: (T, n_hc, d) BF16 + Returns: A_l (T, n_hc), B_l (T, n_hc, n_hc) FP32, C_l (T, n_hc) + """ + T, n, d = X_l.shape + X_flat = X_l.reshape(T, self.K_proj) # (T, K_proj) + proj = self._project_and_rms(X_flat).float() # (T, N_proj) FP32 + + # Split projection + i0, i1, i2, i3 = 0, n, n + n*n, self.N_proj + A_raw = proj[:, i0:i1] + B_raw = proj[:, i1:i2] + C_raw = proj[:, i2:i3] + + # Add biases and scale by gating alphas (paper eqs. 3-5) + A_tilde = self.alpha_pre * A_raw + self.S_pre + B_tilde = self.alpha_res * B_raw + self.S_res.flatten().unsqueeze(0) + C_tilde = self.alpha_post * C_raw + self.S_post.flatten().unsqueeze(0) + + # Apply constraints + A_l = torch.sigmoid(A_tilde) # (T, n_hc) ∈ (0,1) + C_l = 2.0 * torch.sigmoid(C_tilde) # (T, n_hc) ∈ (0,2) + + # B_l: exp → Sinkhorn-Knopp → doubly stochastic + B_exp = torch.exp(B_tilde).reshape(T, n, n) # (T, n_hc, n_hc) + B_l = self._sinkhorn_knopp(B_exp) + + return A_l.bfloat16(), B_l, C_l.bfloat16() + + def _sinkhorn_knopp(self, M, t_max=None): + """Project (T, n, n) positive matrices onto Birkhoff polytope. + + Alternating row/col normalization, t_max=20 iterations. + Result is doubly stochastic: rows and columns each sum to 1. + """ + if t_max is None: + t_max = self.sinkhorn_iters + for _ in range(t_max): + M = M / (M.sum(dim=-1, keepdim=True) + self.eps) # row norm + M = M / (M.sum(dim=-2, keepdim=True) + self.eps) # col norm + return M + + @staticmethod + def init_state(embeddings, n_hc=4): + """Initialise X_0 from token embeddings. + + Convention: broadcast embedding across all n_hc residual streams. + embeddings: (T, d) BF16 + Returns: (T, n_hc, d) BF16 + """ + return embeddings.unsqueeze(1).expand(-1, n_hc, -1).clone() + + def pre_block(self, X_l): + """Compute dynamic mixing params and extract layer input. + + X_l: (T, n_hc, d) BF16 + Returns: x_in (T, d) BF16, ctx tuple + """ + A_l, B_l, C_l = self._dynamic_params(X_l) + # x_in = A_l @ X_l — weighted sum of residual streams + x_in = torch.bmm(A_l.unsqueeze(1), X_l).squeeze(1) # (T, d) + return x_in, (B_l, C_l) + + def post_block(self, X_l, F_out, ctx): + """Apply mHC residual update (paper eq. 1): + X_{l+1} = B_l @ X_l + C_l ⊗ F_out + + X_l: (T, n_hc, d) BF16 — residual state BEFORE sub-layer + F_out: (T, d) BF16 — sub-layer output + ctx: (B_l, C_l) from pre_block + Returns: X_next (T, n_hc, d) BF16 + """ + B_l, C_l = ctx + # B_l is FP32, X_l is BF16 — bmm upcasts automatically + BX = torch.bmm(B_l, X_l.float()) # (T, n_hc, d) FP32 + CF = C_l.unsqueeze(-1) * F_out.unsqueeze(1) # (T, n_hc, d) BF16 + return (BX + CF.float()).bfloat16() + # ===================================================================== -# RoPE +# RoPE — partial, GPT-J interleaved, last rope_dim dims # ===================================================================== -def build_rope_cache(max_pos, head_dim, rope_dim, device, theta=10000.0): +def build_rope_cache(max_pos, rope_dim, device, theta=10000.0): + """Build cos/sin caches for partial RoPE. + + Returns: (cos_cache, sin_cache) each (max_pos, rope_dim//2) BF16 + """ half = rope_dim // 2 freqs = 1.0 / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim)) angles = torch.outer(torch.arange(max_pos, dtype=torch.float32), freqs) - return torch.cos(angles).to(device), torch.sin(angles).to(device) + return torch.cos(angles).bfloat16().to(device), torch.sin(angles).bfloat16().to(device) -def apply_rope(x, positions, cos_cache, sin_cache, rope_dim): + +def apply_rope_partial(x, positions, cos_cache, sin_cache, head_dim, rope_dim): + """Apply partial GPT-J interleaved RoPE to the last rope_dim dims of each head. + + x: (T, n_h, hd) BF16 — already reshaped to per-head + positions: (T,) int64 + Returns: (T, n_h, hd) BF16 with RoPE applied to last rope_dim dims + """ T, n_h, hd = x.shape nope = hd - rope_dim - cos = cos_cache[positions].unsqueeze(1).to(x.dtype) - sin = sin_cache[positions].unsqueeze(1).to(x.dtype) + half = rope_dim // 2 + + cos = cos_cache[positions].unsqueeze(1) # (T, 1, half) BF16 + sin = sin_cache[positions].unsqueeze(1) + out = x.clone() - out[:, :, nope:][..., 0::2] = x[:, :, nope:][..., 0::2] * cos - x[:, :, nope:][..., 1::2] * sin - out[:, :, nope:][..., 1::2] = x[:, :, nope:][..., 0::2] * sin + x[:, :, nope:][..., 1::2] * cos + x_rope = x[:, :, nope:] # (T, n_h, rope_dim) + x_even = x_rope[:, :, 0::2] # (T, n_h, half) + x_odd = x_rope[:, :, 1::2] + + out[:, :, nope:][..., 0::2] = x_even * cos - x_odd * sin + out[:, :, nope:][..., 1::2] = x_even * sin + x_odd * cos + return out + + +def apply_inverse_rope(o, positions, cos_cache, sin_cache, head_dim, rope_dim): + """Apply inverse RoPE (conjugate rotation) to attention output. + + Paper §2.3.3: after attention, the per-head output carries position + information from the RoPE'd queries/keys. Inverse RoPE removes this + so the output is position-invariant. + + o: (T, n_h, hd) BF16 + positions: (T,) int64 + Returns: (T, n_h, hd) BF16 with inverse RoPE on last rope_dim dims + """ + T, n_h, hd = o.shape + nope = hd - rope_dim + half = rope_dim // 2 + + cos = cos_cache[positions].unsqueeze(1) # (T, 1, half) BF16 + sin = sin_cache[positions].unsqueeze(1) + + out = o.clone() + o_rope = o[:, :, nope:] + o_even = o_rope[:, :, 0::2] + o_odd = o_rope[:, :, 1::2] + + # Conjugate rotation + out[:, :, nope:][..., 0::2] = o_even * cos + o_odd * sin + out[:, :, nope:][..., 1::2] = -o_even * sin + o_odd * cos return out # ===================================================================== -# Simple KV cache — BF16 flat, one tensor per layer +# KV cache — BF16 flat, MQA (1 KV head) # ===================================================================== class SimpleKVCache: """Per-layer KV cache for decode. Stores BF16 K,V accumulated across steps. MQA: 1 KV head, so cache is (1, seq_len, hd) per layer.""" - def __init__(self, n_layers, head_dim, max_seq=8192, device='cuda:0'): + def __init__(self, head_dim, max_seq=8192, device='cuda:0'): self.hd = head_dim self.max_seq = max_seq self.device = device - self.k = [torch.zeros(1, max_seq, head_dim, dtype=torch.bfloat16, device=device) for _ in range(n_layers)] - self.v = [torch.zeros(1, max_seq, head_dim, dtype=torch.bfloat16, device=device) for _ in range(n_layers)] - self.len = [0] * n_layers # current sequence length per layer - - def append(self, layer_idx, k_new, v_new): - """Append new K,V. k_new: (1, 1, hd), v_new: (1, 1, hd).""" - pos = self.len[layer_idx] - self.k[layer_idx][0, pos] = k_new[0, 0] - self.v[layer_idx][0, pos] = v_new[0, 0] - self.len[layer_idx] = pos + 1 - - def get(self, layer_idx): - """Get K,V up to current length. Returns (1, seq_len, hd), (1, seq_len, hd).""" - l = self.len[layer_idx] - return self.k[layer_idx][:, :l], self.v[layer_idx][:, :l] + self.k = torch.zeros(1, max_seq, head_dim, dtype=torch.bfloat16, device=device) + self.v = torch.zeros(1, max_seq, head_dim, dtype=torch.bfloat16, device=device) + self.len = 0 + + def append(self, k_new, v_new): + """Append K,V. k_new: (1, T, hd), v_new: (1, T, hd).""" + T = k_new.shape[1] + self.k[0, self.len:self.len + T] = k_new[0] + self.v[0, self.len:self.len + T] = v_new[0] + self.len += T + + def get(self): + """Get K,V up to current length. Returns (1, seq_len, hd) each.""" + return self.k[:, :self.len], self.v[:, :self.len] # ===================================================================== -# Routed MoE forward -# ===================================================================== - -def moe_forward(x, w, li, cfg, token_id): - """Run routed MoE + shared expert. - - x: (1, H) BF16 — input after FFN mHC pre_block - Returns: (1, H) BF16 — combined expert output - """ - H = cfg["hidden_size"] - n_experts = cfg["n_routed_experts"] # 384 - top_k = cfg["num_experts_per_tok"] if "num_experts_per_tok" in cfg else 6 - routed_scaling = cfg.get("routed_scaling_factor", 2.5) - swiglu_limit = cfg.get("swiglu_limit", 10.0) - mlp_inter = cfg["moe_intermediate_size"] # 3072 - - # ---- Hash routing ---- - # For decode, first 3 layers use hash routing (token ID lookup) - # Remaining layers use dense routing (weight projection) - is_hash = li < 3 # Hash routing for first 3 layers - - expert_ids = None - expert_weights = None - - if is_hash: - # tid2eid: (vocab_size, top_k) int64 - tid2eid = w[f"model.layers.{li}.mlp.gate.tid2eid"] - tid = token_id.item() if token_id.numel() == 1 else token_id[0].item() - expert_ids = tid2eid[tid] # (top_k,) int64 - expert_weights = torch.ones(top_k, dtype=torch.float32, device=x.device) / top_k - else: - # Dense routing: gate weight (n_experts, H) BF16 - gate_w = w[f"model.layers.{li}.mlp.gate.weight"] # (384, 7168) BF16 - logits = bf16_linear(x, gate_w) # (1, 384) BF16 - # activation = sqrt(softplus(logits)) - activated = torch.sqrt(torch.nn.functional.softplus(logits.float()) + 1e-6).bfloat16() - # Top-k - scores, indices = activated.float().topk(top_k, dim=-1) # (1, 6) - expert_ids = indices[0] # (6,) - # Renormalize - expert_weights = (scores[0] / scores[0].sum()).float() - - # ---- Run selected experts ---- - expert_outputs = [] - for i, eid in enumerate(expert_ids): - eid_int = eid.item() - epre = f"model.layers.{li}.mlp.experts.{eid_int}" - - gate = nvfp4_linear(x, w[f"{epre}.gate_proj.weight"], - w[f"{epre}.gate_proj.weight_scale"], - w[f"{epre}.gate_proj.weight_scale_2"]) - up = nvfp4_linear(x, w[f"{epre}.up_proj.weight"], - w[f"{epre}.up_proj.weight_scale"], - w[f"{epre}.up_proj.weight_scale_2"]) - - # SiLU + clamp - silu_out = torch.nn.functional.silu(gate.float()).clamp(-swiglu_limit, swiglu_limit) - hidden = (silu_out * up.float()).bfloat16() - - down = nvfp4_linear(hidden, w[f"{epre}.down_proj.weight"], - w[f"{epre}.down_proj.weight_scale"], - w[f"{epre}.down_proj.weight_scale_2"]) - expert_outputs.append(down) - - # Weighted combine + scaling - routed_out = torch.zeros_like(x) - for i, (out, wt) in enumerate(zip(expert_outputs, expert_weights)): - routed_out = routed_out + (out.float() * wt).bfloat16() - routed_out = (routed_out.float() * routed_scaling).bfloat16() - - # ---- Shared expert ---- - se_pre = f"model.layers.{li}.mlp.shared_experts" - se_gate_w = w.get(f"{se_pre}.gate_proj.weight") - if se_gate_w is not None: - gate = nvfp4_linear(x, se_gate_w, - w[f"{se_pre}.gate_proj.weight_scale"], - w[f"{se_pre}.gate_proj.weight_scale_2"]) - up = nvfp4_linear(x, w[f"{se_pre}.up_proj.weight"], - w[f"{se_pre}.up_proj.weight_scale"], - w[f"{se_pre}.up_proj.weight_scale_2"]) - silu_out = torch.nn.functional.silu(gate.float()).clamp(-swiglu_limit, swiglu_limit) - hidden = (silu_out * up.float()).bfloat16() - shared_out = nvfp4_linear(hidden, w[f"{se_pre}.down_proj.weight"], - w[f"{se_pre}.down_proj.weight_scale"], - w[f"{se_pre}.down_proj.weight_scale_2"]) - else: - shared_out = torch.zeros_like(x) - - return routed_out + shared_out - - -# ===================================================================== -# Weight loading +# Weight loading — streams safetensors shards, distributes to 8 GPUs # ===================================================================== def load_all_weights(checkpoint_dir, num_layers): + """Load all weights from checkpoint, distribute layers across GPUs. + + Returns: + layer_weights: dict[li] → {key: tensor on cuda:li%8} + global_weights: {key: tensor on cuda:0} + """ from safetensors.torch import load_file cdir = Path(checkpoint_dir) index_path = cdir / "model.safetensors.index.json" @@ -284,7 +372,7 @@ def load_all_weights(checkpoint_dir, num_layers): if loaded % 20 == 0: print(f" {loaded}/{len(shard_names)} shards, {len(all_weights)} tensors") print(f" Done: {len(all_weights)} tensors") - + layer_weights = {} global_weights = {} print("Assigning to GPUs...") @@ -302,10 +390,11 @@ def load_all_weights(checkpoint_dir, num_layers): global_weights[key] = tensor.to("cuda:0") elif key.startswith("lm_head"): global_weights[key] = tensor.to("cuda:0") - + for gpu in range(NUM_GPUS): alloc = torch.cuda.memory_allocated(gpu) / 1e9 - print(f" GPU {gpu}: {alloc:.1f}GB") + if alloc > 0: + print(f" GPU {gpu}: {alloc:.1f}GB") return layer_weights, global_weights @@ -313,88 +402,243 @@ def load_all_weights(checkpoint_dir, num_layers): # Single layer forward # ===================================================================== -def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, attn_mhc, ffn_mhc, - kv_cache, token_id, decode_pos): - """Forward one layer with mHC + MoE + KV cache.""" +def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, + attn_mhc, ffn_mhc, attn_norm, ffn_norm, + kv_cache, token_id, positions): + """Forward one layer with mHC + Attention + FFN. + + Architecture (paper §2): + X_l → mHC.pre_block(attn) → RMSNorm → Attention → F_attn → mHC.post_block → X_mid + X_mid → mHC.pre_block(ffn) → RMSNorm → MoE → F_ffn → mHC.post_block → X_{l+1} + + X_l: (T, n_hc, H) BF16 — mHC residual state + Returns: X_next (T, n_hc, H) BF16 + """ device = X_l.device H = cfg["hidden_size"] n_h = cfg["num_attention_heads"] hd = cfg["head_dim"] - rd = cfg["qk_rope_head_dim"] - o_rank = cfg["o_lora_rank"] - o_groups = cfg["o_groups"] + rd = cfg.get("qk_rope_head_dim", cfg.get("rope_dim", 64)) + o_rank = cfg.get("output_group_dim", 1024) + o_groups = cfg.get("num_output_groups", 16) n_hc = 4 pre = f"model.layers.{li}.self_attn" T = X_l.shape[0] heads_per_group = n_h // o_groups group_input_dim = heads_per_group * hd - - # ==== mHC pre_block (attention) ==== - x_in, attn_ctx = attn_mhc.pre_block(X_l) - - # ==== Q projection ==== - c_Q = nvfp4_linear(x_in, w[f"{pre}.q_a_proj.weight"], + + # ================================================================== + # ATTENTION SUB-BLOCK + # ================================================================== + + # -- mHC pre_block (attention) -- + x_in, attn_ctx = attn_mhc.pre_block(X_l) # x_in: (T, H) + + # -- RMSNorm (pre-norm before attention) -- + x_normed = attn_norm.forward(x_in) # (T, H) BF16 + + # -- Q projection: q_a (low-rank down) → q_b (low-rank up) -- + c_Q = nvfp4_linear(x_normed, + w[f"{pre}.q_a_proj.weight"], w[f"{pre}.q_a_proj.weight_scale"], - w[f"{pre}.q_a_proj.weight_scale_2"]) - q = nvfp4_linear(c_Q, w[f"{pre}.q_b_proj.weight"], + w[f"{pre}.q_a_proj.weight_scale_2"]) # (T, dc) + q = nvfp4_linear(c_Q, + w[f"{pre}.q_b_proj.weight"], w[f"{pre}.q_b_proj.weight_scale"], - w[f"{pre}.q_b_proj.weight_scale_2"]) - - # ==== KV projection ==== - kv = nvfp4_linear(x_in, w[f"{pre}.kv_proj.weight"], + w[f"{pre}.q_b_proj.weight_scale_2"]) # (T, n_h * hd) + + # -- KV projection (MQA: 1 KV head) -- + kv = nvfp4_linear(x_normed, + w[f"{pre}.kv_proj.weight"], w[f"{pre}.kv_proj.weight_scale"], - w[f"{pre}.kv_proj.weight_scale_2"]) - - # ==== Reshape for attention ==== - q_heads = q.reshape(T, n_h, hd).permute(1, 0, 2) # (n_h, T, hd) - k_new = kv.reshape(T, 1, hd).permute(1, 0, 2) # (1, T, hd) - v_new = k_new.clone() - - # ==== KV cache: append new K,V ==== - kv_cache.append(0, k_new, v_new) - k_full, v_full = kv_cache.get(0) # (1, seq_len, hd) - seq_len = k_full.shape[1] - - # ==== RoPE ==== - # Apply RoPE to Q (at current position) - q_pos = torch.tensor([decode_pos], dtype=torch.long, device=device) - q_heads = apply_rope(q_heads, q_pos, rope_cos, rope_sin, rd) - - # Apply RoPE to K (at each position in the cache) - k_positions = torch.arange(seq_len, dtype=torch.long, device=device) - k_full_3d = k_full.permute(1, 0, 2) # (seq_len, 1, hd) for RoPE - k_full_3d = apply_rope(k_full_3d, k_positions, rope_cos, rope_sin, rd) - k_full = k_full_3d.permute(1, 0, 2) # (1, seq_len, hd) — RoPE'd - - # ==== FMHA ==== + w[f"{pre}.kv_proj.weight_scale_2"]) # (T, hd) — 1 KV head, no split + + # -- Reshape for attention -- + q_heads = q.reshape(T, n_h, hd) # (T, n_h, hd) + kv_new = kv.reshape(T, 1, hd) # (T, 1, hd) — 1 KV head + + # -- Apply RoPE to Q (at current positions) -- + q_heads = apply_rope_partial(q_heads, positions, rope_cos, rope_sin, hd, rd) + + # -- Apply RoPE to KV (at current positions) BEFORE caching -- + # DSV4 convention: RoPE applied to KV before writing to cache. + # K = V in DSV4 MQA (same projection, same RoPE'd tensor). + kv_new = apply_rope_partial(kv_new, positions, rope_cos, rope_sin, hd, rd) + + # -- KV cache: append RoPE'd KV (K=V) -- + k_new = kv_new # (T, 1, hd) — RoPE'd + v_new = kv_new # K = V in DSV4 MQA + kv_cache.append(k_new.permute(1, 0, 2), v_new.permute(1, 0, 2)) # (1, T, hd) + + # -- Get full KV from cache (already RoPE'd) -- + k_full, v_full = kv_cache.get() # (1, seq_len, hd) each — RoPE'd, K=V + + # -- FMHA: (n_h, T, hd) × (1, seq_len, hd) → (n_h, T, hd) -- from dsv4.kernels.attention.production import dsv4_attention - attn_out = dsv4_attention(q_heads, k_full, v_full) - attn_out = attn_out.permute(1, 0, 2).reshape(T, n_h * hd) - - # ==== Output projection ==== - attn_grouped = attn_out.reshape(T, o_groups, heads_per_group, hd).reshape(T, o_groups, group_input_dim) - oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16() - oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim) - attn_for_bmm = attn_grouped.permute(1, 0, 2) - grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2)) - grouped_flat = grouped_out.permute(1, 0, 2).reshape(T, o_groups * o_rank) - F_attn = nvfp4_linear(grouped_flat, w[f"{pre}.o_b_proj.weight"], + q_input = q_heads.permute(1, 0, 2) # (n_h, T, hd) + k_input = k_full.permute(1, 0, 2) # (1, seq_len, hd) — already RoPE'd + v_input = v_full.permute(1, 0, 2) # (1, seq_len, hd) — K=V, RoPE'd + attn_out = dsv4_attention(q_input, k_input, v_input) # (n_h, T, hd) + attn_out = attn_out.permute(1, 0, 2) # (T, n_h, hd) + + # -- Inverse RoPE on attention output (paper §2.3.3) -- + attn_out = apply_inverse_rope(attn_out, positions, rope_cos, rope_sin, hd, rd) + + # -- Output projection: wo_a (grouped BMM) + wo_b (NVFP4) -- + # wo_a: grouped linear, (n_h, hd) → (n_groups, o_rank) via BMM + attn_flat = attn_out.reshape(T, n_h * hd) # (T, n_h * hd) + attn_grouped = attn_flat.reshape(T, o_groups, heads_per_group * hd) # (T, groups, group_dim) + oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16() # (n_groups * o_rank, group_input_dim) BF16 + oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim) # (groups, o_rank, group_dim) + attn_for_bmm = attn_grouped.permute(1, 0, 2) # (groups, T, group_dim) + grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2)) # (groups, T, o_rank) + grouped_flat = grouped_out.permute(1, 0, 2).reshape(T, o_groups * o_rank) # (T, groups*o_rank) + + F_attn = nvfp4_linear(grouped_flat, + w[f"{pre}.o_b_proj.weight"], w[f"{pre}.o_b_proj.weight_scale"], - w[f"{pre}.o_b_proj.weight_scale_2"]) - - # ==== mHC post_block (attention) ==== - X_l = attn_mhc.post_block(X_l, F_attn, attn_ctx) - - # ==== mHC pre_block (FFN) ==== - x_ffn, ffn_ctx = ffn_mhc.pre_block(X_l) - - # ==== MoE + shared expert ==== - F_ffn = moe_forward(x_ffn, w, li, cfg, token_id) - - # ==== mHC post_block (FFN) ==== - X_l = ffn_mhc.post_block(X_l, F_ffn, ffn_ctx) - - return X_l + w[f"{pre}.o_b_proj.weight_scale_2"]) # (T, H) + + # -- mHC post_block (attention) -- + X_mid = attn_mhc.post_block(X_l, F_attn, attn_ctx) # (T, n_hc, H) + + # ================================================================== + # FFN SUB-BLOCK + # ================================================================== + + # -- mHC pre_block (FFN) -- + x_ffn, ffn_ctx = ffn_mhc.pre_block(X_mid) # (T, H) + + # -- RMSNorm (pre-norm before FFN) -- + x_ffn_normed = ffn_norm.forward(x_ffn) # (T, H) BF16 + + # -- MoE + shared expert -- + F_ffn = moe_forward(x_ffn_normed, w, li, cfg, token_id, device) + + # -- mHC post_block (FFN) -- + X_next = ffn_mhc.post_block(X_mid, F_ffn, ffn_ctx) # (T, n_hc, H) + + return X_next + + +# ===================================================================== +# MoE forward — hash + dense routing, SwiGLU with clamping +# ===================================================================== + +def moe_forward(x, w, li, cfg, token_id, device): + """Run routed MoE + shared expert. + + x: (T, H) BF16 — post-RMSNorm FFN input + Returns: (T, H) BF16 + """ + H = cfg["hidden_size"] + n_experts = cfg["n_routed_experts"] + top_k = cfg.get("num_experts_per_tok", 6) + routed_scaling = cfg.get("routed_scaling_factor", 2.5) + swiglu_limit = cfg.get("swiglu_limit", 10.0) + mlp_inter = cfg["moe_intermediate_size"] + is_hash = li < 3 + + # ---- Routing ---- + if is_hash: + tid2eid_key = f"model.layers.{li}.mlp.gate.tid2eid" + if tid2eid_key in w: + tid2eid = w[tid2eid_key] + tid = token_id.item() if token_id.numel() == 1 else token_id[0].item() + expert_ids = tid2eid[tid] # (top_k,) + else: + # Fallback: use dense routing even for hash layers + is_hash = False + + if not is_hash: + # Dense routing: sqrt(softplus(X @ W_gate)) + e_bias for selection + gate_w = w[f"model.layers.{li}.mlp.gate.weight"] # (H, n_experts) BF16 + logits = torch.nn.functional.linear(x, gate_w.bfloat16()) # (T, n_experts) + # Activation: sqrt(softplus(logits)) + activated = torch.sqrt(torch.nn.functional.softplus(logits.float()) + 1e-6) + # e_bias: learned per-expert bias for SELECTION ONLY (not in weights) + e_bias_key = f"model.layers.{li}.mlp.gate.e_bias" + if e_bias_key in w: + activated = activated + w[e_bias_key].float().unsqueeze(0) + # Top-k + scores, indices = activated.topk(top_k, dim=-1) # (T, top_k) + # Renormalize on UNBIASED activation + # Re-compute unbiased activation for weights + unbiased = torch.sqrt(torch.nn.functional.softplus(logits.float()) + 1e-6) + unbiased_scores = torch.gather(unbiased, -1, indices) + expert_weights = unbiased_scores / unbiased_scores.sum(dim=-1, keepdim=True) + # For T=1 decode, squeeze + if x.shape[0] == 1: + expert_ids = indices[0] + expert_weights = expert_weights[0] + else: + # Per-token routing (not yet needed for decode) + raise NotImplementedError("Multi-token MoE routing") + + # ---- Run selected experts ---- + T = x.shape[0] + expert_outputs = [] + for i, eid in enumerate(expert_ids): + eid_int = eid.item() + epre = f"model.layers.{li}.mlp.experts.{eid_int}" + + gate = nvfp4_linear(x, + w[f"{epre}.gate_proj.weight"], + w[f"{epre}.gate_proj.weight_scale"], + w[f"{epre}.gate_proj.weight_scale_2"]) + up = nvfp4_linear(x, + w[f"{epre}.up_proj.weight"], + w[f"{epre}.up_proj.weight_scale"], + w[f"{epre}.up_proj.weight_scale_2"]) + + # SwiGLU with clamping (paper §4.2.3) + silu_out = torch.nn.functional.silu(gate.float()) + if swiglu_limit is not None: + silu_out = silu_out.clamp(-swiglu_limit, swiglu_limit) + up_clamped = up.float().clamp(-swiglu_limit, swiglu_limit) + else: + up_clamped = up.float() + hidden = (silu_out * up_clamped).bfloat16() + + down = nvfp4_linear(hidden, + w[f"{epre}.down_proj.weight"], + w[f"{epre}.down_proj.weight_scale"], + w[f"{epre}.down_proj.weight_scale_2"]) + expert_outputs.append(down) + + # Weighted combine + scaling + routed_out = torch.zeros_like(x) + for i, (out, wt) in enumerate(zip(expert_outputs, expert_weights)): + routed_out = routed_out + (out.float() * wt).bfloat16() + routed_out = (routed_out.float() * routed_scaling).bfloat16() + + # ---- Shared expert ---- + se_pre = f"model.layers.{li}.mlp.shared_experts" + se_gate_key = f"{se_pre}.gate_proj.weight" + if se_gate_key in w: + gate = nvfp4_linear(x, + w[se_gate_key], + w[f"{se_pre}.gate_proj.weight_scale"], + w[f"{se_pre}.gate_proj.weight_scale_2"]) + up = nvfp4_linear(x, + w[f"{se_pre}.up_proj.weight"], + w[f"{se_pre}.up_proj.weight_scale"], + w[f"{se_pre}.up_proj.weight_scale_2"]) + silu_out = torch.nn.functional.silu(gate.float()) + if swiglu_limit is not None: + silu_out = silu_out.clamp(-swiglu_limit, swiglu_limit) + up_clamped = up.float().clamp(-swiglu_limit, swiglu_limit) + else: + up_clamped = up.float() + hidden = (silu_out * up_clamped).bfloat16() + shared_out = nvfp4_linear(hidden, + w[f"{se_pre}.down_proj.weight"], + w[f"{se_pre}.down_proj.weight_scale"], + w[f"{se_pre}.down_proj.weight_scale_2"]) + else: + shared_out = torch.zeros_like(x) + + return routed_out + shared_out # ===================================================================== @@ -404,34 +648,38 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, attn_mhc, ffn_mhc, def main(): t_start = time.time() print("=" * 70) - print("DSV4 Single-Shot Inference — Full Pipeline (mHC+MoE+KV)") + print("DSV4 Single-Shot Inference — Full Pipeline (mHC+Attn+MoE)") + print(" Proper Sinkhorn mHC, RMSNorm, inverse RoPE, production FMHA") print("=" * 70) - + with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f: cfg = json.load(f) n_layers = cfg["num_hidden_layers"] H = cfg["hidden_size"] n_h = cfg["num_attention_heads"] hd = cfg["head_dim"] - rd = cfg["qk_rope_head_dim"] + rd = cfg.get("qk_rope_head_dim", cfg.get("rope_dim", 64)) n_hc = 4 print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}") print(f"Experts: {cfg['n_routed_experts']}, top-{cfg.get('num_experts_per_tok', 6)}") - + # ==== Phase 1: Load weights ==== print(f"\n{'='*70}\nPhase 1: Loading weights\n{'='*70}") layer_weights, global_weights = load_all_weights(CHECKPOINT_DIR, n_layers) t_loaded = time.time() print(f"Weight loading: {t_loaded - t_start:.1f}s") - - # ==== Build mHC blocks ==== + + # ==== Build mHC blocks (proper Sinkhorn) ==== print("Building mHC blocks...") attn_mhc_blocks = {} ffn_mhc_blocks = {} + attn_norms = {} + ffn_norms = {} for li in range(n_layers): gpu = li % NUM_GPUS dev = f"cuda:{gpu}" - + + # mHC blocks for prefix, blocks in [(f"model.layers.{li}.attn_hc", attn_mhc_blocks), (f"model.layers.{li}.ffn_hc", ffn_mhc_blocks)]: fn = layer_weights[li].get(f"{prefix}.fn") @@ -441,25 +689,53 @@ def main(): mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev) mhc.load_from_checkpoint(fn, base, scale) blocks[li] = mhc - + else: + # Fallback: identity mHC (A=1, B=I, C=1) — not ideal but prevents crash + print(f" WARNING: no mHC weights for {prefix}, using identity fallback") + mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev) + n = n_hc + K = n * H + mhc.W_stacked = torch.zeros(n + n*n + n, K, dtype=torch.float32, device=dev) + mhc.S_pre = torch.zeros(1, n, dtype=torch.float32, device=dev) + mhc.S_res = torch.eye(n, dtype=torch.float32, device=dev) + mhc.S_post = torch.ones(n, 1, dtype=torch.float32, device=dev) * 0.5 + mhc.alpha_pre = 0.01 + mhc.alpha_res = 0.01 + mhc.alpha_post = 0.01 + blocks[li] = mhc + + # RMSNorms (pre-norm before each sub-block) + attn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=dev) + an_key = f"model.layers.{li}.input_layernorm.weight" + if an_key in layer_weights[li]: + attn_norm.weight = layer_weights[li][an_key].to(device=dev, dtype=torch.float32) + attn_norms[li] = attn_norm + + ffn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=dev) + fn_key = f"model.layers.{li}.post_attention_layernorm.weight" + if fn_key in layer_weights[li]: + ffn_norm.weight = layer_weights[li][fn_key].to(device=dev, dtype=torch.float32) + ffn_norms[li] = ffn_norm + print(f" attn mHC: {len(attn_mhc_blocks)}, ffn mHC: {len(ffn_mhc_blocks)}") - + # ==== Global weights ==== torch.cuda.set_device(0) embed_w = global_weights.get("model.embed_tokens.weight") embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16()) lm_w = global_weights.get("lm_head.weight", embed_w).bfloat16() final_norm_w = global_weights.get("model.norm.weight") - rope_caches = {g: build_rope_cache(8192, hd, rd, f"cuda:{g}") for g in range(NUM_GPUS)} - - # ==== KV cache (gpu0, moves to target GPU per layer) ==== + rope_caches = {g: build_rope_cache(8192, rd, f"cuda:{g}") for g in range(NUM_GPUS)} + + # ==== KV caches (one per layer on its GPU) ==== kv_caches = {} for li in range(n_layers): - kv_caches[li] = SimpleKVCache(n_layers=1, head_dim=hd, max_seq=8192, device=f"cuda:{li % NUM_GPUS}") - - # ==== Phase 2: Compile ==== + kv_caches[li] = SimpleKVCache(head_dim=hd, max_seq=8192, device=f"cuda:{li % NUM_GPUS}") + + # ==== Phase 2: Compile FMHA ==== print(f"\n{'='*70}\nPhase 2: JIT compiling\n{'='*70}") from dsv4.kernels.attention.production import dsv4_attention + torch.cuda.set_device(0) dummy_q = torch.randn(n_h, 1, hd, dtype=torch.bfloat16, device='cuda:0') dummy_k = torch.randn(1, 1, hd, dtype=torch.bfloat16, device='cuda:0') try: @@ -469,96 +745,105 @@ def main(): print(f" FMHA error: {e}") t_compiled = time.time() print(f"Compile: {t_compiled - t_loaded:.1f}s") - + # ==== Phase 3: Inference ==== print(f"\n{'='*70}\nPhase 3: Inference\n{'='*70}") from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR) input_ids = tokenizer.encode(PROMPT, return_tensors="pt").cuda() print(f"Prompt: '{PROMPT}' → {input_ids.tolist()}") - + generated = input_ids[0].tolist() - + # ==== Prefill: process prompt tokens to fill KV cache ==== print(f"Prefilling {len(generated)} prompt tokens...") for prefill_idx, tid_val in enumerate(generated): tid = torch.tensor([tid_val], dtype=torch.long, device='cuda:0') - emb = embed(tid) - X = emb.unsqueeze(1).expand(-1, n_hc, -1).clone() - + positions = torch.tensor([prefill_idx], dtype=torch.long, device='cuda:0') + emb = embed(tid) # (1, H) on gpu0 + X = mHCBlock.init_state(emb, n_hc) # (1, n_hc, H) + for li in range(n_layers): gpu = li % NUM_GPUS target_device = f"cuda:{gpu}" if X.device != torch.device(target_device): X = X.to(target_device) torch.cuda.set_device(gpu) - + attn_mhc = attn_mhc_blocks.get(li) ffn_mhc = ffn_mhc_blocks.get(li) + a_norm = attn_norms[li] + f_norm = ffn_norms[li] rc, rs = rope_caches[gpu] - X = forward_layer(X, layer_weights[li], li, cfg, rc, rs, - attn_mhc, ffn_mhc, kv_caches[li], tid, prefill_idx) - + X = forward_layer(X, layer_weights[li], li, cfg, rc, rs, + attn_mhc, ffn_mhc, a_norm, f_norm, + kv_caches[li], tid, positions) + X = X.to('cuda:0') torch.cuda.set_device(0) + print(f" Prefill done ({len(generated)} tokens, {time.time()-t_compiled:.1f}s)") - + # ==== Decode: generate new tokens ==== print(f"\nDecoding (max {MAX_NEW_TOKENS} new tokens)...") - all_tokens = generated.copy() # track full sequence including prompt - + all_tokens = generated.copy() + for step in range(MAX_NEW_TOKENS): t0 = time.time() - # Current token (last in the sequence) tid = torch.tensor([all_tokens[-1]], dtype=torch.long, device='cuda:0') - decode_pos = len(all_tokens) - 1 # absolute position - - # Embed → mHC init state + decode_pos = len(all_tokens) - 1 + positions = torch.tensor([decode_pos], dtype=torch.long, device='cuda:0') + emb = embed(tid) # (1, H) on gpu0 - X = emb.unsqueeze(1).expand(-1, n_hc, -1).clone() # (1, n_hc, H) - - # Process layers + X = mHCBlock.init_state(emb, n_hc) # (1, n_hc, H) + for li in range(n_layers): gpu = li % NUM_GPUS target_device = f"cuda:{gpu}" if X.device != torch.device(target_device): X = X.to(target_device) torch.cuda.set_device(gpu) - + attn_mhc = attn_mhc_blocks.get(li) ffn_mhc = ffn_mhc_blocks.get(li) + a_norm = attn_norms[li] + f_norm = ffn_norms[li] rc, rs = rope_caches[gpu] - X = forward_layer(X, layer_weights[li], li, cfg, rc, rs, - attn_mhc, ffn_mhc, kv_caches[li], tid, decode_pos) - - # Back to gpu0 + X = forward_layer(X, layer_weights[li], li, cfg, rc, rs, + attn_mhc, ffn_mhc, a_norm, f_norm, + kv_caches[li], tid, positions) + X = X.to('cuda:0') torch.cuda.set_device(0) - + # Read out stream 0 → RMSNorm → lm_head - x_out = X[:, 0, :] + x_out = X[:, 0, :] # (1, H) if final_norm_w is not None: xf = x_out.float() rms = xf.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() x_out = (xf * rms * final_norm_w.float()).bfloat16() - + logits = torch.nn.functional.linear(x_out, lm_w) next_id = torch.argmax(logits, dim=-1).item() generated.append(next_id) all_tokens.append(next_id) - + tok_str = tokenizer.decode([next_id]) dt = time.time() - t0 has_nan = torch.isnan(logits.float()).any().item() + has_inf = torch.isinf(logits.float()).any().item() lmin, lmax = logits.float().min().item(), logits.float().max().item() - print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s) logits=[{lmin:.1f},{lmax:.1f}] nan={has_nan}") - - if has_nan: - print(" NaN — stopping") + x_max = X.abs().max().item() + print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s) " + f"logits=[{lmin:.1f},{lmax:.1f}] nan={has_nan} inf={has_inf} " + f"|X|={x_max:.3f}") + + if has_nan or has_inf: + print(" Numerical issue — stopping") break if next_id == tokenizer.eos_token_id: break - + out = tokenizer.decode(generated, skip_special_tokens=True) total = time.time() - t_start print(f"\n{'='*70}")