diff --git a/single_shot_inference.py b/single_shot_inference.py index c8e6e8b8..8c50bce8 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -105,219 +105,59 @@ class RMSNorm: # ===================================================================== class mHCBlock: - """Manifold-Constrained Hyper-Connections (paper §2.2). - - This is a standalone implementation matching dsv4/layers/mhc.py - but self-contained for single-shot inference. Uses the BF16 matmul - fallback (no DeepGEMM dependency). - - Weight mapping from checkpoint: - fn: (N_proj, K_proj) FP32 — fused projection [W_pre; W_res; W_post] - base: (N_proj,) — static biases [S_pre; S_res; S_post] - scale: (3,) — gating alphas [alpha_pre, alpha_res, alpha_post] + """Wrapper around dsv4.layers.mhc.mHCLayer for single-shot inference. + + Uses the production mHCLayer implementation with proper Sinkhorn-Knopp. """ def __init__(self, hidden_dim=7168, n_hc=4, sinkhorn_iters=20, device='cuda:0'): - self.d = hidden_dim - self.n_hc = n_hc - self.K_proj = n_hc * hidden_dim # 28672 - self.N_proj = n_hc + n_hc**2 + n_hc # 4 + 16 + 4 = 24 - self.sinkhorn_iters = sinkhorn_iters + from dsv4.layers.mhc import mHCLayer + self._impl = mHCLayer( + hidden_dim=hidden_dim, n_hc=n_hc, + t_max_sinkhorn=sinkhorn_iters, + device=device, dtype=torch.bfloat16) self.device = device - self.eps = 1e-6 - - # Weights — set by load_from_checkpoint - self.W_stacked = None # (N_proj, K_proj) FP32 - self.S_pre = None # (1, n_hc) FP32 - self.S_res = None # (n_hc, n_hc) FP32 - self.S_post = None # (n_hc, 1) FP32 - self.alpha_pre = None # float - self.alpha_res = None # float - self.alpha_post = None # float + self.n_hc = n_hc + self.hidden_dim = hidden_dim def load_from_checkpoint(self, fn, base, scale): """Load from checkpoint tensors. - - fn: (24, 28672) FP32 + fn: (24, 28672) FP32 — fused projection base: (24,) — [pre(4), post(4), res(16)] scale: (3,) — [alpha_pre, alpha_post, alpha_res] """ n = self.n_hc - # CRITICAL: checkpoint base order is [pre, post, res], not [pre, res, post] - # This matches the old working code: base[:4]=pre, base[4:8]=post, base[8:24]=res - self.W_stacked = fn.to(device=self.device, dtype=torch.float32).contiguous() - self.S_pre = base[0:n].reshape(1, n).to(device=self.device, dtype=torch.float32) - self.S_post = base[n:2*n].reshape(n, 1).to(device=self.device, dtype=torch.float32) - self.S_res = base[2*n:2*n + n*n].reshape(n, n).to(device=self.device, dtype=torch.float32) - # CRITICAL: checkpoint scale order is [alpha_pre, alpha_post, alpha_res] - self.alpha_pre = scale[0].item() if scale.numel() > 0 else 0.01 - self.alpha_post = scale[1].item() if scale.numel() > 1 else 0.01 - self.alpha_res = scale[2].item() if scale.numel() > 2 else 0.01 + dev = self.device - def _project_and_rms(self, X_flat): - """Compute RMSNorm(X_flat) @ W_stacked.T → (T, N_proj) BF16. + # fn rows: [W_pre(4), W_post(4), W_res(16)] + W_pre = fn[0:n].to(device=dev, dtype=torch.float32).contiguous() + W_post = fn[n:2*n].to(device=dev, dtype=torch.float32).contiguous() + W_res = fn[2*n:].to(device=dev, dtype=torch.float32).contiguous() - X_flat: (T, K_proj) BF16 - """ - T = X_flat.shape[0] - K = self.K_proj - x_f32 = X_flat.float() - d_out = x_f32 @ self.W_stacked.T # (T, N_proj) - sqr_sum = x_f32.pow(2).sum(dim=-1) # (T,) - rms_scale = torch.sqrt(K / (sqr_sum + self.eps)) # (T,) - return (d_out * rms_scale.unsqueeze(-1)).bfloat16() # (T, N_proj) BF16 + # base: [S_pre(4), S_post(4), S_res(16)] + S_pre = base[0:n].reshape(1, n).to(device=dev, dtype=torch.bfloat16).contiguous() + S_post = base[n:2*n].reshape(n, 1).to(device=dev, dtype=torch.bfloat16).contiguous() + S_res = base[2*n:].reshape(n, n).to(device=dev, dtype=torch.bfloat16).contiguous() - def _dynamic_params(self, X_l): - """Compute A_l, B_l, C_l from residual state. + # scale: [alpha_pre, alpha_post, alpha_res] + alpha_pre = scale[0].item() + alpha_post = scale[1].item() + alpha_res = scale[2].item() - X_l: (T, n_hc, d) BF16 - Returns: A_l (T, n_hc), B_l (T, n_hc, n_hc) FP32, C_l (T, n_hc) - """ - T, n, d = X_l.shape - X_flat = X_l.reshape(T, self.K_proj) # (T, K_proj) - proj = self._project_and_rms(X_flat).float() # (T, N_proj) FP32 - - # Split projection — order matches W_stacked rows: [pre(4), post(4), res(16)] - i0, i1, i2, i3 = 0, n, 2*n, self.N_proj - A_raw = proj[:, i0:i1] # (T, n_hc) — pre - C_raw = proj[:, i1:i2] # (T, n_hc) — post - B_raw = proj[:, i2:i3] # (T, n_hc²) — res - - # Add biases and scale by gating alphas (paper eqs. 3-5) - A_tilde = self.alpha_pre * A_raw + self.S_pre - B_tilde = self.alpha_res * B_raw + self.S_res.flatten().unsqueeze(0) - C_tilde = self.alpha_post * C_raw + self.S_post.flatten().unsqueeze(0) - - # Apply constraints - A_l = torch.sigmoid(A_tilde) # (T, n_hc) ∈ (0,1) - C_l = 2.0 * torch.sigmoid(C_tilde) # (T, n_hc) ∈ (0,2) - - # B_l: exp → Sinkhorn-Knopp → doubly stochastic - B_exp = torch.exp(B_tilde).reshape(T, n, n) # (T, n_hc, n_hc) - B_l = self._sinkhorn_knopp(B_exp) - - return A_l.bfloat16(), B_l, C_l.bfloat16() - - def _sinkhorn_knopp(self, M, t_max=None): - """Project (T, n, n) positive matrices onto Birkhoff polytope. - - Alternating row/col normalization, t_max=20 iterations. - Result is doubly stochastic: rows and columns each sum to 1. - """ - if t_max is None: - t_max = self.sinkhorn_iters - for _ in range(t_max): - M = M / (M.sum(dim=-1, keepdim=True) + self.eps) # row norm - M = M / (M.sum(dim=-2, keepdim=True) + self.eps) # col norm - return M + self._impl.load_weights( + W_pre=W_pre, W_res=W_res, W_post=W_post, + S_pre=S_pre, S_res=S_res, S_post=S_post, + alpha_pre=alpha_pre, alpha_res=alpha_res, alpha_post=alpha_post) @staticmethod def init_state(embeddings, n_hc=4): - """Initialise X_0 from token embeddings. - - Convention: broadcast embedding across all n_hc residual streams. - embeddings: (T, d) BF16 - Returns: (T, n_hc, d) BF16 - """ - return embeddings.unsqueeze(1).expand(-1, n_hc, -1).clone() + from dsv4.layers.mhc import mHCLayer + return mHCLayer.init_state(embeddings, n_hc) def pre_block(self, X_l): - """Compute dynamic mixing params and extract layer input. - - X_l: (T, n_hc, d) BF16 - Returns: x_in (T, d) BF16, ctx tuple - """ - A_l, B_l, C_l = self._dynamic_params(X_l) - # x_in = A_l @ X_l — weighted sum of residual streams - x_in = torch.bmm(A_l.unsqueeze(1), X_l).squeeze(1) # (T, d) - return x_in, (B_l, C_l) + return self._impl.pre_block(X_l) def post_block(self, X_l, F_out, ctx): - """Apply mHC residual update (paper eq. 1): - X_{l+1} = B_l @ X_l + C_l ⊗ F_out - - X_l: (T, n_hc, d) BF16 — residual state BEFORE sub-layer - F_out: (T, d) BF16 — sub-layer output - ctx: (B_l, C_l) from pre_block - Returns: X_next (T, n_hc, d) BF16 - """ - B_l, C_l = ctx - # B_l is FP32, X_l is BF16 — bmm upcasts automatically - BX = torch.bmm(B_l, X_l.float()) # (T, n_hc, d) FP32 - CF = C_l.unsqueeze(-1) * F_out.unsqueeze(1) # (T, n_hc, d) BF16 - return (BX + CF.float()).bfloat16() - - -# ===================================================================== -# RoPE — partial, GPT-J interleaved, last rope_dim dims -# ===================================================================== - -def build_rope_cache(max_pos, rope_dim, device, theta=10000.0): - """Build cos/sin caches for partial RoPE. - - Returns: (cos_cache, sin_cache) each (max_pos, rope_dim//2) BF16 - """ - half = rope_dim // 2 - freqs = 1.0 / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim)) - angles = torch.outer(torch.arange(max_pos, dtype=torch.float32), freqs) - return torch.cos(angles).bfloat16().to(device), torch.sin(angles).bfloat16().to(device) - - -def apply_rope_partial(x, positions, cos_cache, sin_cache, head_dim, rope_dim): - """Apply partial GPT-J interleaved RoPE to the last rope_dim dims of each head. - - x: (T, n_h, hd) BF16 — already reshaped to per-head - positions: (T,) int64 - Returns: (T, n_h, hd) BF16 with RoPE applied to last rope_dim dims - """ - T, n_h, hd = x.shape - nope = hd - rope_dim - half = rope_dim // 2 - - cos = cos_cache[positions].unsqueeze(1) # (T, 1, half) BF16 - sin = sin_cache[positions].unsqueeze(1) - - out = x.clone() - x_rope = x[:, :, nope:] # (T, n_h, rope_dim) - x_even = x_rope[:, :, 0::2] # (T, n_h, half) - x_odd = x_rope[:, :, 1::2] - - out[:, :, nope:][..., 0::2] = x_even * cos - x_odd * sin - out[:, :, nope:][..., 1::2] = x_even * sin + x_odd * cos - return out - - -def apply_inverse_rope(o, positions, cos_cache, sin_cache, head_dim, rope_dim): - """Apply inverse RoPE (conjugate rotation) to attention output. - - Paper §2.3.3: after attention, the per-head output carries position - information from the RoPE'd queries/keys. Inverse RoPE removes this - so the output is position-invariant. - - o: (T, n_h, hd) BF16 - positions: (T,) int64 - Returns: (T, n_h, hd) BF16 with inverse RoPE on last rope_dim dims - """ - T, n_h, hd = o.shape - nope = hd - rope_dim - half = rope_dim // 2 - - cos = cos_cache[positions].unsqueeze(1) # (T, 1, half) BF16 - sin = sin_cache[positions].unsqueeze(1) - - out = o.clone() - o_rope = o[:, :, nope:] - o_even = o_rope[:, :, 0::2] - o_odd = o_rope[:, :, 1::2] - - # Conjugate rotation - out[:, :, nope:][..., 0::2] = o_even * cos + o_odd * sin - out[:, :, nope:][..., 1::2] = -o_even * sin + o_odd * cos - return out - - -# ===================================================================== -# KV cache — BF16 flat, MQA (1 KV head) -# ===================================================================== + return self._impl.post_block(X_l, F_out, ctx) class SimpleKVCache: """Per-layer KV cache for decode. Stores BF16 K,V accumulated across steps.