diff --git a/cutedsl/csa_hca_compressor.py b/cutedsl/csa_hca_compressor.py new file mode 100644 index 00000000..956fb5a0 --- /dev/null +++ b/cutedsl/csa_hca_compressor.py @@ -0,0 +1,419 @@ +""" +DeepSeek-V4 CSA/HCA compressor kernels for CUTLASS CuTe DSL / Blackwell. + +This is a production-oriented fusion boundary for the PyTorch reference in +`Pasted markdown.md`: + + 1. Run the projection stage as one or two packed Blackwell GEMMs: + CSA main: H @ [W_a_KV | W_a_Z | W_b_KV | W_b_Z] -> (N, 4*C) + CSA indexer: H @ [W_I_a_KV | W_I_a_Z | W_I_b_KV | W_I_b_Z] -> (N, 4*C_I) + HCA: H @ [W_KV | W_Z] -> (N, 2*C) + + Use your tcgen05 / NVFP4 blockscaled GEMM here. Keep the packed projection + outputs in BF16/FP32 as you prefer; the compressor reads them as tensor + elements and accumulates the softmax reduction in FP32. + + 2. These native CuTe DSL kernels fuse: + bias add + column-wise softmax over token positions + weighted C sum + + partial RoPE for KV output. + +The full projection+compression single-kernel tcgen05 variant is possible, but it +needs your exact NVFP4 weight/scale layout and preferred tile shape. This file is +therefore the safe fusion seam: the expensive D x C math stays in your Blackwell +GEMM, while the small-position softmax/reduction/rope path avoids PyTorch ops and +extra materialization after projection. + +Target dimensions used by DeepSeek-V4-Pro reference: + D=7168, C=512, C_I=128, CSA_M=4, HCA_M=128, NOPE=448, ROPE=64. + +Assumptions: + * Tensors are contiguous row-major from PyTorch/DLPack. + * Projection buffers are laid out as described above. + * One sequence at a time. State/tail management remains on the caller side. + * For CSA continuation across calls, provide external previous-block B-side + projections for the first committed block. For fresh prefill, set + has_external_prev=False. + +NOTE: I cannot compile this in this sandbox because CUTLASS CuTe DSL is not +installed here. It follows the CuTe DSL @jit/@kernel launch style, but you may +need tiny API-name edits if you are pinned to a specific CUTLASS 4.x commit. +""" + +from __future__ import annotations + +import torch +import cutlass +import cutlass.cute as cute +import cuda.bindings.driver as cuda + + +# ----------------------------------------------------------------------------- +# Small helpers +# ----------------------------------------------------------------------------- + +LOG2_E = 1.44269504088896340736 + + +def _ceil_div(a: int, b: int) -> int: + return (a + b - 1) // b + + +@cute.jit +def _expf(x: cutlass.Float32) -> cutlass.Float32: + # CuTe DSL exposes exp2 in cute.math; exp(x) = exp2(x * log2(e)). + return cute.math.exp2(x * cutlass.Float32(LOG2_E)) + + +@cute.jit +def _read_csa_current_packed( + proj: cute.Tensor, + token_idx: cutlass.Int64, + col: cutlass.Int32, + kind: cutlass.Constexpr, # 0 Ca, 1 Za, 2 Cb, 3 Zb + OUT: cutlass.Constexpr, +) -> cutlass.Float32: + return proj[token_idx, kind * OUT + col].to(cutlass.Float32) + + +@cute.jit +def _read_hca_packed( + proj: cute.Tensor, + token_idx: cutlass.Int64, + col: cutlass.Int32, + kind: cutlass.Constexpr, # 0 C, 1 Z + C: cutlass.Constexpr, +) -> cutlass.Float32: + return proj[token_idx, kind * C + col].to(cutlass.Float32) + + +# ----------------------------------------------------------------------------- +# CSA fused compressor from packed projections +# ----------------------------------------------------------------------------- + +@cute.jit +def _csa_raw_reduce_one_col( + proj: cute.Tensor, # (N_tokens, 4*OUT): Ca, Za, Cb, Zb + prev_b_proj: cute.Tensor, # (M, 2*OUT): Cb_prev, Zb_prev for first block, may be dummy + B_a: cute.Tensor, # (M, OUT) + B_b: cute.Tensor, # (M, OUT) + block_i: cutlass.Int64, + col: cutlass.Int32, + start_token_in_proj: cutlass.Int64, + has_external_prev: cutlass.Constexpr, + M: cutlass.Constexpr, + OUT: cutlass.Constexpr, +) -> cutlass.Float32: + """Column-wise CSA softmax+weighted-sum for either main KV or indexer. + + This implements exactly: + softmax([Z_a_cur + B_a ; Z_b_prev + B_b], dim=position) + sum S_a*C_a + sum S_b*C_b + with the block-0 no-prev case reducing over M current positions only. + """ + # First pass: max logit. + max_logit = cutlass.Float32(-3.4028234663852886e38) + + # Current/current-a side always exists. + for p in cutlass.range_constexpr(M): + tok = start_token_in_proj + block_i * M + p + za = _read_csa_current_packed(proj, tok, col, 1, OUT) + B_a[p, col].to(cutlass.Float32) + max_logit = cute.math.fmax(max_logit, za) + + # Previous/b side exists if this is not the first fresh block. + use_prev = (block_i > 0) or has_external_prev + if use_prev: + for p in cutlass.range_constexpr(M): + if block_i > 0: + tok_prev = start_token_in_proj + (block_i - 1) * M + p + zb = _read_csa_current_packed(proj, tok_prev, col, 3, OUT) + else: + # External previous block is packed as [Cb_prev | Zb_prev] + zb = prev_b_proj[p, OUT + col].to(cutlass.Float32) + zb = zb + B_b[p, col].to(cutlass.Float32) + max_logit = cute.math.fmax(max_logit, zb) + + # Second pass: exp denominator and weighted value. + denom = cutlass.Float32(0.0) + acc = cutlass.Float32(0.0) + + for p in cutlass.range_constexpr(M): + tok = start_token_in_proj + block_i * M + p + za = _read_csa_current_packed(proj, tok, col, 1, OUT) + B_a[p, col].to(cutlass.Float32) + ca = _read_csa_current_packed(proj, tok, col, 0, OUT) + e = _expf(za - max_logit) + denom += e + acc += e * ca + + if use_prev: + for p in cutlass.range_constexpr(M): + if block_i > 0: + tok_prev = start_token_in_proj + (block_i - 1) * M + p + cb = _read_csa_current_packed(proj, tok_prev, col, 2, OUT) + zb = _read_csa_current_packed(proj, tok_prev, col, 3, OUT) + else: + cb = prev_b_proj[p, col].to(cutlass.Float32) + zb = prev_b_proj[p, OUT + col].to(cutlass.Float32) + zb = zb + B_b[p, col].to(cutlass.Float32) + e = _expf(zb - max_logit) + denom += e + acc += e * cb + + return acc / denom + + +@cute.kernel +def csa_compress_projected_kernel( + proj_main: cute.Tensor, # (N_tokens, 4*C) + proj_indexer: cute.Tensor, # (N_tokens, 4*C_I) + prev_main_b: cute.Tensor, # (M, 2*C), Cb_prev|Zb_prev; dummy if no ext prev + prev_indexer_b: cute.Tensor, # (M, 2*C_I), Cb_prev|Zb_prev; dummy if no ext prev + B_a: cute.Tensor, # (M, C) + B_b: cute.Tensor, # (M, C) + B_I_a: cute.Tensor, # (M, C_I) + B_I_b: cute.Tensor, # (M, C_I) + cos_sin_cache: cute.Tensor, # (max_pos, ROPE), cos first half, sin second half + kv_out: cute.Tensor, # (n_blocks, C) + indexer_out: cute.Tensor, # (n_blocks, C_I) + n_blocks: cutlass.Int64, + start_token_in_proj: cutlass.Int64, + start_abs_pos: cutlass.Int64, + has_external_prev: cutlass.Constexpr, + M: cutlass.Constexpr, + C: cutlass.Constexpr, + C_I: cutlass.Constexpr, + NOPE: cutlass.Constexpr, + ROPE: cutlass.Constexpr, + COLS_PER_CTA: cutlass.Constexpr, +): + tx, _, _ = cute.arch.thread_idx() + bx, by, bz = cute.arch.block_idx() + + block_i = bx.to(cutlass.Int64) + base_col = by * COLS_PER_CTA + col = base_col + tx + + # bz == 0: main KV output with RoPE. bz == 1: indexer output, no RoPE. + if bz == 0: + if block_i < n_blocks and col < C: + # RoPE dims need even/odd pair. Let even lane compute/store both. + if col < NOPE: + val = _csa_raw_reduce_one_col( + proj_main, prev_main_b, B_a, B_b, + block_i, col, start_token_in_proj, has_external_prev, M, C, + ) + kv_out[block_i, col] = val.to(kv_out.element_type) + else: + rope_col = col - NOPE + if (rope_col % 2) == 0: + # Compute pair and rotate by block end position. + x0 = _csa_raw_reduce_one_col( + proj_main, prev_main_b, B_a, B_b, + block_i, col, start_token_in_proj, has_external_prev, M, C, + ) + x1 = _csa_raw_reduce_one_col( + proj_main, prev_main_b, B_a, B_b, + block_i, col + 1, start_token_in_proj, has_external_prev, M, C, + ) + block_end_pos = start_abs_pos + block_i * M + (M - 1) + half_idx = rope_col // 2 + cosv = cos_sin_cache[block_end_pos, half_idx].to(cutlass.Float32) + sinv = cos_sin_cache[block_end_pos, half_idx + ROPE // 2].to(cutlass.Float32) + kv_out[block_i, col] = (x0 * cosv - x1 * sinv).to(kv_out.element_type) + kv_out[block_i, col + 1] = (x0 * sinv + x1 * cosv).to(kv_out.element_type) + else: + if block_i < n_blocks and col < C_I: + val_i = _csa_raw_reduce_one_col( + proj_indexer, prev_indexer_b, B_I_a, B_I_b, + block_i, col, start_token_in_proj, has_external_prev, M, C_I, + ) + indexer_out[block_i, col] = val_i.to(indexer_out.element_type) + + +@cute.jit +def launch_csa_compress_projected( + proj_main: cute.Tensor, + proj_indexer: cute.Tensor, + prev_main_b: cute.Tensor, + prev_indexer_b: cute.Tensor, + B_a: cute.Tensor, + B_b: cute.Tensor, + B_I_a: cute.Tensor, + B_I_b: cute.Tensor, + cos_sin_cache: cute.Tensor, + kv_out: cute.Tensor, + indexer_out: cute.Tensor, + n_blocks: int, + start_token_in_proj: int, + start_abs_pos: int, + has_external_prev: cutlass.Constexpr, + stream: cuda.CUstream, + M: cutlass.Constexpr = 4, + C: cutlass.Constexpr = 512, + C_I: cutlass.Constexpr = 128, + NOPE: cutlass.Constexpr = 448, + ROPE: cutlass.Constexpr = 64, + COLS_PER_CTA: cutlass.Constexpr = 128, +): + grid_y = _ceil_div(C, COLS_PER_CTA) # enough for main; indexer just masks col < C_I + csa_compress_projected_kernel( + proj_main, + proj_indexer, + prev_main_b, + prev_indexer_b, + B_a, + B_b, + B_I_a, + B_I_b, + cos_sin_cache, + kv_out, + indexer_out, + n_blocks, + start_token_in_proj, + start_abs_pos, + has_external_prev, + M, + C, + C_I, + NOPE, + ROPE, + COLS_PER_CTA, + ).launch( + grid=[n_blocks, grid_y, 2], + block=[COLS_PER_CTA, 1, 1], + stream=stream, + ) + + +# ----------------------------------------------------------------------------- +# HCA fused compressor from packed projections +# ----------------------------------------------------------------------------- + +@cute.jit +def _hca_raw_reduce_one_col( + proj: cute.Tensor, # (N_tokens, 2*C): C, Z + B: cute.Tensor, # (M, C) + block_i: cutlass.Int64, + col: cutlass.Int32, + start_token_in_proj: cutlass.Int64, + M: cutlass.Constexpr, + C: cutlass.Constexpr, +) -> cutlass.Float32: + max_logit = cutlass.Float32(-3.4028234663852886e38) + + for p in cutlass.range_constexpr(M): + tok = start_token_in_proj + block_i * M + p + z = _read_hca_packed(proj, tok, col, 1, C) + B[p, col].to(cutlass.Float32) + max_logit = cute.math.fmax(max_logit, z) + + denom = cutlass.Float32(0.0) + acc = cutlass.Float32(0.0) + + for p in cutlass.range_constexpr(M): + tok = start_token_in_proj + block_i * M + p + z = _read_hca_packed(proj, tok, col, 1, C) + B[p, col].to(cutlass.Float32) + c = _read_hca_packed(proj, tok, col, 0, C) + e = _expf(z - max_logit) + denom += e + acc += e * c + + return acc / denom + + +@cute.kernel +def hca_compress_projected_kernel( + proj: cute.Tensor, # (N_tokens, 2*C) + B: cute.Tensor, # (M, C) + cos_sin_cache: cute.Tensor, # (max_pos, ROPE) + kv_out: cute.Tensor, # (n_blocks, C) + n_blocks: cutlass.Int64, + start_token_in_proj: cutlass.Int64, + start_abs_pos: cutlass.Int64, + M: cutlass.Constexpr, + C: cutlass.Constexpr, + NOPE: cutlass.Constexpr, + ROPE: cutlass.Constexpr, + COLS_PER_CTA: cutlass.Constexpr, +): + tx, _, _ = cute.arch.thread_idx() + bx, by, _ = cute.arch.block_idx() + + block_i = bx.to(cutlass.Int64) + col = by * COLS_PER_CTA + tx + + if block_i < n_blocks and col < C: + if col < NOPE: + val = _hca_raw_reduce_one_col(proj, B, block_i, col, start_token_in_proj, M, C) + kv_out[block_i, col] = val.to(kv_out.element_type) + else: + rope_col = col - NOPE + if (rope_col % 2) == 0: + x0 = _hca_raw_reduce_one_col(proj, B, block_i, col, start_token_in_proj, M, C) + x1 = _hca_raw_reduce_one_col(proj, B, block_i, col + 1, start_token_in_proj, M, C) + block_end_pos = start_abs_pos + block_i * M + (M - 1) + half_idx = rope_col // 2 + cosv = cos_sin_cache[block_end_pos, half_idx].to(cutlass.Float32) + sinv = cos_sin_cache[block_end_pos, half_idx + ROPE // 2].to(cutlass.Float32) + kv_out[block_i, col] = (x0 * cosv - x1 * sinv).to(kv_out.element_type) + kv_out[block_i, col + 1] = (x0 * sinv + x1 * cosv).to(kv_out.element_type) + + +@cute.jit +def launch_hca_compress_projected( + proj: cute.Tensor, + B: cute.Tensor, + cos_sin_cache: cute.Tensor, + kv_out: cute.Tensor, + n_blocks: int, + start_token_in_proj: int, + start_abs_pos: int, + stream: cuda.CUstream, + M: cutlass.Constexpr = 128, + C: cutlass.Constexpr = 512, + NOPE: cutlass.Constexpr = 448, + ROPE: cutlass.Constexpr = 64, + COLS_PER_CTA: cutlass.Constexpr = 128, +): + grid_y = _ceil_div(C, COLS_PER_CTA) + hca_compress_projected_kernel( + proj, + B, + cos_sin_cache, + kv_out, + n_blocks, + start_token_in_proj, + start_abs_pos, + M, + C, + NOPE, + ROPE, + COLS_PER_CTA, + ).launch( + grid=[n_blocks, grid_y, 1], + block=[COLS_PER_CTA, 1, 1], + stream=stream, + ) + + +# ----------------------------------------------------------------------------- +# PyTorch-side packing helpers +# ----------------------------------------------------------------------------- + + +def pack_csa_main_weights(W_a_KV, W_a_Z, W_b_KV, W_b_Z): + """Return W packed as [W_a_KV | W_a_Z | W_b_KV | W_b_Z].""" + return torch.cat([W_a_KV, W_a_Z, W_b_KV, W_b_Z], dim=1).contiguous() + + +def pack_csa_indexer_weights(W_I_a_KV, W_I_a_Z, W_I_b_KV, W_I_b_Z): + """Return W_I packed as [W_I_a_KV | W_I_a_Z | W_I_b_KV | W_I_b_Z].""" + return torch.cat([W_I_a_KV, W_I_a_Z, W_I_b_KV, W_I_b_Z], dim=1).contiguous() + + +def pack_hca_weights(W_KV, W_Z): + """Return W packed as [W_KV | W_Z].""" + return torch.cat([W_KV, W_Z], dim=1).contiguous() + + +def make_dummy_prev_b(M: int, OUT: int, *, device, dtype): + """Dummy external previous-block projection for fresh prefill.""" + return torch.empty((M, 2 * OUT), device=device, dtype=dtype) diff --git a/cutedsl/mhc_inference_layer.py b/cutedsl/mhc_inference_layer.py new file mode 100644 index 00000000..06e6cb72 --- /dev/null +++ b/cutedsl/mhc_inference_layer.py @@ -0,0 +1,501 @@ +""" +mHC (Manifold-Constrained Hyper-Connections) — Inference Layer. + +Implements Section 2.2 of the DeepSeek-V4 paper for the forward pass only. + +At inference the Sinkhorn-Knopp constraint has already been enforced during +training, but B_l is still *dynamically generated* per-token from the input +residual state. So we still need to: + 1. Project the flattened residual → raw A/B/C parameter values. + 2. Apply sigmoid (A, C) and Sinkhorn-Knopp 20 iters (B). + 3. Mix residual streams. + +The only thing that changes vs training is that we skip the loss and gradient +through the Sinkhorn projection — the forward arithmetic is identical. + +--------------------------------------------------------------------- +V4-Pro reference dimensions (Section 4.2.1) +--------------------------------------------------------------------- + d = 7168 hidden dim + n_hc = 4 hyper-connection expansion factor + N_proj = 24 fused output of W_pre(4) + W_res(16) + W_post(4) + K_proj = 4*7168 = 28672 = n_hc * d (flattened residual) + t_max = 20 Sinkhorn iterations + +--------------------------------------------------------------------- +Kernel dependency +--------------------------------------------------------------------- +tf32_hc_prenorm_gemm (DeepGEMM, SM90/SM100) + a: (T, K) BF16 — flattened residual X_flat + b: (N, K) FP32 — stacked weight [W_pre; W_res; W_post] + d: (S, T, N) or (T, N) FP32 — raw projection outputs (pre-normalised) + sqr_sum: (S, T) or (T,) FP32 — Σ a² per token (for RMSNorm denominator) + num_splits = S (16 recommended for K=28672) + +After the call: + d = d.sum(0) → (T, N) + sqr_sum = sqr_sum.sum(0) → (T,) + rms_scale = sqrt(K / (sqr_sum + eps)) + d_norm = d * rms_scale[:,None] — equivalent to RMSNorm(X_flat) @ W_stacked +""" + +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F + + +# --------------------------------------------------------------------------- +# Try importing DeepGEMM; fall back to plain BF16 matmul if unavailable. +# --------------------------------------------------------------------------- +try: + import deep_gemm + _HAS_DEEP_GEMM = True +except ImportError: + _HAS_DEEP_GEMM = False + + +NUM_SPLITS = 16 # K-split count for tf32_hc_prenorm_gemm numerical stability +EPS_RMSN = 1e-6 + + +# --------------------------------------------------------------------------- +# Sinkhorn-Knopp projection (T batched 4×4 matrices, 20 iters) +# --------------------------------------------------------------------------- + +def sinkhorn_knopp( + M: torch.Tensor, # (T, n, n) positive (after exp) + t_max: int = 20, +) -> torch.Tensor: + """ + Project each (n×n) positive matrix onto the Birkhoff polytope + (doubly stochastic matrices) via alternating row/col normalisation. + + Paper eq. (8): M^(t) = T_r( T_c( M^(t-1) ) ) + where T_r = row-normalise, T_c = col-normalise. + + For n=4 and t_max=20 this is ~160 tiny operations — no kernel needed. + All ops stay on GPU via standard PyTorch. + """ + for _ in range(t_max): + M = M / (M.sum(dim=-1, keepdim=True) + EPS_RMSN) # T_r (row) + M = M / (M.sum(dim=-2, keepdim=True) + EPS_RMSN) # T_c (col) + return M + + +# --------------------------------------------------------------------------- +# Context carried between pre_block and post_block +# --------------------------------------------------------------------------- + +@dataclass +class mHCContext: + """Holds the per-token mixing matrices computed in pre_block.""" + B_l: torch.Tensor # (T, n_hc, n_hc) doubly stochastic residual transform + C_l: torch.Tensor # (T, n_hc) output mapping (before unsqueeze) + + +# --------------------------------------------------------------------------- +# mHC layer +# --------------------------------------------------------------------------- + +class mHCLayer: + """ + Wraps one transformer sub-layer (attention *or* MoE) with the mHC + residual update. + + Typical call pattern per layer: + + x_in, ctx = mhc.pre_block(X_l) + F_out = transformer_sublayer(x_in) # (T, d) + X_next = mhc.post_block(X_l, F_out, ctx) + + where X_l has shape (T, n_hc, d) — the expanded residual state. + The first call at layer 0 should use X_0 initialised via `init_state`. + """ + + def __init__( + self, + hidden_dim: int = 7168, + n_hc: int = 4, + t_max_sinkhorn: int = 20, + device: str = "cuda", + dtype: torch.dtype = torch.bfloat16, + ): + self.d = hidden_dim + self.n_hc = n_hc + self.K_proj = n_hc * hidden_dim # 28672 for V4-Pro + self.N_proj = n_hc + n_hc * n_hc + n_hc # 4 + 16 + 4 = 24 + self.t_max = t_max_sinkhorn + self.device = device + self.dtype = dtype + + # ── Learnable weights (set via load_weights) ────────────────── + # Stacked projection: b shape = (N_proj, K_proj) in FP32 + # Stored as separate tensors, fused in forward if DeepGEMM available. + self.W_pre = self._buf(n_hc, self.K_proj, dtype=torch.float32) # (4, K) + self.W_res = self._buf(n_hc * n_hc, self.K_proj, dtype=torch.float32) # (16, K) + self.W_post = self._buf(n_hc, self.K_proj, dtype=torch.float32) # (4, K) + + # Static biases (eq. 3-5, S^pre / S^res / S^post) + self.S_pre = self._buf(1, n_hc) # (1, 4) + self.S_res = self._buf(n_hc, n_hc) # (4, 4) + self.S_post = self._buf(n_hc, 1) # (4, 1) + + # Learnable gating scalars (α), initialised small during training + # At inference these are just scalars loaded from the checkpoint. + self.alpha_pre = torch.zeros(1, device=device, dtype=torch.float32) + self.alpha_res = torch.zeros(1, device=device, dtype=torch.float32) + self.alpha_post = torch.zeros(1, device=device, dtype=torch.float32) + + # Pre-allocated split buffers (set in _ensure_buffers) + self._d_split = None # (NUM_SPLITS, max_T, N_proj) FP32 + self._sqr_sum_split = None # (NUM_SPLITS, max_T) FP32 + self._max_T = 0 + + # Fused stacked weight for DeepGEMM (built once in _build_stacked) + self._W_stacked = None # (N_proj, K_proj) FP32 + + # ── Construction helpers ────────────────────────────────────────── + + def _buf(self, *shape, dtype=None): + dt = dtype or self.dtype + return torch.empty(*shape, dtype=dt, device=self.device) + + def load_weights( + self, + W_pre: torch.Tensor, # (n_hc, K) FP32 + W_res: torch.Tensor, # (n_hc², K) FP32 + W_post: torch.Tensor, # (n_hc, K) FP32 + S_pre: torch.Tensor, # (1, n_hc) + S_res: torch.Tensor, # (n_hc, n_hc) + S_post: torch.Tensor, # (n_hc, 1) + alpha_pre: float, + alpha_res: float, + alpha_post: float, + ): + """ + Load all mHC parameters from the checkpoint. + + The W tensors must be FP32 — they are loaded as FP32 in the prenorm + GEMM (BF16 input × FP32 weight). Everything else can be BF16 in the + checkpoint and will be cast here. + """ + def _f32(t): return t.to(device=self.device, dtype=torch.float32).contiguous() + def _cvt(t): return t.to(device=self.device, dtype=self.dtype).contiguous() + + self.W_pre = _f32(W_pre) + self.W_res = _f32(W_res) + self.W_post = _f32(W_post) + self.S_pre = _cvt(S_pre) + self.S_res = _cvt(S_res) + self.S_post = _cvt(S_post) + self.alpha_pre = torch.tensor(alpha_pre, dtype=torch.float32, device=self.device) + self.alpha_res = torch.tensor(alpha_res, dtype=torch.float32, device=self.device) + self.alpha_post = torch.tensor(alpha_post, dtype=torch.float32, device=self.device) + self._W_stacked = None # invalidate cache + + def _build_stacked(self): + """Fuse W_pre / W_res / W_post into one (N_proj, K_proj) FP32 tensor.""" + self._W_stacked = torch.cat([self.W_pre, self.W_res, self.W_post], dim=0) + # Must be K-major (contiguous along K) for DeepGEMM + self._W_stacked = self._W_stacked.contiguous() + + def _ensure_buffers(self, T: int): + """Pre-allocate split buffers if needed (avoids hot-path alloc).""" + if T <= self._max_T: + return + self._d_split = torch.empty( + NUM_SPLITS, T, self.N_proj, dtype=torch.float32, device=self.device + ) + self._sqr_sum_split = torch.empty( + NUM_SPLITS, T, dtype=torch.float32, device=self.device + ) + self._max_T = T + + # ── Forward ────────────────────────────────────────────────────── + + def _project_and_rms(self, X_flat: torch.Tensor) -> torch.Tensor: + """ + Compute RMSNorm(X_flat) @ W_stacked.T → (T, N_proj) FP32. + + Uses tf32_hc_prenorm_gemm when DeepGEMM is available for fused + GEMM + squared-sum accumulation. Falls back to plain BF16 matmul. + + X_flat: (T, K_proj) BF16 + """ + T = X_flat.shape[0] + K = self.K_proj + + if _HAS_DEEP_GEMM: + if self._W_stacked is None: + self._build_stacked() + self._ensure_buffers(T) + + d_s = self._d_split[:, :T, :] # view, no copy + ss_s = self._sqr_sum_split[:, :T] + + # a: (T, K) BF16 b: (N, K) FP32 → d_s: (S, T, N), ss_s: (S, T) + # Both d and sqr_sum are OUTPUT tensors (written by the kernel). + deep_gemm.tf32_hc_prenorm_gemm( + X_flat.contiguous(), # a + self._W_stacked, # b (N, K) FP32 + d_s, # d (S, T, N) + ss_s, # sqr_sum (S, T) + num_splits=NUM_SPLITS, + ) + + d_out = d_s.sum(dim=0) # (T, N) + sqr_sum = ss_s.sum(dim=0) # (T,) + + else: + # Fallback: BF16 matmul + manual squared sum + if self._W_stacked is None: + self._build_stacked() + + x_f32 = X_flat.float() + d_out = x_f32 @ self._W_stacked.T # (T, N) + sqr_sum = x_f32.pow(2).sum(dim=-1) # (T,) + + # RMSNorm scale: multiply raw GEMM output by rsqrt(mean(x²)) + # mean(x²) = sqr_sum / K → scale = sqrt(K / sqr_sum) + rms_scale = torch.sqrt(K / (sqr_sum + EPS_RMSN)) # (T,) + return (d_out * rms_scale.unsqueeze(-1)).to(self.dtype) # (T, N) in BF16 + + def _dynamic_params( + self, X_l: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute per-token A_l, B_l, C_l from the current residual state. + + X_l: (T, n_hc, d) + + Returns: + A_l: (T, n_hc) sigmoid-constrained input mapping + B_l: (T, n_hc, n_hc) doubly-stochastic residual transform + C_l: (T, n_hc) 2*sigmoid-constrained output mapping + """ + T, n, d = X_l.shape + assert n == self.n_hc and d == self.d + + # Flatten: (T, n_hc*d) + X_flat = X_l.reshape(T, self.K_proj).to(self.dtype) + + # Fused RMSNorm projection: (T, N_proj) + proj = self._project_and_rms(X_flat).float() # keep FP32 for precision + + # Split into raw A / B / C + i0, i1, i2, i3 = 0, self.n_hc, self.n_hc + self.n_hc**2, self.N_proj + A_raw = proj[:, i0:i1] # (T, n_hc) + B_raw = proj[:, i1:i2] # (T, n_hc²) + C_raw = proj[:, i2:i3] # (T, n_hc) + + # Add static biases and scale by learned gating factors (eq. 3-5) + S_pre = self.S_pre.float() # (1, n_hc) + S_res = self.S_res.float() # (n_hc, n_hc) + S_post = self.S_post.float() # (n_hc, 1) + + A_tilde = self.alpha_pre * A_raw + S_pre # (T, n_hc) + B_tilde = self.alpha_res * B_raw + S_res.flatten().unsqueeze(0) # (T, n_hc²) + C_tilde = self.alpha_post * C_raw + S_post.flatten().unsqueeze(0) # (T, n_hc) + + # Apply constraints (paper eqs. 6-8) + A_l = torch.sigmoid(A_tilde) # (T, n_hc) + C_l = 2.0 * torch.sigmoid(C_tilde) # (T, n_hc) + + # B_l: exp → Sinkhorn-Knopp → doubly stochastic + B_exp = torch.exp(B_tilde).reshape(T, self.n_hc, self.n_hc) + B_l = sinkhorn_knopp(B_exp, t_max=self.t_max) # (T, n_hc, n_hc) + + # Keep B_l in FP32 — the (T,4,4) bmm precision matters more than memory. + # A_l and C_l are cast to dtype for the input/output mixing multiplies. + return A_l.to(self.dtype), B_l, C_l.to(self.dtype) + + # ---------------------------------------------------------------- + # Public API: pre_block / post_block + # ---------------------------------------------------------------- + + def pre_block( + self, + X_l: torch.Tensor, # (T, n_hc, d) BF16 + ) -> Tuple[torch.Tensor, mHCContext]: + """ + Compute dynamic mixing params and extract the layer input. + + Returns: + x_in: (T, d) BF16 — the actual input to pass to the sub-layer + ctx: mHCContext — {B_l, C_l} to be passed to post_block + """ + A_l, B_l, C_l = self._dynamic_params(X_l) + + # Layer input: x_in = A_l @ X_l (per token, weighted sum of streams) + # A_l: (T, n_hc) X_l: (T, n_hc, d) + # → (T, 1, n_hc) bmm (T, n_hc, d) = (T, 1, d) → squeeze + x_in = torch.bmm(A_l.unsqueeze(1), X_l).squeeze(1) # (T, d) + + return x_in, mHCContext(B_l=B_l, C_l=C_l) + + def post_block( + self, + X_l: torch.Tensor, # (T, n_hc, d) BF16 — residual state BEFORE sub-layer + F_out: torch.Tensor, # (T, d) BF16 — sub-layer output + ctx: mHCContext, + ) -> torch.Tensor: + """ + Apply the mHC residual update (eq. 1): + X_{l+1} = B_l @ X_l + C_l ⊗ F_out + + Returns: + X_next: (T, n_hc, d) BF16 + """ + # B_l is FP32, X_l is BF16 — bmm upcasts automatically in PyTorch. + BX = torch.bmm(ctx.B_l, X_l.float()) + CF = ctx.C_l.unsqueeze(-1) * F_out.unsqueeze(1) # (T, n_hc, d) + return (BX + CF.float()).to(self.dtype) # (T, n_hc, d) + + # ---------------------------------------------------------------- + # Utility + # ---------------------------------------------------------------- + + @staticmethod + def init_state( + embeddings: torch.Tensor, # (T, d) BF16 — token embeddings + n_hc: int = 4, + ) -> torch.Tensor: + """ + Initialise X_0 for the first layer. + + The paper figure shows the embedding feeding into the first + Residual Mixing. We broadcast the embedding across all n_hc + residual streams as the simplest valid initialisation. + + Returns: (T, n_hc, d) BF16 + """ + return embeddings.unsqueeze(1).expand(-1, n_hc, -1).clone() + + @staticmethod + def read_out(X_L: torch.Tensor) -> torch.Tensor: + """ + Extract the final hidden state from the last residual state. + + Convention: stream 0 is the primary output stream (standard choice + for HC models — the first stream carries the main residual). + + Returns: (T, d) BF16 + """ + return X_L[:, 0, :] + + +# --------------------------------------------------------------------------- +# Quick smoke test +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + import sys + + torch.manual_seed(0) + device = "cuda" if torch.cuda.is_available() else "cpu" + dtype = torch.bfloat16 + + D, N_HC = 7168, 4 + K = N_HC * D # 28672 + N_PROJ = N_HC + N_HC ** 2 + N_HC # 24 + + mhc = mHCLayer(hidden_dim=D, n_hc=N_HC, device=device, dtype=dtype) + + # Random weights matching the expected shapes + mhc.load_weights( + W_pre = torch.randn(N_HC, K, dtype=torch.float32), + W_res = torch.randn(N_HC**2, K, dtype=torch.float32), + W_post = torch.randn(N_HC, K, dtype=torch.float32), + S_pre = torch.zeros(1, N_HC, dtype=dtype), + S_res = torch.eye(N_HC, dtype=dtype), # identity: pure residual + S_post = torch.zeros(N_HC, 1, dtype=dtype), + alpha_pre = 0.01, + alpha_res = 0.01, + alpha_post = 0.01, + ) + + T = 4 # 4 tokens + + # ── Forward pass ──────────────────────────────────────────────── + embeddings = torch.randn(T, D, dtype=dtype, device=device) + X = mHCLayer.init_state(embeddings, n_hc=N_HC) + print(f"X_0: {X.shape} (T={T}, n_hc={N_HC}, d={D})") + + # Simulate a 2-layer stack + for layer_idx in range(2): + x_in, ctx = mhc.pre_block(X) + print(f"\nLayer {layer_idx}:") + print(f" x_in (to sub-layer): {x_in.shape}") + print(f" B_l: {ctx.B_l.shape}") + print(f" C_l: {ctx.C_l.shape}") + + # Dummy sub-layer: identity (for testing the mHC mechanics) + F_out = x_in + + X = mhc.post_block(X, F_out, ctx) + print(f" X_next: {X.shape}") + + hidden = mHCLayer.read_out(X) + print(f"\nFinal hidden: {hidden.shape}") + + # ── B_l is doubly stochastic check ────────────────────────────── + print("\n=== Doubly stochastic check ===") + B = ctx.B_l # (T, 4, 4) — FP32 from Sinkhorn + row_sums = B.sum(dim=-1) # (T, 4) — should all be ~1 + col_sums = B.sum(dim=-2) # (T, 4) — should all be ~1 + print(f" row sum range: [{row_sums.min():.6f}, {row_sums.max():.6f}] (want ≈ 1.0)") + print(f" col sum range: [{col_sums.min():.6f}, {col_sums.max():.6f}] (want ≈ 1.0)") + assert (row_sums - 1).abs().max() < 1e-3, "B_l rows do not sum to 1" + assert (col_sums - 1).abs().max() < 1e-3, "B_l cols do not sum to 1" + print(" PASSED") + + # ── A_l and C_l are bounded ────────────────────────────────────── + # (Re-run dynamic params to expose A_l for checking) + A_l, B_l2, C_l = mhc._dynamic_params(X) + print(f"\n=== A_l ∈ (0,1) check ===") + print(f" A_l range: [{A_l.min():.4f}, {A_l.max():.4f}] (want ∈ (0,1))") + assert A_l.min() > 0 and A_l.max() < 1, "A_l out of sigmoid range" + print(" PASSED") + print(f"\n=== C_l ∈ (0,2) check ===") + print(f" C_l range: [{C_l.min():.4f}, {C_l.max():.4f}] (want ∈ (0,2))") + assert C_l.min() > 0 and C_l.max() < 2, "C_l out of 2*sigmoid range" + print(" PASSED") + + # ── Consistency: S_res = identity → B_l ≈ doubly-stochastic I ─── + print("\n=== S_res=I, alpha_res≈0 → B_l ≈ uniform matrix ===") + # With S_res = I and alpha_res ≈ 0: + # B_tilde ≈ I → exp(I) → Sinkhorn of exp(I) + # exp(I) is diag-dominant; after Sinkhorn it converges to a doubly stochastic matrix. + # We just check doubly-stochastic property is preserved (already checked above). + print(" Already verified via doubly stochastic check above.") + + # ── Equivalence: T=1 decode vs T=N prefill ────────────────────── + print("\n=== Token-by-token decode == batch prefill ===") + T_big = 8 + h_big = torch.randn(T_big, D, dtype=dtype, device=device) + X_batch = mHCLayer.init_state(h_big, n_hc=N_HC) + + # Batch + x_in_batch, ctx_batch = mhc.pre_block(X_batch) + + # Token by token + x_in_tokens = [] + for t in range(T_big): + X_t = X_batch[t:t+1] # (1, n_hc, d) + x_in_t, _ = mhc.pre_block(X_t) + x_in_tokens.append(x_in_t) + x_in_seq = torch.cat(x_in_tokens, dim=0) # (T_big, d) + + diff = (x_in_batch - x_in_seq).abs().max().item() + print(f" max |batch - sequential| on x_in: {diff:.6f}") + assert diff < 1e-2, f"Mismatch too large: {diff}" + print(" PASSED") + + print("\nAll checks done.") + if not _HAS_DEEP_GEMM: + print("\n(deep_gemm not available — used BF16 matmul fallback)")