Files
nvfp4-megamoe-kernel/cutedsl/mhc_inference_layer.py
2026-05-21 05:55:22 +00:00

502 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
mHC (Manifold-Constrained Hyper-Connections) — Inference Layer.
Implements Section 2.2 of the DeepSeek-V4 paper for the forward pass only.
At inference the Sinkhorn-Knopp constraint has already been enforced during
training, but B_l is still *dynamically generated* per-token from the input
residual state. So we still need to:
1. Project the flattened residual → raw A/B/C parameter values.
2. Apply sigmoid (A, C) and Sinkhorn-Knopp 20 iters (B).
3. Mix residual streams.
The only thing that changes vs training is that we skip the loss and gradient
through the Sinkhorn projection — the forward arithmetic is identical.
---------------------------------------------------------------------
V4-Pro reference dimensions (Section 4.2.1)
---------------------------------------------------------------------
d = 7168 hidden dim
n_hc = 4 hyper-connection expansion factor
N_proj = 24 fused output of W_pre(4) + W_res(16) + W_post(4)
K_proj = 4*7168 = 28672 = n_hc * d (flattened residual)
t_max = 20 Sinkhorn iterations
---------------------------------------------------------------------
Kernel dependency
---------------------------------------------------------------------
tf32_hc_prenorm_gemm (DeepGEMM, SM90/SM100)
a: (T, K) BF16 — flattened residual X_flat
b: (N, K) FP32 — stacked weight [W_pre; W_res; W_post]
d: (S, T, N) or (T, N) FP32 — raw projection outputs (pre-normalised)
sqr_sum: (S, T) or (T,) FP32 — Σ a² per token (for RMSNorm denominator)
num_splits = S (16 recommended for K=28672)
After the call:
d = d.sum(0) → (T, N)
sqr_sum = sqr_sum.sum(0) → (T,)
rms_scale = sqrt(K / (sqr_sum + eps))
d_norm = d * rms_scale[:,None] — equivalent to RMSNorm(X_flat) @ W_stacked
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
# ---------------------------------------------------------------------------
# Try importing DeepGEMM; fall back to plain BF16 matmul if unavailable.
# ---------------------------------------------------------------------------
try:
import deep_gemm
_HAS_DEEP_GEMM = True
except ImportError:
_HAS_DEEP_GEMM = False
NUM_SPLITS = 16 # K-split count for tf32_hc_prenorm_gemm numerical stability
EPS_RMSN = 1e-6
# ---------------------------------------------------------------------------
# Sinkhorn-Knopp projection (T batched 4×4 matrices, 20 iters)
# ---------------------------------------------------------------------------
def sinkhorn_knopp(
M: torch.Tensor, # (T, n, n) positive (after exp)
t_max: int = 20,
) -> torch.Tensor:
"""
Project each (n×n) positive matrix onto the Birkhoff polytope
(doubly stochastic matrices) via alternating row/col normalisation.
Paper eq. (8): M^(t) = T_r( T_c( M^(t-1) ) )
where T_r = row-normalise, T_c = col-normalise.
For n=4 and t_max=20 this is ~160 tiny operations — no kernel needed.
All ops stay on GPU via standard PyTorch.
"""
for _ in range(t_max):
M = M / (M.sum(dim=-1, keepdim=True) + EPS_RMSN) # T_r (row)
M = M / (M.sum(dim=-2, keepdim=True) + EPS_RMSN) # T_c (col)
return M
# ---------------------------------------------------------------------------
# Context carried between pre_block and post_block
# ---------------------------------------------------------------------------
@dataclass
class mHCContext:
"""Holds the per-token mixing matrices computed in pre_block."""
B_l: torch.Tensor # (T, n_hc, n_hc) doubly stochastic residual transform
C_l: torch.Tensor # (T, n_hc) output mapping (before unsqueeze)
# ---------------------------------------------------------------------------
# mHC layer
# ---------------------------------------------------------------------------
class mHCLayer:
"""
Wraps one transformer sub-layer (attention *or* MoE) with the mHC
residual update.
Typical call pattern per layer:
x_in, ctx = mhc.pre_block(X_l)
F_out = transformer_sublayer(x_in) # (T, d)
X_next = mhc.post_block(X_l, F_out, ctx)
where X_l has shape (T, n_hc, d) — the expanded residual state.
The first call at layer 0 should use X_0 initialised via `init_state`.
"""
def __init__(
self,
hidden_dim: int = 7168,
n_hc: int = 4,
t_max_sinkhorn: int = 20,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
self.d = hidden_dim
self.n_hc = n_hc
self.K_proj = n_hc * hidden_dim # 28672 for V4-Pro
self.N_proj = n_hc + n_hc * n_hc + n_hc # 4 + 16 + 4 = 24
self.t_max = t_max_sinkhorn
self.device = device
self.dtype = dtype
# ── Learnable weights (set via load_weights) ──────────────────
# Stacked projection: b shape = (N_proj, K_proj) in FP32
# Stored as separate tensors, fused in forward if DeepGEMM available.
self.W_pre = self._buf(n_hc, self.K_proj, dtype=torch.float32) # (4, K)
self.W_res = self._buf(n_hc * n_hc, self.K_proj, dtype=torch.float32) # (16, K)
self.W_post = self._buf(n_hc, self.K_proj, dtype=torch.float32) # (4, K)
# Static biases (eq. 3-5, S^pre / S^res / S^post)
self.S_pre = self._buf(1, n_hc) # (1, 4)
self.S_res = self._buf(n_hc, n_hc) # (4, 4)
self.S_post = self._buf(n_hc, 1) # (4, 1)
# Learnable gating scalars (α), initialised small during training
# At inference these are just scalars loaded from the checkpoint.
self.alpha_pre = torch.zeros(1, device=device, dtype=torch.float32)
self.alpha_res = torch.zeros(1, device=device, dtype=torch.float32)
self.alpha_post = torch.zeros(1, device=device, dtype=torch.float32)
# Pre-allocated split buffers (set in _ensure_buffers)
self._d_split = None # (NUM_SPLITS, max_T, N_proj) FP32
self._sqr_sum_split = None # (NUM_SPLITS, max_T) FP32
self._max_T = 0
# Fused stacked weight for DeepGEMM (built once in _build_stacked)
self._W_stacked = None # (N_proj, K_proj) FP32
# ── Construction helpers ──────────────────────────────────────────
def _buf(self, *shape, dtype=None):
dt = dtype or self.dtype
return torch.empty(*shape, dtype=dt, device=self.device)
def load_weights(
self,
W_pre: torch.Tensor, # (n_hc, K) FP32
W_res: torch.Tensor, # (n_hc², K) FP32
W_post: torch.Tensor, # (n_hc, K) FP32
S_pre: torch.Tensor, # (1, n_hc)
S_res: torch.Tensor, # (n_hc, n_hc)
S_post: torch.Tensor, # (n_hc, 1)
alpha_pre: float,
alpha_res: float,
alpha_post: float,
):
"""
Load all mHC parameters from the checkpoint.
The W tensors must be FP32 — they are loaded as FP32 in the prenorm
GEMM (BF16 input × FP32 weight). Everything else can be BF16 in the
checkpoint and will be cast here.
"""
def _f32(t): return t.to(device=self.device, dtype=torch.float32).contiguous()
def _cvt(t): return t.to(device=self.device, dtype=self.dtype).contiguous()
self.W_pre = _f32(W_pre)
self.W_res = _f32(W_res)
self.W_post = _f32(W_post)
self.S_pre = _cvt(S_pre)
self.S_res = _cvt(S_res)
self.S_post = _cvt(S_post)
self.alpha_pre = torch.tensor(alpha_pre, dtype=torch.float32, device=self.device)
self.alpha_res = torch.tensor(alpha_res, dtype=torch.float32, device=self.device)
self.alpha_post = torch.tensor(alpha_post, dtype=torch.float32, device=self.device)
self._W_stacked = None # invalidate cache
def _build_stacked(self):
"""Fuse W_pre / W_res / W_post into one (N_proj, K_proj) FP32 tensor."""
self._W_stacked = torch.cat([self.W_pre, self.W_res, self.W_post], dim=0)
# Must be K-major (contiguous along K) for DeepGEMM
self._W_stacked = self._W_stacked.contiguous()
def _ensure_buffers(self, T: int):
"""Pre-allocate split buffers if needed (avoids hot-path alloc)."""
if T <= self._max_T:
return
self._d_split = torch.empty(
NUM_SPLITS, T, self.N_proj, dtype=torch.float32, device=self.device
)
self._sqr_sum_split = torch.empty(
NUM_SPLITS, T, dtype=torch.float32, device=self.device
)
self._max_T = T
# ── Forward ──────────────────────────────────────────────────────
def _project_and_rms(self, X_flat: torch.Tensor) -> torch.Tensor:
"""
Compute RMSNorm(X_flat) @ W_stacked.T → (T, N_proj) FP32.
Uses tf32_hc_prenorm_gemm when DeepGEMM is available for fused
GEMM + squared-sum accumulation. Falls back to plain BF16 matmul.
X_flat: (T, K_proj) BF16
"""
T = X_flat.shape[0]
K = self.K_proj
if _HAS_DEEP_GEMM:
if self._W_stacked is None:
self._build_stacked()
self._ensure_buffers(T)
d_s = self._d_split[:, :T, :] # view, no copy
ss_s = self._sqr_sum_split[:, :T]
# a: (T, K) BF16 b: (N, K) FP32 → d_s: (S, T, N), ss_s: (S, T)
# Both d and sqr_sum are OUTPUT tensors (written by the kernel).
deep_gemm.tf32_hc_prenorm_gemm(
X_flat.contiguous(), # a
self._W_stacked, # b (N, K) FP32
d_s, # d (S, T, N)
ss_s, # sqr_sum (S, T)
num_splits=NUM_SPLITS,
)
d_out = d_s.sum(dim=0) # (T, N)
sqr_sum = ss_s.sum(dim=0) # (T,)
else:
# Fallback: BF16 matmul + manual squared sum
if self._W_stacked is None:
self._build_stacked()
x_f32 = X_flat.float()
d_out = x_f32 @ self._W_stacked.T # (T, N)
sqr_sum = x_f32.pow(2).sum(dim=-1) # (T,)
# RMSNorm scale: multiply raw GEMM output by rsqrt(mean(x²))
# mean(x²) = sqr_sum / K → scale = sqrt(K / sqr_sum)
rms_scale = torch.sqrt(K / (sqr_sum + EPS_RMSN)) # (T,)
return (d_out * rms_scale.unsqueeze(-1)).to(self.dtype) # (T, N) in BF16
def _dynamic_params(
self, X_l: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute per-token A_l, B_l, C_l from the current residual state.
X_l: (T, n_hc, d)
Returns:
A_l: (T, n_hc) sigmoid-constrained input mapping
B_l: (T, n_hc, n_hc) doubly-stochastic residual transform
C_l: (T, n_hc) 2*sigmoid-constrained output mapping
"""
T, n, d = X_l.shape
assert n == self.n_hc and d == self.d
# Flatten: (T, n_hc*d)
X_flat = X_l.reshape(T, self.K_proj).to(self.dtype)
# Fused RMSNorm projection: (T, N_proj)
proj = self._project_and_rms(X_flat).float() # keep FP32 for precision
# Split into raw A / B / C
i0, i1, i2, i3 = 0, self.n_hc, self.n_hc + self.n_hc**2, self.N_proj
A_raw = proj[:, i0:i1] # (T, n_hc)
B_raw = proj[:, i1:i2] # (T, n_hc²)
C_raw = proj[:, i2:i3] # (T, n_hc)
# Add static biases and scale by learned gating factors (eq. 3-5)
S_pre = self.S_pre.float() # (1, n_hc)
S_res = self.S_res.float() # (n_hc, n_hc)
S_post = self.S_post.float() # (n_hc, 1)
A_tilde = self.alpha_pre * A_raw + S_pre # (T, n_hc)
B_tilde = self.alpha_res * B_raw + S_res.flatten().unsqueeze(0) # (T, n_hc²)
C_tilde = self.alpha_post * C_raw + S_post.flatten().unsqueeze(0) # (T, n_hc)
# Apply constraints (paper eqs. 6-8)
A_l = torch.sigmoid(A_tilde) # (T, n_hc)
C_l = 2.0 * torch.sigmoid(C_tilde) # (T, n_hc)
# B_l: exp → Sinkhorn-Knopp → doubly stochastic
B_exp = torch.exp(B_tilde).reshape(T, self.n_hc, self.n_hc)
B_l = sinkhorn_knopp(B_exp, t_max=self.t_max) # (T, n_hc, n_hc)
# Keep B_l in FP32 — the (T,4,4) bmm precision matters more than memory.
# A_l and C_l are cast to dtype for the input/output mixing multiplies.
return A_l.to(self.dtype), B_l, C_l.to(self.dtype)
# ----------------------------------------------------------------
# Public API: pre_block / post_block
# ----------------------------------------------------------------
def pre_block(
self,
X_l: torch.Tensor, # (T, n_hc, d) BF16
) -> Tuple[torch.Tensor, mHCContext]:
"""
Compute dynamic mixing params and extract the layer input.
Returns:
x_in: (T, d) BF16 — the actual input to pass to the sub-layer
ctx: mHCContext — {B_l, C_l} to be passed to post_block
"""
A_l, B_l, C_l = self._dynamic_params(X_l)
# Layer input: x_in = A_l @ X_l (per token, weighted sum of streams)
# A_l: (T, n_hc) X_l: (T, n_hc, d)
# → (T, 1, n_hc) bmm (T, n_hc, d) = (T, 1, d) → squeeze
x_in = torch.bmm(A_l.unsqueeze(1), X_l).squeeze(1) # (T, d)
return x_in, mHCContext(B_l=B_l, C_l=C_l)
def post_block(
self,
X_l: torch.Tensor, # (T, n_hc, d) BF16 — residual state BEFORE sub-layer
F_out: torch.Tensor, # (T, d) BF16 — sub-layer output
ctx: mHCContext,
) -> torch.Tensor:
"""
Apply the mHC residual update (eq. 1):
X_{l+1} = B_l @ X_l + C_l ⊗ F_out
Returns:
X_next: (T, n_hc, d) BF16
"""
# B_l is FP32, X_l is BF16 — bmm upcasts automatically in PyTorch.
BX = torch.bmm(ctx.B_l, X_l.float())
CF = ctx.C_l.unsqueeze(-1) * F_out.unsqueeze(1) # (T, n_hc, d)
return (BX + CF.float()).to(self.dtype) # (T, n_hc, d)
# ----------------------------------------------------------------
# Utility
# ----------------------------------------------------------------
@staticmethod
def init_state(
embeddings: torch.Tensor, # (T, d) BF16 — token embeddings
n_hc: int = 4,
) -> torch.Tensor:
"""
Initialise X_0 for the first layer.
The paper figure shows the embedding feeding into the first
Residual Mixing. We broadcast the embedding across all n_hc
residual streams as the simplest valid initialisation.
Returns: (T, n_hc, d) BF16
"""
return embeddings.unsqueeze(1).expand(-1, n_hc, -1).clone()
@staticmethod
def read_out(X_L: torch.Tensor) -> torch.Tensor:
"""
Extract the final hidden state from the last residual state.
Convention: stream 0 is the primary output stream (standard choice
for HC models — the first stream carries the main residual).
Returns: (T, d) BF16
"""
return X_L[:, 0, :]
# ---------------------------------------------------------------------------
# Quick smoke test
# ---------------------------------------------------------------------------
if __name__ == "__main__":
import sys
torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16
D, N_HC = 7168, 4
K = N_HC * D # 28672
N_PROJ = N_HC + N_HC ** 2 + N_HC # 24
mhc = mHCLayer(hidden_dim=D, n_hc=N_HC, device=device, dtype=dtype)
# Random weights matching the expected shapes
mhc.load_weights(
W_pre = torch.randn(N_HC, K, dtype=torch.float32),
W_res = torch.randn(N_HC**2, K, dtype=torch.float32),
W_post = torch.randn(N_HC, K, dtype=torch.float32),
S_pre = torch.zeros(1, N_HC, dtype=dtype),
S_res = torch.eye(N_HC, dtype=dtype), # identity: pure residual
S_post = torch.zeros(N_HC, 1, dtype=dtype),
alpha_pre = 0.01,
alpha_res = 0.01,
alpha_post = 0.01,
)
T = 4 # 4 tokens
# ── Forward pass ────────────────────────────────────────────────
embeddings = torch.randn(T, D, dtype=dtype, device=device)
X = mHCLayer.init_state(embeddings, n_hc=N_HC)
print(f"X_0: {X.shape} (T={T}, n_hc={N_HC}, d={D})")
# Simulate a 2-layer stack
for layer_idx in range(2):
x_in, ctx = mhc.pre_block(X)
print(f"\nLayer {layer_idx}:")
print(f" x_in (to sub-layer): {x_in.shape}")
print(f" B_l: {ctx.B_l.shape}")
print(f" C_l: {ctx.C_l.shape}")
# Dummy sub-layer: identity (for testing the mHC mechanics)
F_out = x_in
X = mhc.post_block(X, F_out, ctx)
print(f" X_next: {X.shape}")
hidden = mHCLayer.read_out(X)
print(f"\nFinal hidden: {hidden.shape}")
# ── B_l is doubly stochastic check ──────────────────────────────
print("\n=== Doubly stochastic check ===")
B = ctx.B_l # (T, 4, 4) — FP32 from Sinkhorn
row_sums = B.sum(dim=-1) # (T, 4) — should all be ~1
col_sums = B.sum(dim=-2) # (T, 4) — should all be ~1
print(f" row sum range: [{row_sums.min():.6f}, {row_sums.max():.6f}] (want ≈ 1.0)")
print(f" col sum range: [{col_sums.min():.6f}, {col_sums.max():.6f}] (want ≈ 1.0)")
assert (row_sums - 1).abs().max() < 1e-3, "B_l rows do not sum to 1"
assert (col_sums - 1).abs().max() < 1e-3, "B_l cols do not sum to 1"
print(" PASSED")
# ── A_l and C_l are bounded ──────────────────────────────────────
# (Re-run dynamic params to expose A_l for checking)
A_l, B_l2, C_l = mhc._dynamic_params(X)
print(f"\n=== A_l ∈ (0,1) check ===")
print(f" A_l range: [{A_l.min():.4f}, {A_l.max():.4f}] (want ∈ (0,1))")
assert A_l.min() > 0 and A_l.max() < 1, "A_l out of sigmoid range"
print(" PASSED")
print(f"\n=== C_l ∈ (0,2) check ===")
print(f" C_l range: [{C_l.min():.4f}, {C_l.max():.4f}] (want ∈ (0,2))")
assert C_l.min() > 0 and C_l.max() < 2, "C_l out of 2*sigmoid range"
print(" PASSED")
# ── Consistency: S_res = identity → B_l ≈ doubly-stochastic I ───
print("\n=== S_res=I, alpha_res≈0 → B_l ≈ uniform matrix ===")
# With S_res = I and alpha_res ≈ 0:
# B_tilde ≈ I → exp(I) → Sinkhorn of exp(I)
# exp(I) is diag-dominant; after Sinkhorn it converges to a doubly stochastic matrix.
# We just check doubly-stochastic property is preserved (already checked above).
print(" Already verified via doubly stochastic check above.")
# ── Equivalence: T=1 decode vs T=N prefill ──────────────────────
print("\n=== Token-by-token decode == batch prefill ===")
T_big = 8
h_big = torch.randn(T_big, D, dtype=dtype, device=device)
X_batch = mHCLayer.init_state(h_big, n_hc=N_HC)
# Batch
x_in_batch, ctx_batch = mhc.pre_block(X_batch)
# Token by token
x_in_tokens = []
for t in range(T_big):
X_t = X_batch[t:t+1] # (1, n_hc, d)
x_in_t, _ = mhc.pre_block(X_t)
x_in_tokens.append(x_in_t)
x_in_seq = torch.cat(x_in_tokens, dim=0) # (T_big, d)
diff = (x_in_batch - x_in_seq).abs().max().item()
print(f" max |batch - sequential| on x_in: {diff:.6f}")
assert diff < 1e-2, f"Mismatch too large: {diff}"
print(" PASSED")
print("\nAll checks done.")
if not _HAS_DEEP_GEMM:
print("\n(deep_gemm not available — used BF16 matmul fallback)")