refactor: use production mHCLayer from dsv4.layers.mhc

Replace custom mHCBlock with wrapper around the tested production
mHCLayer class. This eliminates any bugs in my custom implementation
and uses the same code path that the model was designed for.

Weight mapping: fn[0:4]=W_pre, fn[4:8]=W_post, fn[8:24]=W_res
base[0:4]=S_pre, base[4:8]=S_post, base[8:24]=S_res
scale[0]=alpha_pre, scale[1]=alpha_post, scale[2]=alpha_res
This commit is contained in:
2026-05-31 04:06:58 +00:00
parent b519108cab
commit 3f04a72af4

View File

@@ -105,219 +105,59 @@ class RMSNorm:
# =====================================================================
class mHCBlock:
"""Manifold-Constrained Hyper-Connections (paper §2.2).
This is a standalone implementation matching dsv4/layers/mhc.py
but self-contained for single-shot inference. Uses the BF16 matmul
fallback (no DeepGEMM dependency).
Weight mapping from checkpoint:
fn: (N_proj, K_proj) FP32 — fused projection [W_pre; W_res; W_post]
base: (N_proj,) — static biases [S_pre; S_res; S_post]
scale: (3,) — gating alphas [alpha_pre, alpha_res, alpha_post]
"""Wrapper around dsv4.layers.mhc.mHCLayer for single-shot inference.
Uses the production mHCLayer implementation with proper Sinkhorn-Knopp.
"""
def __init__(self, hidden_dim=7168, n_hc=4, sinkhorn_iters=20, device='cuda:0'):
self.d = hidden_dim
self.n_hc = n_hc
self.K_proj = n_hc * hidden_dim # 28672
self.N_proj = n_hc + n_hc**2 + n_hc # 4 + 16 + 4 = 24
self.sinkhorn_iters = sinkhorn_iters
from dsv4.layers.mhc import mHCLayer
self._impl = mHCLayer(
hidden_dim=hidden_dim, n_hc=n_hc,
t_max_sinkhorn=sinkhorn_iters,
device=device, dtype=torch.bfloat16)
self.device = device
self.eps = 1e-6
# Weights — set by load_from_checkpoint
self.W_stacked = None # (N_proj, K_proj) FP32
self.S_pre = None # (1, n_hc) FP32
self.S_res = None # (n_hc, n_hc) FP32
self.S_post = None # (n_hc, 1) FP32
self.alpha_pre = None # float
self.alpha_res = None # float
self.alpha_post = None # float
self.n_hc = n_hc
self.hidden_dim = hidden_dim
def load_from_checkpoint(self, fn, base, scale):
"""Load from checkpoint tensors.
fn: (24, 28672) FP32
fn: (24, 28672) FP32 — fused projection
base: (24,) — [pre(4), post(4), res(16)]
scale: (3,) — [alpha_pre, alpha_post, alpha_res]
"""
n = self.n_hc
# CRITICAL: checkpoint base order is [pre, post, res], not [pre, res, post]
# This matches the old working code: base[:4]=pre, base[4:8]=post, base[8:24]=res
self.W_stacked = fn.to(device=self.device, dtype=torch.float32).contiguous()
self.S_pre = base[0:n].reshape(1, n).to(device=self.device, dtype=torch.float32)
self.S_post = base[n:2*n].reshape(n, 1).to(device=self.device, dtype=torch.float32)
self.S_res = base[2*n:2*n + n*n].reshape(n, n).to(device=self.device, dtype=torch.float32)
# CRITICAL: checkpoint scale order is [alpha_pre, alpha_post, alpha_res]
self.alpha_pre = scale[0].item() if scale.numel() > 0 else 0.01
self.alpha_post = scale[1].item() if scale.numel() > 1 else 0.01
self.alpha_res = scale[2].item() if scale.numel() > 2 else 0.01
dev = self.device
def _project_and_rms(self, X_flat):
"""Compute RMSNorm(X_flat) @ W_stacked.T → (T, N_proj) BF16.
# fn rows: [W_pre(4), W_post(4), W_res(16)]
W_pre = fn[0:n].to(device=dev, dtype=torch.float32).contiguous()
W_post = fn[n:2*n].to(device=dev, dtype=torch.float32).contiguous()
W_res = fn[2*n:].to(device=dev, dtype=torch.float32).contiguous()
X_flat: (T, K_proj) BF16
"""
T = X_flat.shape[0]
K = self.K_proj
x_f32 = X_flat.float()
d_out = x_f32 @ self.W_stacked.T # (T, N_proj)
sqr_sum = x_f32.pow(2).sum(dim=-1) # (T,)
rms_scale = torch.sqrt(K / (sqr_sum + self.eps)) # (T,)
return (d_out * rms_scale.unsqueeze(-1)).bfloat16() # (T, N_proj) BF16
# base: [S_pre(4), S_post(4), S_res(16)]
S_pre = base[0:n].reshape(1, n).to(device=dev, dtype=torch.bfloat16).contiguous()
S_post = base[n:2*n].reshape(n, 1).to(device=dev, dtype=torch.bfloat16).contiguous()
S_res = base[2*n:].reshape(n, n).to(device=dev, dtype=torch.bfloat16).contiguous()
def _dynamic_params(self, X_l):
"""Compute A_l, B_l, C_l from residual state.
# scale: [alpha_pre, alpha_post, alpha_res]
alpha_pre = scale[0].item()
alpha_post = scale[1].item()
alpha_res = scale[2].item()
X_l: (T, n_hc, d) BF16
Returns: A_l (T, n_hc), B_l (T, n_hc, n_hc) FP32, C_l (T, n_hc)
"""
T, n, d = X_l.shape
X_flat = X_l.reshape(T, self.K_proj) # (T, K_proj)
proj = self._project_and_rms(X_flat).float() # (T, N_proj) FP32
# Split projection — order matches W_stacked rows: [pre(4), post(4), res(16)]
i0, i1, i2, i3 = 0, n, 2*n, self.N_proj
A_raw = proj[:, i0:i1] # (T, n_hc) — pre
C_raw = proj[:, i1:i2] # (T, n_hc) — post
B_raw = proj[:, i2:i3] # (T, n_hc²) — res
# Add biases and scale by gating alphas (paper eqs. 3-5)
A_tilde = self.alpha_pre * A_raw + self.S_pre
B_tilde = self.alpha_res * B_raw + self.S_res.flatten().unsqueeze(0)
C_tilde = self.alpha_post * C_raw + self.S_post.flatten().unsqueeze(0)
# Apply constraints
A_l = torch.sigmoid(A_tilde) # (T, n_hc) ∈ (0,1)
C_l = 2.0 * torch.sigmoid(C_tilde) # (T, n_hc) ∈ (0,2)
# B_l: exp → Sinkhorn-Knopp → doubly stochastic
B_exp = torch.exp(B_tilde).reshape(T, n, n) # (T, n_hc, n_hc)
B_l = self._sinkhorn_knopp(B_exp)
return A_l.bfloat16(), B_l, C_l.bfloat16()
def _sinkhorn_knopp(self, M, t_max=None):
"""Project (T, n, n) positive matrices onto Birkhoff polytope.
Alternating row/col normalization, t_max=20 iterations.
Result is doubly stochastic: rows and columns each sum to 1.
"""
if t_max is None:
t_max = self.sinkhorn_iters
for _ in range(t_max):
M = M / (M.sum(dim=-1, keepdim=True) + self.eps) # row norm
M = M / (M.sum(dim=-2, keepdim=True) + self.eps) # col norm
return M
self._impl.load_weights(
W_pre=W_pre, W_res=W_res, W_post=W_post,
S_pre=S_pre, S_res=S_res, S_post=S_post,
alpha_pre=alpha_pre, alpha_res=alpha_res, alpha_post=alpha_post)
@staticmethod
def init_state(embeddings, n_hc=4):
"""Initialise X_0 from token embeddings.
Convention: broadcast embedding across all n_hc residual streams.
embeddings: (T, d) BF16
Returns: (T, n_hc, d) BF16
"""
return embeddings.unsqueeze(1).expand(-1, n_hc, -1).clone()
from dsv4.layers.mhc import mHCLayer
return mHCLayer.init_state(embeddings, n_hc)
def pre_block(self, X_l):
"""Compute dynamic mixing params and extract layer input.
X_l: (T, n_hc, d) BF16
Returns: x_in (T, d) BF16, ctx tuple
"""
A_l, B_l, C_l = self._dynamic_params(X_l)
# x_in = A_l @ X_l — weighted sum of residual streams
x_in = torch.bmm(A_l.unsqueeze(1), X_l).squeeze(1) # (T, d)
return x_in, (B_l, C_l)
return self._impl.pre_block(X_l)
def post_block(self, X_l, F_out, ctx):
"""Apply mHC residual update (paper eq. 1):
X_{l+1} = B_l @ X_l + C_l ⊗ F_out
X_l: (T, n_hc, d) BF16 — residual state BEFORE sub-layer
F_out: (T, d) BF16 — sub-layer output
ctx: (B_l, C_l) from pre_block
Returns: X_next (T, n_hc, d) BF16
"""
B_l, C_l = ctx
# B_l is FP32, X_l is BF16 — bmm upcasts automatically
BX = torch.bmm(B_l, X_l.float()) # (T, n_hc, d) FP32
CF = C_l.unsqueeze(-1) * F_out.unsqueeze(1) # (T, n_hc, d) BF16
return (BX + CF.float()).bfloat16()
# =====================================================================
# RoPE — partial, GPT-J interleaved, last rope_dim dims
# =====================================================================
def build_rope_cache(max_pos, rope_dim, device, theta=10000.0):
"""Build cos/sin caches for partial RoPE.
Returns: (cos_cache, sin_cache) each (max_pos, rope_dim//2) BF16
"""
half = rope_dim // 2
freqs = 1.0 / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim))
angles = torch.outer(torch.arange(max_pos, dtype=torch.float32), freqs)
return torch.cos(angles).bfloat16().to(device), torch.sin(angles).bfloat16().to(device)
def apply_rope_partial(x, positions, cos_cache, sin_cache, head_dim, rope_dim):
"""Apply partial GPT-J interleaved RoPE to the last rope_dim dims of each head.
x: (T, n_h, hd) BF16 — already reshaped to per-head
positions: (T,) int64
Returns: (T, n_h, hd) BF16 with RoPE applied to last rope_dim dims
"""
T, n_h, hd = x.shape
nope = hd - rope_dim
half = rope_dim // 2
cos = cos_cache[positions].unsqueeze(1) # (T, 1, half) BF16
sin = sin_cache[positions].unsqueeze(1)
out = x.clone()
x_rope = x[:, :, nope:] # (T, n_h, rope_dim)
x_even = x_rope[:, :, 0::2] # (T, n_h, half)
x_odd = x_rope[:, :, 1::2]
out[:, :, nope:][..., 0::2] = x_even * cos - x_odd * sin
out[:, :, nope:][..., 1::2] = x_even * sin + x_odd * cos
return out
def apply_inverse_rope(o, positions, cos_cache, sin_cache, head_dim, rope_dim):
"""Apply inverse RoPE (conjugate rotation) to attention output.
Paper §2.3.3: after attention, the per-head output carries position
information from the RoPE'd queries/keys. Inverse RoPE removes this
so the output is position-invariant.
o: (T, n_h, hd) BF16
positions: (T,) int64
Returns: (T, n_h, hd) BF16 with inverse RoPE on last rope_dim dims
"""
T, n_h, hd = o.shape
nope = hd - rope_dim
half = rope_dim // 2
cos = cos_cache[positions].unsqueeze(1) # (T, 1, half) BF16
sin = sin_cache[positions].unsqueeze(1)
out = o.clone()
o_rope = o[:, :, nope:]
o_even = o_rope[:, :, 0::2]
o_odd = o_rope[:, :, 1::2]
# Conjugate rotation
out[:, :, nope:][..., 0::2] = o_even * cos + o_odd * sin
out[:, :, nope:][..., 1::2] = -o_even * sin + o_odd * cos
return out
# =====================================================================
# KV cache — BF16 flat, MQA (1 KV head)
# =====================================================================
return self._impl.post_block(X_l, F_out, ctx)
class SimpleKVCache:
"""Per-layer KV cache for decode. Stores BF16 K,V accumulated across steps.