""" mHC (Manifold-Constrained Hyper-Connections) — Inference Layer. Implements Section 2.2 of the DeepSeek-V4 paper for the forward pass only. At inference the Sinkhorn-Knopp constraint has already been enforced during training, but B_l is still *dynamically generated* per-token from the input residual state. So we still need to: 1. Project the flattened residual → raw A/B/C parameter values. 2. Apply sigmoid (A, C) and Sinkhorn-Knopp 20 iters (B). 3. Mix residual streams. The only thing that changes vs training is that we skip the loss and gradient through the Sinkhorn projection — the forward arithmetic is identical. --------------------------------------------------------------------- V4-Pro reference dimensions (Section 4.2.1) --------------------------------------------------------------------- d = 7168 hidden dim n_hc = 4 hyper-connection expansion factor N_proj = 24 fused output of W_pre(4) + W_res(16) + W_post(4) K_proj = 4*7168 = 28672 = n_hc * d (flattened residual) t_max = 20 Sinkhorn iterations --------------------------------------------------------------------- Kernel dependency --------------------------------------------------------------------- tf32_hc_prenorm_gemm (DeepGEMM, SM90/SM100) a: (T, K) BF16 — flattened residual X_flat b: (N, K) FP32 — stacked weight [W_pre; W_res; W_post] d: (S, T, N) or (T, N) FP32 — raw projection outputs (pre-normalised) sqr_sum: (S, T) or (T,) FP32 — Σ a² per token (for RMSNorm denominator) num_splits = S (16 recommended for K=28672) After the call: d = d.sum(0) → (T, N) sqr_sum = sqr_sum.sum(0) → (T,) rms_scale = sqrt(K / (sqr_sum + eps)) d_norm = d * rms_scale[:,None] — equivalent to RMSNorm(X_flat) @ W_stacked """ from __future__ import annotations import math from dataclasses import dataclass from typing import Optional, Tuple import torch import torch.nn.functional as F # --------------------------------------------------------------------------- # Try importing DeepGEMM; fall back to plain BF16 matmul if unavailable. # --------------------------------------------------------------------------- try: import deep_gemm _HAS_DEEP_GEMM = True except ImportError: _HAS_DEEP_GEMM = False NUM_SPLITS = 16 # K-split count for tf32_hc_prenorm_gemm numerical stability EPS_RMSN = 1e-6 # --------------------------------------------------------------------------- # Sinkhorn-Knopp projection (T batched 4×4 matrices, 20 iters) # --------------------------------------------------------------------------- def sinkhorn_knopp( M: torch.Tensor, # (T, n, n) positive (after exp) t_max: int = 20, ) -> torch.Tensor: """ Project each (n×n) positive matrix onto the Birkhoff polytope (doubly stochastic matrices) via alternating row/col normalisation. Paper eq. (8): M^(t) = T_r( T_c( M^(t-1) ) ) where T_r = row-normalise, T_c = col-normalise. For n=4 and t_max=20 this is ~160 tiny operations — no kernel needed. All ops stay on GPU via standard PyTorch. """ for _ in range(t_max): M = M / (M.sum(dim=-1, keepdim=True) + EPS_RMSN) # T_r (row) M = M / (M.sum(dim=-2, keepdim=True) + EPS_RMSN) # T_c (col) return M # --------------------------------------------------------------------------- # Context carried between pre_block and post_block # --------------------------------------------------------------------------- @dataclass class mHCContext: """Holds the per-token mixing matrices computed in pre_block.""" B_l: torch.Tensor # (T, n_hc, n_hc) doubly stochastic residual transform C_l: torch.Tensor # (T, n_hc) output mapping (before unsqueeze) # --------------------------------------------------------------------------- # mHC layer # --------------------------------------------------------------------------- class mHCLayer: """ Wraps one transformer sub-layer (attention *or* MoE) with the mHC residual update. Typical call pattern per layer: x_in, ctx = mhc.pre_block(X_l) F_out = transformer_sublayer(x_in) # (T, d) X_next = mhc.post_block(X_l, F_out, ctx) where X_l has shape (T, n_hc, d) — the expanded residual state. The first call at layer 0 should use X_0 initialised via `init_state`. """ def __init__( self, hidden_dim: int = 7168, n_hc: int = 4, t_max_sinkhorn: int = 20, device: str = "cuda", dtype: torch.dtype = torch.bfloat16, ): self.d = hidden_dim self.n_hc = n_hc self.K_proj = n_hc * hidden_dim # 28672 for V4-Pro self.N_proj = n_hc + n_hc * n_hc + n_hc # 4 + 16 + 4 = 24 self.t_max = t_max_sinkhorn self.device = device self.dtype = dtype # ── Learnable weights (set via load_weights) ────────────────── # Stacked projection: b shape = (N_proj, K_proj) in FP32 # Stored as separate tensors, fused in forward if DeepGEMM available. self.W_pre = self._buf(n_hc, self.K_proj, dtype=torch.float32) # (4, K) self.W_res = self._buf(n_hc * n_hc, self.K_proj, dtype=torch.float32) # (16, K) self.W_post = self._buf(n_hc, self.K_proj, dtype=torch.float32) # (4, K) # Static biases (eq. 3-5, S^pre / S^res / S^post) self.S_pre = self._buf(1, n_hc) # (1, 4) self.S_res = self._buf(n_hc, n_hc) # (4, 4) self.S_post = self._buf(n_hc, 1) # (4, 1) # Learnable gating scalars (α), initialised small during training # At inference these are just scalars loaded from the checkpoint. self.alpha_pre = torch.zeros(1, device=device, dtype=torch.float32) self.alpha_res = torch.zeros(1, device=device, dtype=torch.float32) self.alpha_post = torch.zeros(1, device=device, dtype=torch.float32) # Pre-allocated split buffers (set in _ensure_buffers) self._d_split = None # (NUM_SPLITS, max_T, N_proj) FP32 self._sqr_sum_split = None # (NUM_SPLITS, max_T) FP32 self._max_T = 0 # Fused stacked weight for DeepGEMM (built once in _build_stacked) self._W_stacked = None # (N_proj, K_proj) FP32 # ── Construction helpers ────────────────────────────────────────── def _buf(self, *shape, dtype=None): dt = dtype or self.dtype return torch.empty(*shape, dtype=dt, device=self.device) def load_weights( self, W_pre: torch.Tensor, # (n_hc, K) FP32 W_res: torch.Tensor, # (n_hc², K) FP32 W_post: torch.Tensor, # (n_hc, K) FP32 S_pre: torch.Tensor, # (1, n_hc) S_res: torch.Tensor, # (n_hc, n_hc) S_post: torch.Tensor, # (n_hc, 1) alpha_pre: float, alpha_res: float, alpha_post: float, ): """ Load all mHC parameters from the checkpoint. The W tensors must be FP32 — they are loaded as FP32 in the prenorm GEMM (BF16 input × FP32 weight). Everything else can be BF16 in the checkpoint and will be cast here. """ def _f32(t): return t.to(device=self.device, dtype=torch.float32).contiguous() def _cvt(t): return t.to(device=self.device, dtype=self.dtype).contiguous() self.W_pre = _f32(W_pre) self.W_res = _f32(W_res) self.W_post = _f32(W_post) self.S_pre = _cvt(S_pre) self.S_res = _cvt(S_res) self.S_post = _cvt(S_post) self.alpha_pre = torch.tensor(alpha_pre, dtype=torch.float32, device=self.device) self.alpha_res = torch.tensor(alpha_res, dtype=torch.float32, device=self.device) self.alpha_post = torch.tensor(alpha_post, dtype=torch.float32, device=self.device) self._W_stacked = None # invalidate cache def _build_stacked(self): """Fuse W_pre / W_res / W_post into one (N_proj, K_proj) FP32 tensor.""" self._W_stacked = torch.cat([self.W_pre, self.W_res, self.W_post], dim=0) # Must be K-major (contiguous along K) for DeepGEMM self._W_stacked = self._W_stacked.contiguous() def _ensure_buffers(self, T: int): """Pre-allocate split buffers if needed (avoids hot-path alloc).""" if T <= self._max_T: return self._d_split = torch.empty( NUM_SPLITS, T, self.N_proj, dtype=torch.float32, device=self.device ) self._sqr_sum_split = torch.empty( NUM_SPLITS, T, dtype=torch.float32, device=self.device ) self._max_T = T # ── Forward ────────────────────────────────────────────────────── def _project_and_rms(self, X_flat: torch.Tensor) -> torch.Tensor: """ Compute RMSNorm(X_flat) @ W_stacked.T → (T, N_proj) FP32. Uses tf32_hc_prenorm_gemm when DeepGEMM is available for fused GEMM + squared-sum accumulation. Falls back to plain BF16 matmul. X_flat: (T, K_proj) BF16 """ T = X_flat.shape[0] K = self.K_proj if _HAS_DEEP_GEMM: if self._W_stacked is None: self._build_stacked() self._ensure_buffers(T) d_s = self._d_split[:, :T, :] # view, no copy ss_s = self._sqr_sum_split[:, :T] # a: (T, K) BF16 b: (N, K) FP32 → d_s: (S, T, N), ss_s: (S, T) # Both d and sqr_sum are OUTPUT tensors (written by the kernel). deep_gemm.tf32_hc_prenorm_gemm( X_flat.contiguous(), # a self._W_stacked, # b (N, K) FP32 d_s, # d (S, T, N) ss_s, # sqr_sum (S, T) num_splits=NUM_SPLITS, ) d_out = d_s.sum(dim=0) # (T, N) sqr_sum = ss_s.sum(dim=0) # (T,) else: # Fallback: BF16 matmul + manual squared sum if self._W_stacked is None: self._build_stacked() x_f32 = X_flat.float() d_out = x_f32 @ self._W_stacked.T # (T, N) sqr_sum = x_f32.pow(2).sum(dim=-1) # (T,) # RMSNorm scale: multiply raw GEMM output by rsqrt(mean(x²)) # mean(x²) = sqr_sum / K → scale = sqrt(K / sqr_sum) rms_scale = torch.sqrt(K / (sqr_sum + EPS_RMSN)) # (T,) return (d_out * rms_scale.unsqueeze(-1)).to(self.dtype) # (T, N) in BF16 def _dynamic_params( self, X_l: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute per-token A_l, B_l, C_l from the current residual state. X_l: (T, n_hc, d) Returns: A_l: (T, n_hc) sigmoid-constrained input mapping B_l: (T, n_hc, n_hc) doubly-stochastic residual transform C_l: (T, n_hc) 2*sigmoid-constrained output mapping """ T, n, d = X_l.shape assert n == self.n_hc and d == self.d # Flatten: (T, n_hc*d) X_flat = X_l.reshape(T, self.K_proj).to(self.dtype) # Fused RMSNorm projection: (T, N_proj) proj = self._project_and_rms(X_flat).float() # keep FP32 for precision # Split into raw A / B / C i0, i1, i2, i3 = 0, self.n_hc, self.n_hc + self.n_hc**2, self.N_proj A_raw = proj[:, i0:i1] # (T, n_hc) B_raw = proj[:, i1:i2] # (T, n_hc²) C_raw = proj[:, i2:i3] # (T, n_hc) # Add static biases and scale by learned gating factors (eq. 3-5) S_pre = self.S_pre.float() # (1, n_hc) S_res = self.S_res.float() # (n_hc, n_hc) S_post = self.S_post.float() # (n_hc, 1) A_tilde = self.alpha_pre * A_raw + S_pre # (T, n_hc) B_tilde = self.alpha_res * B_raw + S_res.flatten().unsqueeze(0) # (T, n_hc²) C_tilde = self.alpha_post * C_raw + S_post.flatten().unsqueeze(0) # (T, n_hc) # Apply constraints (paper eqs. 6-8) A_l = torch.sigmoid(A_tilde) # (T, n_hc) C_l = 2.0 * torch.sigmoid(C_tilde) # (T, n_hc) # B_l: exp → Sinkhorn-Knopp → doubly stochastic B_exp = torch.exp(B_tilde).reshape(T, self.n_hc, self.n_hc) B_l = sinkhorn_knopp(B_exp, t_max=self.t_max) # (T, n_hc, n_hc) # Keep B_l in FP32 — the (T,4,4) bmm precision matters more than memory. # A_l and C_l are cast to dtype for the input/output mixing multiplies. return A_l.to(self.dtype), B_l, C_l.to(self.dtype) # ---------------------------------------------------------------- # Public API: pre_block / post_block # ---------------------------------------------------------------- def pre_block( self, X_l: torch.Tensor, # (T, n_hc, d) BF16 ) -> Tuple[torch.Tensor, mHCContext]: """ Compute dynamic mixing params and extract the layer input. Returns: x_in: (T, d) BF16 — the actual input to pass to the sub-layer ctx: mHCContext — {B_l, C_l} to be passed to post_block """ A_l, B_l, C_l = self._dynamic_params(X_l) # Layer input: x_in = A_l @ X_l (per token, weighted sum of streams) # A_l: (T, n_hc) X_l: (T, n_hc, d) # → (T, 1, n_hc) bmm (T, n_hc, d) = (T, 1, d) → squeeze x_in = torch.bmm(A_l.unsqueeze(1), X_l).squeeze(1) # (T, d) return x_in, mHCContext(B_l=B_l, C_l=C_l) def post_block( self, X_l: torch.Tensor, # (T, n_hc, d) BF16 — residual state BEFORE sub-layer F_out: torch.Tensor, # (T, d) BF16 — sub-layer output ctx: mHCContext, ) -> torch.Tensor: """ Apply the mHC residual update (eq. 1): X_{l+1} = B_l @ X_l + C_l ⊗ F_out Returns: X_next: (T, n_hc, d) BF16 """ # B_l is FP32, X_l is BF16 — bmm upcasts automatically in PyTorch. BX = torch.bmm(ctx.B_l, X_l.float()) CF = ctx.C_l.unsqueeze(-1) * F_out.unsqueeze(1) # (T, n_hc, d) return (BX + CF.float()).to(self.dtype) # (T, n_hc, d) # ---------------------------------------------------------------- # Utility # ---------------------------------------------------------------- @staticmethod def init_state( embeddings: torch.Tensor, # (T, d) BF16 — token embeddings n_hc: int = 4, ) -> torch.Tensor: """ Initialise X_0 for the first layer. The paper figure shows the embedding feeding into the first Residual Mixing. We broadcast the embedding across all n_hc residual streams as the simplest valid initialisation. Returns: (T, n_hc, d) BF16 """ return embeddings.unsqueeze(1).expand(-1, n_hc, -1).clone() @staticmethod def read_out(X_L: torch.Tensor) -> torch.Tensor: """ Extract the final hidden state from the last residual state. Convention: stream 0 is the primary output stream (standard choice for HC models — the first stream carries the main residual). Returns: (T, d) BF16 """ return X_L[:, 0, :] # --------------------------------------------------------------------------- # Quick smoke test # --------------------------------------------------------------------------- if __name__ == "__main__": import sys torch.manual_seed(0) device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.bfloat16 D, N_HC = 7168, 4 K = N_HC * D # 28672 N_PROJ = N_HC + N_HC ** 2 + N_HC # 24 mhc = mHCLayer(hidden_dim=D, n_hc=N_HC, device=device, dtype=dtype) # Random weights matching the expected shapes mhc.load_weights( W_pre = torch.randn(N_HC, K, dtype=torch.float32), W_res = torch.randn(N_HC**2, K, dtype=torch.float32), W_post = torch.randn(N_HC, K, dtype=torch.float32), S_pre = torch.zeros(1, N_HC, dtype=dtype), S_res = torch.eye(N_HC, dtype=dtype), # identity: pure residual S_post = torch.zeros(N_HC, 1, dtype=dtype), alpha_pre = 0.01, alpha_res = 0.01, alpha_post = 0.01, ) T = 4 # 4 tokens # ── Forward pass ──────────────────────────────────────────────── embeddings = torch.randn(T, D, dtype=dtype, device=device) X = mHCLayer.init_state(embeddings, n_hc=N_HC) print(f"X_0: {X.shape} (T={T}, n_hc={N_HC}, d={D})") # Simulate a 2-layer stack for layer_idx in range(2): x_in, ctx = mhc.pre_block(X) print(f"\nLayer {layer_idx}:") print(f" x_in (to sub-layer): {x_in.shape}") print(f" B_l: {ctx.B_l.shape}") print(f" C_l: {ctx.C_l.shape}") # Dummy sub-layer: identity (for testing the mHC mechanics) F_out = x_in X = mhc.post_block(X, F_out, ctx) print(f" X_next: {X.shape}") hidden = mHCLayer.read_out(X) print(f"\nFinal hidden: {hidden.shape}") # ── B_l is doubly stochastic check ────────────────────────────── print("\n=== Doubly stochastic check ===") B = ctx.B_l # (T, 4, 4) — FP32 from Sinkhorn row_sums = B.sum(dim=-1) # (T, 4) — should all be ~1 col_sums = B.sum(dim=-2) # (T, 4) — should all be ~1 print(f" row sum range: [{row_sums.min():.6f}, {row_sums.max():.6f}] (want ≈ 1.0)") print(f" col sum range: [{col_sums.min():.6f}, {col_sums.max():.6f}] (want ≈ 1.0)") assert (row_sums - 1).abs().max() < 1e-3, "B_l rows do not sum to 1" assert (col_sums - 1).abs().max() < 1e-3, "B_l cols do not sum to 1" print(" PASSED") # ── A_l and C_l are bounded ────────────────────────────────────── # (Re-run dynamic params to expose A_l for checking) A_l, B_l2, C_l = mhc._dynamic_params(X) print(f"\n=== A_l ∈ (0,1) check ===") print(f" A_l range: [{A_l.min():.4f}, {A_l.max():.4f}] (want ∈ (0,1))") assert A_l.min() > 0 and A_l.max() < 1, "A_l out of sigmoid range" print(" PASSED") print(f"\n=== C_l ∈ (0,2) check ===") print(f" C_l range: [{C_l.min():.4f}, {C_l.max():.4f}] (want ∈ (0,2))") assert C_l.min() > 0 and C_l.max() < 2, "C_l out of 2*sigmoid range" print(" PASSED") # ── Consistency: S_res = identity → B_l ≈ doubly-stochastic I ─── print("\n=== S_res=I, alpha_res≈0 → B_l ≈ uniform matrix ===") # With S_res = I and alpha_res ≈ 0: # B_tilde ≈ I → exp(I) → Sinkhorn of exp(I) # exp(I) is diag-dominant; after Sinkhorn it converges to a doubly stochastic matrix. # We just check doubly-stochastic property is preserved (already checked above). print(" Already verified via doubly stochastic check above.") # ── Equivalence: T=1 decode vs T=N prefill ────────────────────── print("\n=== Token-by-token decode == batch prefill ===") T_big = 8 h_big = torch.randn(T_big, D, dtype=dtype, device=device) X_batch = mHCLayer.init_state(h_big, n_hc=N_HC) # Batch x_in_batch, ctx_batch = mhc.pre_block(X_batch) # Token by token x_in_tokens = [] for t in range(T_big): X_t = X_batch[t:t+1] # (1, n_hc, d) x_in_t, _ = mhc.pre_block(X_t) x_in_tokens.append(x_in_t) x_in_seq = torch.cat(x_in_tokens, dim=0) # (T_big, d) diff = (x_in_batch - x_in_seq).abs().max().item() print(f" max |batch - sequential| on x_in: {diff:.6f}") assert diff < 1e-2, f"Mismatch too large: {diff}" print(" PASSED") print("\nAll checks done.") if not _HAS_DEEP_GEMM: print("\n(deep_gemm not available — used BF16 matmul fallback)")