From 3f04a72af45bcb9b00c97338d28b1eff39a6ec5b Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 31 May 2026 04:06:58 +0000 Subject: [PATCH] refactor: use production mHCLayer from dsv4.layers.mhc Replace custom mHCBlock with wrapper around the tested production mHCLayer class. This eliminates any bugs in my custom implementation and uses the same code path that the model was designed for. Weight mapping: fn[0:4]=W_pre, fn[4:8]=W_post, fn[8:24]=W_res base[0:4]=S_pre, base[4:8]=S_post, base[8:24]=S_res scale[0]=alpha_pre, scale[1]=alpha_post, scale[2]=alpha_res --- single_shot_inference.py | 224 ++++++--------------------------------- 1 file changed, 32 insertions(+), 192 deletions(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index c8e6e8b8..8c50bce8 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -105,219 +105,59 @@ class RMSNorm: # ===================================================================== class mHCBlock: - """Manifold-Constrained Hyper-Connections (paper §2.2). - - This is a standalone implementation matching dsv4/layers/mhc.py - but self-contained for single-shot inference. Uses the BF16 matmul - fallback (no DeepGEMM dependency). - - Weight mapping from checkpoint: - fn: (N_proj, K_proj) FP32 — fused projection [W_pre; W_res; W_post] - base: (N_proj,) — static biases [S_pre; S_res; S_post] - scale: (3,) — gating alphas [alpha_pre, alpha_res, alpha_post] + """Wrapper around dsv4.layers.mhc.mHCLayer for single-shot inference. + + Uses the production mHCLayer implementation with proper Sinkhorn-Knopp. """ def __init__(self, hidden_dim=7168, n_hc=4, sinkhorn_iters=20, device='cuda:0'): - self.d = hidden_dim - self.n_hc = n_hc - self.K_proj = n_hc * hidden_dim # 28672 - self.N_proj = n_hc + n_hc**2 + n_hc # 4 + 16 + 4 = 24 - self.sinkhorn_iters = sinkhorn_iters + from dsv4.layers.mhc import mHCLayer + self._impl = mHCLayer( + hidden_dim=hidden_dim, n_hc=n_hc, + t_max_sinkhorn=sinkhorn_iters, + device=device, dtype=torch.bfloat16) self.device = device - self.eps = 1e-6 - - # Weights — set by load_from_checkpoint - self.W_stacked = None # (N_proj, K_proj) FP32 - self.S_pre = None # (1, n_hc) FP32 - self.S_res = None # (n_hc, n_hc) FP32 - self.S_post = None # (n_hc, 1) FP32 - self.alpha_pre = None # float - self.alpha_res = None # float - self.alpha_post = None # float + self.n_hc = n_hc + self.hidden_dim = hidden_dim def load_from_checkpoint(self, fn, base, scale): """Load from checkpoint tensors. - - fn: (24, 28672) FP32 + fn: (24, 28672) FP32 — fused projection base: (24,) — [pre(4), post(4), res(16)] scale: (3,) — [alpha_pre, alpha_post, alpha_res] """ n = self.n_hc - # CRITICAL: checkpoint base order is [pre, post, res], not [pre, res, post] - # This matches the old working code: base[:4]=pre, base[4:8]=post, base[8:24]=res - self.W_stacked = fn.to(device=self.device, dtype=torch.float32).contiguous() - self.S_pre = base[0:n].reshape(1, n).to(device=self.device, dtype=torch.float32) - self.S_post = base[n:2*n].reshape(n, 1).to(device=self.device, dtype=torch.float32) - self.S_res = base[2*n:2*n + n*n].reshape(n, n).to(device=self.device, dtype=torch.float32) - # CRITICAL: checkpoint scale order is [alpha_pre, alpha_post, alpha_res] - self.alpha_pre = scale[0].item() if scale.numel() > 0 else 0.01 - self.alpha_post = scale[1].item() if scale.numel() > 1 else 0.01 - self.alpha_res = scale[2].item() if scale.numel() > 2 else 0.01 + dev = self.device - def _project_and_rms(self, X_flat): - """Compute RMSNorm(X_flat) @ W_stacked.T → (T, N_proj) BF16. + # fn rows: [W_pre(4), W_post(4), W_res(16)] + W_pre = fn[0:n].to(device=dev, dtype=torch.float32).contiguous() + W_post = fn[n:2*n].to(device=dev, dtype=torch.float32).contiguous() + W_res = fn[2*n:].to(device=dev, dtype=torch.float32).contiguous() - X_flat: (T, K_proj) BF16 - """ - T = X_flat.shape[0] - K = self.K_proj - x_f32 = X_flat.float() - d_out = x_f32 @ self.W_stacked.T # (T, N_proj) - sqr_sum = x_f32.pow(2).sum(dim=-1) # (T,) - rms_scale = torch.sqrt(K / (sqr_sum + self.eps)) # (T,) - return (d_out * rms_scale.unsqueeze(-1)).bfloat16() # (T, N_proj) BF16 + # base: [S_pre(4), S_post(4), S_res(16)] + S_pre = base[0:n].reshape(1, n).to(device=dev, dtype=torch.bfloat16).contiguous() + S_post = base[n:2*n].reshape(n, 1).to(device=dev, dtype=torch.bfloat16).contiguous() + S_res = base[2*n:].reshape(n, n).to(device=dev, dtype=torch.bfloat16).contiguous() - def _dynamic_params(self, X_l): - """Compute A_l, B_l, C_l from residual state. + # scale: [alpha_pre, alpha_post, alpha_res] + alpha_pre = scale[0].item() + alpha_post = scale[1].item() + alpha_res = scale[2].item() - X_l: (T, n_hc, d) BF16 - Returns: A_l (T, n_hc), B_l (T, n_hc, n_hc) FP32, C_l (T, n_hc) - """ - T, n, d = X_l.shape - X_flat = X_l.reshape(T, self.K_proj) # (T, K_proj) - proj = self._project_and_rms(X_flat).float() # (T, N_proj) FP32 - - # Split projection — order matches W_stacked rows: [pre(4), post(4), res(16)] - i0, i1, i2, i3 = 0, n, 2*n, self.N_proj - A_raw = proj[:, i0:i1] # (T, n_hc) — pre - C_raw = proj[:, i1:i2] # (T, n_hc) — post - B_raw = proj[:, i2:i3] # (T, n_hc²) — res - - # Add biases and scale by gating alphas (paper eqs. 3-5) - A_tilde = self.alpha_pre * A_raw + self.S_pre - B_tilde = self.alpha_res * B_raw + self.S_res.flatten().unsqueeze(0) - C_tilde = self.alpha_post * C_raw + self.S_post.flatten().unsqueeze(0) - - # Apply constraints - A_l = torch.sigmoid(A_tilde) # (T, n_hc) ∈ (0,1) - C_l = 2.0 * torch.sigmoid(C_tilde) # (T, n_hc) ∈ (0,2) - - # B_l: exp → Sinkhorn-Knopp → doubly stochastic - B_exp = torch.exp(B_tilde).reshape(T, n, n) # (T, n_hc, n_hc) - B_l = self._sinkhorn_knopp(B_exp) - - return A_l.bfloat16(), B_l, C_l.bfloat16() - - def _sinkhorn_knopp(self, M, t_max=None): - """Project (T, n, n) positive matrices onto Birkhoff polytope. - - Alternating row/col normalization, t_max=20 iterations. - Result is doubly stochastic: rows and columns each sum to 1. - """ - if t_max is None: - t_max = self.sinkhorn_iters - for _ in range(t_max): - M = M / (M.sum(dim=-1, keepdim=True) + self.eps) # row norm - M = M / (M.sum(dim=-2, keepdim=True) + self.eps) # col norm - return M + self._impl.load_weights( + W_pre=W_pre, W_res=W_res, W_post=W_post, + S_pre=S_pre, S_res=S_res, S_post=S_post, + alpha_pre=alpha_pre, alpha_res=alpha_res, alpha_post=alpha_post) @staticmethod def init_state(embeddings, n_hc=4): - """Initialise X_0 from token embeddings. - - Convention: broadcast embedding across all n_hc residual streams. - embeddings: (T, d) BF16 - Returns: (T, n_hc, d) BF16 - """ - return embeddings.unsqueeze(1).expand(-1, n_hc, -1).clone() + from dsv4.layers.mhc import mHCLayer + return mHCLayer.init_state(embeddings, n_hc) def pre_block(self, X_l): - """Compute dynamic mixing params and extract layer input. - - X_l: (T, n_hc, d) BF16 - Returns: x_in (T, d) BF16, ctx tuple - """ - A_l, B_l, C_l = self._dynamic_params(X_l) - # x_in = A_l @ X_l — weighted sum of residual streams - x_in = torch.bmm(A_l.unsqueeze(1), X_l).squeeze(1) # (T, d) - return x_in, (B_l, C_l) + return self._impl.pre_block(X_l) def post_block(self, X_l, F_out, ctx): - """Apply mHC residual update (paper eq. 1): - X_{l+1} = B_l @ X_l + C_l ⊗ F_out - - X_l: (T, n_hc, d) BF16 — residual state BEFORE sub-layer - F_out: (T, d) BF16 — sub-layer output - ctx: (B_l, C_l) from pre_block - Returns: X_next (T, n_hc, d) BF16 - """ - B_l, C_l = ctx - # B_l is FP32, X_l is BF16 — bmm upcasts automatically - BX = torch.bmm(B_l, X_l.float()) # (T, n_hc, d) FP32 - CF = C_l.unsqueeze(-1) * F_out.unsqueeze(1) # (T, n_hc, d) BF16 - return (BX + CF.float()).bfloat16() - - -# ===================================================================== -# RoPE — partial, GPT-J interleaved, last rope_dim dims -# ===================================================================== - -def build_rope_cache(max_pos, rope_dim, device, theta=10000.0): - """Build cos/sin caches for partial RoPE. - - Returns: (cos_cache, sin_cache) each (max_pos, rope_dim//2) BF16 - """ - half = rope_dim // 2 - freqs = 1.0 / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim)) - angles = torch.outer(torch.arange(max_pos, dtype=torch.float32), freqs) - return torch.cos(angles).bfloat16().to(device), torch.sin(angles).bfloat16().to(device) - - -def apply_rope_partial(x, positions, cos_cache, sin_cache, head_dim, rope_dim): - """Apply partial GPT-J interleaved RoPE to the last rope_dim dims of each head. - - x: (T, n_h, hd) BF16 — already reshaped to per-head - positions: (T,) int64 - Returns: (T, n_h, hd) BF16 with RoPE applied to last rope_dim dims - """ - T, n_h, hd = x.shape - nope = hd - rope_dim - half = rope_dim // 2 - - cos = cos_cache[positions].unsqueeze(1) # (T, 1, half) BF16 - sin = sin_cache[positions].unsqueeze(1) - - out = x.clone() - x_rope = x[:, :, nope:] # (T, n_h, rope_dim) - x_even = x_rope[:, :, 0::2] # (T, n_h, half) - x_odd = x_rope[:, :, 1::2] - - out[:, :, nope:][..., 0::2] = x_even * cos - x_odd * sin - out[:, :, nope:][..., 1::2] = x_even * sin + x_odd * cos - return out - - -def apply_inverse_rope(o, positions, cos_cache, sin_cache, head_dim, rope_dim): - """Apply inverse RoPE (conjugate rotation) to attention output. - - Paper §2.3.3: after attention, the per-head output carries position - information from the RoPE'd queries/keys. Inverse RoPE removes this - so the output is position-invariant. - - o: (T, n_h, hd) BF16 - positions: (T,) int64 - Returns: (T, n_h, hd) BF16 with inverse RoPE on last rope_dim dims - """ - T, n_h, hd = o.shape - nope = hd - rope_dim - half = rope_dim // 2 - - cos = cos_cache[positions].unsqueeze(1) # (T, 1, half) BF16 - sin = sin_cache[positions].unsqueeze(1) - - out = o.clone() - o_rope = o[:, :, nope:] - o_even = o_rope[:, :, 0::2] - o_odd = o_rope[:, :, 1::2] - - # Conjugate rotation - out[:, :, nope:][..., 0::2] = o_even * cos + o_odd * sin - out[:, :, nope:][..., 1::2] = -o_even * sin + o_odd * cos - return out - - -# ===================================================================== -# KV cache — BF16 flat, MQA (1 KV head) -# ===================================================================== + return self._impl.post_block(X_l, F_out, ctx) class SimpleKVCache: """Per-layer KV cache for decode. Stores BF16 K,V accumulated across steps.