Files
nvfp4-megamoe-kernel/single_shot_inference.py

860 lines
35 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""Single-shot DSV4-Pro inference — Full 61-layer pipeline, 8-GPU.
This is a reference implementation that exercises the production kernel
stack end-to-end. It should be usable as ground truth when integrating
into vLLM or SGLang.
Architecture (paper §2):
X_l → mHC.pre_block → RMSNorm → Attention → F_attn → mHC.post_block → X_mid
X_mid → mHC.pre_block → RMSNorm → FFN(MoE) → F_ffn → mHC.post_block → X_{l+1}
Components exercised:
- mHC (Manifold-Constrained Hyper-Connections) — proper Sinkhorn-Knopp
- Low-rank Q projection (q_a → q_b) + KV projection (MQA, 1 KV head)
- Partial RoPE (last 64 dims, GPT-J interleaved)
- Production FMHA kernel (6-warp multi-tile, C API + ctypes)
- Inverse RoPE on attention output (paper §2.3.3)
- Grouped output projection (wo_a BMM + wo_b NVFP4)
- Routed MoE (384 experts, top-6, hash + dense routing, SwiGLU clamp)
- Shared expert (NVFP4 gate/up/down)
- RMSNorm (pre-norm before each sub-block)
- KV cache across decode steps
Attention type simplification for this single-shot test:
For short sequences (seq_len ≤ sliding_window=128), ALL attention
types (CSA/HCA/SWA) reduce to dense attention over the full KV cache.
CSA's compressed branch and indexer are only needed for long sequences
where seq_len > sliding_window. HCA is dense over compressed entries,
but at short sequence lengths, the compressed sequence is trivially
small. So we use dense MQA attention over the full KV for all layers.
This is mathematically correct for short sequences and exercises the
FMHA kernel properly.
Usage (on B200):
source /root/dsv4-nvfp4-workspace/venv/bin/activate
cd /root/dsv4-nvfp4-workspace/kernel
python3 single_shot_inference.py
"""
import os, sys, time, json, math
import torch
from pathlib import Path
# =====================================================================
# Configuration
# =====================================================================
CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
MAX_NEW_TOKENS = 10
PROMPT = "The capital of France is"
NUM_GPUS = 8
# =====================================================================
# NVFP4 dequantization — matches checkpoint format exactly
# =====================================================================
FP4_LUT = torch.tensor([0., 2., 3., 4., 6., 8., 12., 24.])
def dequant_nvfp4_weight(weight, weight_scale, weight_scale_2):
"""Dequantize NVFP4 weight to BF16.
weight: (out_dim, in_dim//2) uint8 — 2 FP4 values per byte
weight_scale: (out_dim, in_dim//16) E4M3 — per-16-element block scale
weight_scale_2: (out_dim, 1) float32 — per-row global scale
"""
out_dim = weight.shape[0]
in_packed = weight.shape[1]
in_features = in_packed * 2
low = (weight & 0x0F).to(torch.int8)
high = (weight >> 4).to(torch.int8)
low_sign, low_idx = (low >> 3).bool(), (low & 0x07).long()
high_sign, high_idx = (high >> 3).bool(), (high & 0x07).long()
lut = FP4_LUT.to(device=weight.device, dtype=torch.float32)
low_f = lut[low_idx] * torch.where(low_sign, -1.0, 1.0)
high_f = lut[high_idx] * torch.where(high_sign, -1.0, 1.0)
w_f = torch.stack([low_f, high_f], dim=-1).reshape(out_dim, in_features)
scale_f = weight_scale.float() * weight_scale_2.float()
scale_expanded = scale_f.repeat_interleave(16, dim=1)
return (w_f * scale_expanded).bfloat16()
def nvfp4_linear(x, weight, weight_scale, weight_scale_2):
"""BF16 linear with NVFP4 dequant."""
w = dequant_nvfp4_weight(weight, weight_scale, weight_scale_2)
return torch.nn.functional.linear(x, w)
# =====================================================================
# RMSNorm — matches dsv4/layers/norm.py
# =====================================================================
class RMSNorm:
def __init__(self, hidden_size, eps=1e-6, device='cuda:0'):
self.eps = eps
self.weight = torch.ones(hidden_size, dtype=torch.float32, device=device)
def forward(self, x):
"""x: (T, H) BF16 → (T, H) BF16"""
x_f = x.float()
rms = x_f.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
return (x_f * rms * self.weight).to(torch.bfloat16)
# =====================================================================
# mHC — proper Sinkhorn-Knopp implementation
# =====================================================================
class mHCBlock:
"""Manifold-Constrained Hyper-Connections (paper §2.2).
This is a standalone implementation matching dsv4/layers/mhc.py
but self-contained for single-shot inference. Uses the BF16 matmul
fallback (no DeepGEMM dependency).
Weight mapping from checkpoint:
fn: (N_proj, K_proj) FP32 — fused projection [W_pre; W_res; W_post]
base: (N_proj,) — static biases [S_pre; S_res; S_post]
scale: (3,) — gating alphas [alpha_pre, alpha_res, alpha_post]
"""
def __init__(self, hidden_dim=7168, n_hc=4, sinkhorn_iters=20, device='cuda:0'):
self.d = hidden_dim
self.n_hc = n_hc
self.K_proj = n_hc * hidden_dim # 28672
self.N_proj = n_hc + n_hc**2 + n_hc # 4 + 16 + 4 = 24
self.sinkhorn_iters = sinkhorn_iters
self.device = device
self.eps = 1e-6
# Weights — set by load_from_checkpoint
self.W_stacked = None # (N_proj, K_proj) FP32
self.S_pre = None # (1, n_hc) FP32
self.S_res = None # (n_hc, n_hc) FP32
self.S_post = None # (n_hc, 1) FP32
self.alpha_pre = None # float
self.alpha_res = None # float
self.alpha_post = None # float
def load_from_checkpoint(self, fn, base, scale):
"""Load from checkpoint tensors.
fn: (24, 28672) FP32
base: (24,)
scale: (3,) — [alpha_pre, alpha_res, alpha_post]
"""
n = self.n_hc
self.W_stacked = fn.to(device=self.device, dtype=torch.float32).contiguous()
self.S_pre = base[0:n].reshape(1, n).to(device=self.device, dtype=torch.float32)
self.S_res = base[n:n + n*n].reshape(n, n).to(device=self.device, dtype=torch.float32)
self.S_post = base[n + n*n:].reshape(n, 1).to(device=self.device, dtype=torch.float32)
self.alpha_pre = scale[0].item() if scale.numel() > 0 else 0.01
self.alpha_res = scale[1].item() if scale.numel() > 1 else 0.01
self.alpha_post = scale[2].item() if scale.numel() > 2 else 0.01
def _project_and_rms(self, X_flat):
"""Compute RMSNorm(X_flat) @ W_stacked.T → (T, N_proj) BF16.
X_flat: (T, K_proj) BF16
"""
T = X_flat.shape[0]
K = self.K_proj
x_f32 = X_flat.float()
d_out = x_f32 @ self.W_stacked.T # (T, N_proj)
sqr_sum = x_f32.pow(2).sum(dim=-1) # (T,)
rms_scale = torch.sqrt(K / (sqr_sum + self.eps)) # (T,)
return (d_out * rms_scale.unsqueeze(-1)).bfloat16() # (T, N_proj) BF16
def _dynamic_params(self, X_l):
"""Compute A_l, B_l, C_l from residual state.
X_l: (T, n_hc, d) BF16
Returns: A_l (T, n_hc), B_l (T, n_hc, n_hc) FP32, C_l (T, n_hc)
"""
T, n, d = X_l.shape
X_flat = X_l.reshape(T, self.K_proj) # (T, K_proj)
proj = self._project_and_rms(X_flat).float() # (T, N_proj) FP32
# Split projection
i0, i1, i2, i3 = 0, n, n + n*n, self.N_proj
A_raw = proj[:, i0:i1]
B_raw = proj[:, i1:i2]
C_raw = proj[:, i2:i3]
# Add biases and scale by gating alphas (paper eqs. 3-5)
A_tilde = self.alpha_pre * A_raw + self.S_pre
B_tilde = self.alpha_res * B_raw + self.S_res.flatten().unsqueeze(0)
C_tilde = self.alpha_post * C_raw + self.S_post.flatten().unsqueeze(0)
# Apply constraints
A_l = torch.sigmoid(A_tilde) # (T, n_hc) ∈ (0,1)
C_l = 2.0 * torch.sigmoid(C_tilde) # (T, n_hc) ∈ (0,2)
# B_l: exp → Sinkhorn-Knopp → doubly stochastic
B_exp = torch.exp(B_tilde).reshape(T, n, n) # (T, n_hc, n_hc)
B_l = self._sinkhorn_knopp(B_exp)
return A_l.bfloat16(), B_l, C_l.bfloat16()
def _sinkhorn_knopp(self, M, t_max=None):
"""Project (T, n, n) positive matrices onto Birkhoff polytope.
Alternating row/col normalization, t_max=20 iterations.
Result is doubly stochastic: rows and columns each sum to 1.
"""
if t_max is None:
t_max = self.sinkhorn_iters
for _ in range(t_max):
M = M / (M.sum(dim=-1, keepdim=True) + self.eps) # row norm
M = M / (M.sum(dim=-2, keepdim=True) + self.eps) # col norm
return M
@staticmethod
def init_state(embeddings, n_hc=4):
"""Initialise X_0 from token embeddings.
Convention: broadcast embedding across all n_hc residual streams.
embeddings: (T, d) BF16
Returns: (T, n_hc, d) BF16
"""
return embeddings.unsqueeze(1).expand(-1, n_hc, -1).clone()
def pre_block(self, X_l):
"""Compute dynamic mixing params and extract layer input.
X_l: (T, n_hc, d) BF16
Returns: x_in (T, d) BF16, ctx tuple
"""
A_l, B_l, C_l = self._dynamic_params(X_l)
# x_in = A_l @ X_l — weighted sum of residual streams
x_in = torch.bmm(A_l.unsqueeze(1), X_l).squeeze(1) # (T, d)
return x_in, (B_l, C_l)
def post_block(self, X_l, F_out, ctx):
"""Apply mHC residual update (paper eq. 1):
X_{l+1} = B_l @ X_l + C_l ⊗ F_out
X_l: (T, n_hc, d) BF16 — residual state BEFORE sub-layer
F_out: (T, d) BF16 — sub-layer output
ctx: (B_l, C_l) from pre_block
Returns: X_next (T, n_hc, d) BF16
"""
B_l, C_l = ctx
# B_l is FP32, X_l is BF16 — bmm upcasts automatically
BX = torch.bmm(B_l, X_l.float()) # (T, n_hc, d) FP32
CF = C_l.unsqueeze(-1) * F_out.unsqueeze(1) # (T, n_hc, d) BF16
return (BX + CF.float()).bfloat16()
# =====================================================================
# RoPE — partial, GPT-J interleaved, last rope_dim dims
# =====================================================================
def build_rope_cache(max_pos, rope_dim, device, theta=10000.0):
"""Build cos/sin caches for partial RoPE.
Returns: (cos_cache, sin_cache) each (max_pos, rope_dim//2) BF16
"""
half = rope_dim // 2
freqs = 1.0 / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim))
angles = torch.outer(torch.arange(max_pos, dtype=torch.float32), freqs)
return torch.cos(angles).bfloat16().to(device), torch.sin(angles).bfloat16().to(device)
def apply_rope_partial(x, positions, cos_cache, sin_cache, head_dim, rope_dim):
"""Apply partial GPT-J interleaved RoPE to the last rope_dim dims of each head.
x: (T, n_h, hd) BF16 — already reshaped to per-head
positions: (T,) int64
Returns: (T, n_h, hd) BF16 with RoPE applied to last rope_dim dims
"""
T, n_h, hd = x.shape
nope = hd - rope_dim
half = rope_dim // 2
cos = cos_cache[positions].unsqueeze(1) # (T, 1, half) BF16
sin = sin_cache[positions].unsqueeze(1)
out = x.clone()
x_rope = x[:, :, nope:] # (T, n_h, rope_dim)
x_even = x_rope[:, :, 0::2] # (T, n_h, half)
x_odd = x_rope[:, :, 1::2]
out[:, :, nope:][..., 0::2] = x_even * cos - x_odd * sin
out[:, :, nope:][..., 1::2] = x_even * sin + x_odd * cos
return out
def apply_inverse_rope(o, positions, cos_cache, sin_cache, head_dim, rope_dim):
"""Apply inverse RoPE (conjugate rotation) to attention output.
Paper §2.3.3: after attention, the per-head output carries position
information from the RoPE'd queries/keys. Inverse RoPE removes this
so the output is position-invariant.
o: (T, n_h, hd) BF16
positions: (T,) int64
Returns: (T, n_h, hd) BF16 with inverse RoPE on last rope_dim dims
"""
T, n_h, hd = o.shape
nope = hd - rope_dim
half = rope_dim // 2
cos = cos_cache[positions].unsqueeze(1) # (T, 1, half) BF16
sin = sin_cache[positions].unsqueeze(1)
out = o.clone()
o_rope = o[:, :, nope:]
o_even = o_rope[:, :, 0::2]
o_odd = o_rope[:, :, 1::2]
# Conjugate rotation
out[:, :, nope:][..., 0::2] = o_even * cos + o_odd * sin
out[:, :, nope:][..., 1::2] = -o_even * sin + o_odd * cos
return out
# =====================================================================
# KV cache — BF16 flat, MQA (1 KV head)
# =====================================================================
class SimpleKVCache:
"""Per-layer KV cache for decode. Stores BF16 K,V accumulated across steps.
MQA: 1 KV head, so cache is (1, seq_len, hd) per layer."""
def __init__(self, head_dim, max_seq=8192, device='cuda:0'):
self.hd = head_dim
self.max_seq = max_seq
self.device = device
self.k = torch.zeros(1, max_seq, head_dim, dtype=torch.bfloat16, device=device)
self.v = torch.zeros(1, max_seq, head_dim, dtype=torch.bfloat16, device=device)
self.len = 0
def append(self, k_new, v_new):
"""Append K,V. k_new: (1, T, hd), v_new: (1, T, hd)."""
T = k_new.shape[1]
self.k[0, self.len:self.len + T] = k_new[0]
self.v[0, self.len:self.len + T] = v_new[0]
self.len += T
def get(self):
"""Get K,V up to current length. Returns (1, seq_len, hd) each."""
return self.k[:, :self.len], self.v[:, :self.len]
# =====================================================================
# Weight loading — streams safetensors shards, distributes to 8 GPUs
# =====================================================================
def load_all_weights(checkpoint_dir, num_layers):
"""Load all weights from checkpoint, distribute layers across GPUs.
Returns:
layer_weights: dict[li] → {key: tensor on cuda:li%8}
global_weights: {key: tensor on cuda:0}
"""
from safetensors.torch import load_file
cdir = Path(checkpoint_dir)
index_path = cdir / "model.safetensors.index.json"
weight_map = {}
if index_path.exists():
with open(index_path) as f:
weight_map = json.load(f).get("weight_map", {})
shard_names = set(weight_map.values()) if weight_map else {
f"model-{i:05d}-of-00095.safetensors" for i in range(1, 96)
}
print(f"Loading {len(shard_names)} shards...")
all_weights = {}
loaded = 0
for shard_name in sorted(shard_names):
if not (cdir / shard_name).exists():
continue
data = load_file(str(cdir / shard_name))
all_weights.update(data)
loaded += 1
if loaded % 20 == 0:
print(f" {loaded}/{len(shard_names)} shards, {len(all_weights)} tensors")
print(f" Done: {len(all_weights)} tensors")
layer_weights = {}
global_weights = {}
print("Assigning to GPUs...")
for key, tensor in all_weights.items():
if key.startswith("model.layers."):
li = int(key.split(".")[2])
target_gpu = li % NUM_GPUS
target_device = f"cuda:{target_gpu}"
if li not in layer_weights:
layer_weights[li] = {"_device": target_device, "_gpu": target_gpu}
layer_weights[li][key] = tensor.to(target_device)
elif key.startswith("model.embed_tokens"):
global_weights[key] = tensor.to("cuda:0")
elif key.startswith("model.norm"):
global_weights[key] = tensor.to("cuda:0")
elif key.startswith("lm_head"):
global_weights[key] = tensor.to("cuda:0")
for gpu in range(NUM_GPUS):
alloc = torch.cuda.memory_allocated(gpu) / 1e9
if alloc > 0:
print(f" GPU {gpu}: {alloc:.1f}GB")
return layer_weights, global_weights
# =====================================================================
# Single layer forward
# =====================================================================
def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
attn_mhc, ffn_mhc, attn_norm, ffn_norm,
kv_cache, token_id, positions):
"""Forward one layer with mHC + Attention + FFN.
Architecture (paper §2):
X_l → mHC.pre_block(attn) → RMSNorm → Attention → F_attn → mHC.post_block → X_mid
X_mid → mHC.pre_block(ffn) → RMSNorm → MoE → F_ffn → mHC.post_block → X_{l+1}
X_l: (T, n_hc, H) BF16 — mHC residual state
Returns: X_next (T, n_hc, H) BF16
"""
device = X_l.device
H = cfg["hidden_size"]
n_h = cfg["num_attention_heads"]
hd = cfg["head_dim"]
rd = cfg.get("qk_rope_head_dim", cfg.get("rope_dim", 64))
o_rank = cfg.get("output_group_dim", 1024)
o_groups = cfg.get("num_output_groups", 16)
n_hc = 4
pre = f"model.layers.{li}.self_attn"
T = X_l.shape[0]
heads_per_group = n_h // o_groups
group_input_dim = heads_per_group * hd
# ==================================================================
# ATTENTION SUB-BLOCK
# ==================================================================
# -- mHC pre_block (attention) --
x_in, attn_ctx = attn_mhc.pre_block(X_l) # x_in: (T, H)
# -- RMSNorm (pre-norm before attention) --
x_normed = attn_norm.forward(x_in) # (T, H) BF16
# -- Q projection: q_a (low-rank down) → q_b (low-rank up) --
c_Q = nvfp4_linear(x_normed,
w[f"{pre}.q_a_proj.weight"],
w[f"{pre}.q_a_proj.weight_scale"],
w[f"{pre}.q_a_proj.weight_scale_2"]) # (T, dc)
q = nvfp4_linear(c_Q,
w[f"{pre}.q_b_proj.weight"],
w[f"{pre}.q_b_proj.weight_scale"],
w[f"{pre}.q_b_proj.weight_scale_2"]) # (T, n_h * hd)
# -- KV projection (MQA: 1 KV head) --
kv = nvfp4_linear(x_normed,
w[f"{pre}.kv_proj.weight"],
w[f"{pre}.kv_proj.weight_scale"],
w[f"{pre}.kv_proj.weight_scale_2"]) # (T, hd) — 1 KV head, no split
# -- Reshape for attention --
q_heads = q.reshape(T, n_h, hd) # (T, n_h, hd)
kv_new = kv.reshape(T, 1, hd) # (T, 1, hd) — 1 KV head
# -- Apply RoPE to Q (at current positions) --
q_heads = apply_rope_partial(q_heads, positions, rope_cos, rope_sin, hd, rd)
# -- Apply RoPE to KV (at current positions) BEFORE caching --
# DSV4 convention: RoPE applied to KV before writing to cache.
# K = V in DSV4 MQA (same projection, same RoPE'd tensor).
kv_new = apply_rope_partial(kv_new, positions, rope_cos, rope_sin, hd, rd)
# -- KV cache: append RoPE'd KV (K=V) --
k_new = kv_new # (T, 1, hd) — RoPE'd
v_new = kv_new # K = V in DSV4 MQA
kv_cache.append(k_new.permute(1, 0, 2), v_new.permute(1, 0, 2)) # (1, T, hd)
# -- Get full KV from cache (already RoPE'd) --
k_full, v_full = kv_cache.get() # (1, seq_len, hd) each — RoPE'd, K=V
# -- FMHA: (n_h, T, hd) × (1, seq_len, hd) → (n_h, T, hd) --
from dsv4.kernels.attention.production import dsv4_attention
q_input = q_heads.permute(1, 0, 2) # (n_h, T, hd)
k_input = k_full.permute(1, 0, 2) # (1, seq_len, hd) — already RoPE'd
v_input = v_full.permute(1, 0, 2) # (1, seq_len, hd) — K=V, RoPE'd
attn_out = dsv4_attention(q_input, k_input, v_input) # (n_h, T, hd)
attn_out = attn_out.permute(1, 0, 2) # (T, n_h, hd)
# -- Inverse RoPE on attention output (paper §2.3.3) --
attn_out = apply_inverse_rope(attn_out, positions, rope_cos, rope_sin, hd, rd)
# -- Output projection: wo_a (grouped BMM) + wo_b (NVFP4) --
# wo_a: grouped linear, (n_h, hd) → (n_groups, o_rank) via BMM
attn_flat = attn_out.reshape(T, n_h * hd) # (T, n_h * hd)
attn_grouped = attn_flat.reshape(T, o_groups, heads_per_group * hd) # (T, groups, group_dim)
oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16() # (n_groups * o_rank, group_input_dim) BF16
oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim) # (groups, o_rank, group_dim)
attn_for_bmm = attn_grouped.permute(1, 0, 2) # (groups, T, group_dim)
grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2)) # (groups, T, o_rank)
grouped_flat = grouped_out.permute(1, 0, 2).reshape(T, o_groups * o_rank) # (T, groups*o_rank)
F_attn = nvfp4_linear(grouped_flat,
w[f"{pre}.o_b_proj.weight"],
w[f"{pre}.o_b_proj.weight_scale"],
w[f"{pre}.o_b_proj.weight_scale_2"]) # (T, H)
# -- mHC post_block (attention) --
X_mid = attn_mhc.post_block(X_l, F_attn, attn_ctx) # (T, n_hc, H)
# ==================================================================
# FFN SUB-BLOCK
# ==================================================================
# -- mHC pre_block (FFN) --
x_ffn, ffn_ctx = ffn_mhc.pre_block(X_mid) # (T, H)
# -- RMSNorm (pre-norm before FFN) --
x_ffn_normed = ffn_norm.forward(x_ffn) # (T, H) BF16
# -- MoE + shared expert --
F_ffn = moe_forward(x_ffn_normed, w, li, cfg, token_id, device)
# -- mHC post_block (FFN) --
X_next = ffn_mhc.post_block(X_mid, F_ffn, ffn_ctx) # (T, n_hc, H)
return X_next
# =====================================================================
# MoE forward — hash + dense routing, SwiGLU with clamping
# =====================================================================
def moe_forward(x, w, li, cfg, token_id, device):
"""Run routed MoE + shared expert.
x: (T, H) BF16 — post-RMSNorm FFN input
Returns: (T, H) BF16
"""
H = cfg["hidden_size"]
n_experts = cfg["n_routed_experts"]
top_k = cfg.get("num_experts_per_tok", 6)
routed_scaling = cfg.get("routed_scaling_factor", 2.5)
swiglu_limit = cfg.get("swiglu_limit", 10.0)
mlp_inter = cfg["moe_intermediate_size"]
is_hash = li < 3
# ---- Routing ----
expert_ids = None
expert_weights = None
if is_hash:
tid2eid_key = f"model.layers.{li}.mlp.gate.tid2eid"
if tid2eid_key in w:
tid2eid = w[tid2eid_key]
tid = token_id.item() if token_id.numel() == 1 else token_id[0].item()
expert_ids = tid2eid[tid] # (top_k,) int64
expert_weights = torch.ones(top_k, dtype=torch.float32, device=x.device) / top_k
else:
# Fallback: use dense routing even for hash layers
is_hash = False
if not is_hash:
# Dense routing: sqrt(softplus(X @ W_gate)) + e_bias for selection
gate_w = w[f"model.layers.{li}.mlp.gate.weight"] # (H, n_experts) BF16
logits = torch.nn.functional.linear(x, gate_w.bfloat16()) # (T, n_experts)
# Activation: sqrt(softplus(logits))
activated = torch.sqrt(torch.nn.functional.softplus(logits.float()) + 1e-6)
# e_bias: learned per-expert bias for SELECTION ONLY (not in weights)
e_bias_key = f"model.layers.{li}.mlp.gate.e_bias"
if e_bias_key in w:
activated = activated + w[e_bias_key].float().unsqueeze(0)
# Top-k
scores, indices = activated.topk(top_k, dim=-1) # (T, top_k)
# Renormalize on UNBIASED activation (no e_bias in weights)
unbiased = torch.sqrt(torch.nn.functional.softplus(logits.float()) + 1e-6)
unbiased_scores = torch.gather(unbiased, -1, indices)
expert_weights = unbiased_scores / unbiased_scores.sum(dim=-1, keepdim=True)
# For T=1 decode, squeeze
if x.shape[0] == 1:
expert_ids = indices[0]
expert_weights = expert_weights[0]
else:
raise NotImplementedError("Multi-token MoE routing")
# ---- Run selected experts ----
T = x.shape[0]
expert_outputs = []
for i, eid in enumerate(expert_ids):
eid_int = eid.item()
epre = f"model.layers.{li}.mlp.experts.{eid_int}"
gate = nvfp4_linear(x,
w[f"{epre}.gate_proj.weight"],
w[f"{epre}.gate_proj.weight_scale"],
w[f"{epre}.gate_proj.weight_scale_2"])
up = nvfp4_linear(x,
w[f"{epre}.up_proj.weight"],
w[f"{epre}.up_proj.weight_scale"],
w[f"{epre}.up_proj.weight_scale_2"])
# SwiGLU with clamping (paper §4.2.3)
silu_out = torch.nn.functional.silu(gate.float())
if swiglu_limit is not None:
silu_out = silu_out.clamp(-swiglu_limit, swiglu_limit)
up_clamped = up.float().clamp(-swiglu_limit, swiglu_limit)
else:
up_clamped = up.float()
hidden = (silu_out * up_clamped).bfloat16()
down = nvfp4_linear(hidden,
w[f"{epre}.down_proj.weight"],
w[f"{epre}.down_proj.weight_scale"],
w[f"{epre}.down_proj.weight_scale_2"])
expert_outputs.append(down)
# Weighted combine + scaling
routed_out = torch.zeros_like(x)
for i, (out, wt) in enumerate(zip(expert_outputs, expert_weights)):
routed_out = routed_out + (out.float() * wt).bfloat16()
routed_out = (routed_out.float() * routed_scaling).bfloat16()
# ---- Shared expert ----
se_pre = f"model.layers.{li}.mlp.shared_experts"
se_gate_key = f"{se_pre}.gate_proj.weight"
if se_gate_key in w:
gate = nvfp4_linear(x,
w[se_gate_key],
w[f"{se_pre}.gate_proj.weight_scale"],
w[f"{se_pre}.gate_proj.weight_scale_2"])
up = nvfp4_linear(x,
w[f"{se_pre}.up_proj.weight"],
w[f"{se_pre}.up_proj.weight_scale"],
w[f"{se_pre}.up_proj.weight_scale_2"])
silu_out = torch.nn.functional.silu(gate.float())
if swiglu_limit is not None:
silu_out = silu_out.clamp(-swiglu_limit, swiglu_limit)
up_clamped = up.float().clamp(-swiglu_limit, swiglu_limit)
else:
up_clamped = up.float()
hidden = (silu_out * up_clamped).bfloat16()
shared_out = nvfp4_linear(hidden,
w[f"{se_pre}.down_proj.weight"],
w[f"{se_pre}.down_proj.weight_scale"],
w[f"{se_pre}.down_proj.weight_scale_2"])
else:
shared_out = torch.zeros_like(x)
return routed_out + shared_out
# =====================================================================
# Main
# =====================================================================
def main():
t_start = time.time()
print("=" * 70)
print("DSV4 Single-Shot Inference — Full Pipeline (mHC+Attn+MoE)")
print(" Proper Sinkhorn mHC, RMSNorm, inverse RoPE, production FMHA")
print("=" * 70)
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
cfg = json.load(f)
n_layers = cfg["num_hidden_layers"]
H = cfg["hidden_size"]
n_h = cfg["num_attention_heads"]
hd = cfg["head_dim"]
rd = cfg.get("qk_rope_head_dim", cfg.get("rope_dim", 64))
n_hc = 4
print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}")
print(f"Experts: {cfg['n_routed_experts']}, top-{cfg.get('num_experts_per_tok', 6)}")
# ==== Phase 1: Load weights ====
print(f"\n{'='*70}\nPhase 1: Loading weights\n{'='*70}")
layer_weights, global_weights = load_all_weights(CHECKPOINT_DIR, n_layers)
t_loaded = time.time()
print(f"Weight loading: {t_loaded - t_start:.1f}s")
# ==== Build mHC blocks (proper Sinkhorn) ====
print("Building mHC blocks...")
attn_mhc_blocks = {}
ffn_mhc_blocks = {}
attn_norms = {}
ffn_norms = {}
for li in range(n_layers):
gpu = li % NUM_GPUS
dev = f"cuda:{gpu}"
# mHC blocks
for prefix, blocks in [(f"model.layers.{li}.attn_hc", attn_mhc_blocks),
(f"model.layers.{li}.ffn_hc", ffn_mhc_blocks)]:
fn = layer_weights[li].get(f"{prefix}.fn")
base = layer_weights[li].get(f"{prefix}.base")
scale = layer_weights[li].get(f"{prefix}.scale")
if fn is not None and base is not None and scale is not None:
mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
mhc.load_from_checkpoint(fn, base, scale)
blocks[li] = mhc
else:
# Fallback: identity mHC (A=1, B=I, C=1) — not ideal but prevents crash
print(f" WARNING: no mHC weights for {prefix}, using identity fallback")
mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
n = n_hc
K = n * H
mhc.W_stacked = torch.zeros(n + n*n + n, K, dtype=torch.float32, device=dev)
mhc.S_pre = torch.zeros(1, n, dtype=torch.float32, device=dev)
mhc.S_res = torch.eye(n, dtype=torch.float32, device=dev)
mhc.S_post = torch.ones(n, 1, dtype=torch.float32, device=dev) * 0.5
mhc.alpha_pre = 0.01
mhc.alpha_res = 0.01
mhc.alpha_post = 0.01
blocks[li] = mhc
# RMSNorms (pre-norm before each sub-block)
attn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=dev)
an_key = f"model.layers.{li}.input_layernorm.weight"
if an_key in layer_weights[li]:
attn_norm.weight = layer_weights[li][an_key].to(device=dev, dtype=torch.float32)
attn_norms[li] = attn_norm
ffn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=dev)
fn_key = f"model.layers.{li}.post_attention_layernorm.weight"
if fn_key in layer_weights[li]:
ffn_norm.weight = layer_weights[li][fn_key].to(device=dev, dtype=torch.float32)
ffn_norms[li] = ffn_norm
print(f" attn mHC: {len(attn_mhc_blocks)}, ffn mHC: {len(ffn_mhc_blocks)}")
# ==== Global weights ====
torch.cuda.set_device(0)
embed_w = global_weights.get("model.embed_tokens.weight")
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16())
lm_w = global_weights.get("lm_head.weight", embed_w).bfloat16()
final_norm_w = global_weights.get("model.norm.weight")
rope_caches = {g: build_rope_cache(8192, rd, f"cuda:{g}") for g in range(NUM_GPUS)}
# ==== KV caches (one per layer on its GPU) ====
kv_caches = {}
for li in range(n_layers):
kv_caches[li] = SimpleKVCache(head_dim=hd, max_seq=8192, device=f"cuda:{li % NUM_GPUS}")
# ==== Phase 2: Compile FMHA ====
print(f"\n{'='*70}\nPhase 2: JIT compiling\n{'='*70}")
from dsv4.kernels.attention.production import dsv4_attention
torch.cuda.set_device(0)
dummy_q = torch.randn(n_h, 1, hd, dtype=torch.bfloat16, device='cuda:0')
dummy_k = torch.randn(1, 1, hd, dtype=torch.bfloat16, device='cuda:0')
try:
_ = dsv4_attention(dummy_q, dummy_k, dummy_k.clone())
print(" FMHA: compiled OK")
except Exception as e:
print(f" FMHA error: {e}")
t_compiled = time.time()
print(f"Compile: {t_compiled - t_loaded:.1f}s")
# ==== Phase 3: Inference ====
print(f"\n{'='*70}\nPhase 3: Inference\n{'='*70}")
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
input_ids = tokenizer.encode(PROMPT, return_tensors="pt").cuda()
print(f"Prompt: '{PROMPT}'{input_ids.tolist()}")
generated = input_ids[0].tolist()
# ==== Prefill: process prompt tokens to fill KV cache ====
print(f"Prefilling {len(generated)} prompt tokens...")
for prefill_idx, tid_val in enumerate(generated):
tid = torch.tensor([tid_val], dtype=torch.long, device='cuda:0')
positions = torch.tensor([prefill_idx], dtype=torch.long, device='cuda:0')
emb = embed(tid) # (1, H) on gpu0
X = mHCBlock.init_state(emb, n_hc) # (1, n_hc, H)
for li in range(n_layers):
gpu = li % NUM_GPUS
target_device = f"cuda:{gpu}"
if X.device != torch.device(target_device):
X = X.to(target_device)
torch.cuda.set_device(gpu)
attn_mhc = attn_mhc_blocks.get(li)
ffn_mhc = ffn_mhc_blocks.get(li)
a_norm = attn_norms[li]
f_norm = ffn_norms[li]
rc, rs = rope_caches[gpu]
X = forward_layer(X, layer_weights[li], li, cfg, rc, rs,
attn_mhc, ffn_mhc, a_norm, f_norm,
kv_caches[li], tid, positions)
X = X.to('cuda:0')
torch.cuda.set_device(0)
print(f" Prefill done ({len(generated)} tokens, {time.time()-t_compiled:.1f}s)")
# ==== Decode: generate new tokens ====
print(f"\nDecoding (max {MAX_NEW_TOKENS} new tokens)...")
all_tokens = generated.copy()
for step in range(MAX_NEW_TOKENS):
t0 = time.time()
tid = torch.tensor([all_tokens[-1]], dtype=torch.long, device='cuda:0')
decode_pos = len(all_tokens) - 1
positions = torch.tensor([decode_pos], dtype=torch.long, device='cuda:0')
emb = embed(tid) # (1, H) on gpu0
X = mHCBlock.init_state(emb, n_hc) # (1, n_hc, H)
for li in range(n_layers):
gpu = li % NUM_GPUS
target_device = f"cuda:{gpu}"
if X.device != torch.device(target_device):
X = X.to(target_device)
torch.cuda.set_device(gpu)
attn_mhc = attn_mhc_blocks.get(li)
ffn_mhc = ffn_mhc_blocks.get(li)
a_norm = attn_norms[li]
f_norm = ffn_norms[li]
rc, rs = rope_caches[gpu]
X = forward_layer(X, layer_weights[li], li, cfg, rc, rs,
attn_mhc, ffn_mhc, a_norm, f_norm,
kv_caches[li], tid, positions)
X = X.to('cuda:0')
torch.cuda.set_device(0)
# Read out stream 0 → RMSNorm → lm_head
x_out = X[:, 0, :] # (1, H)
if final_norm_w is not None:
xf = x_out.float()
rms = xf.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
x_out = (xf * rms * final_norm_w.float()).bfloat16()
logits = torch.nn.functional.linear(x_out, lm_w)
next_id = torch.argmax(logits, dim=-1).item()
generated.append(next_id)
all_tokens.append(next_id)
tok_str = tokenizer.decode([next_id])
dt = time.time() - t0
has_nan = torch.isnan(logits.float()).any().item()
has_inf = torch.isinf(logits.float()).any().item()
lmin, lmax = logits.float().min().item(), logits.float().max().item()
x_max = X.abs().max().item()
print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s) "
f"logits=[{lmin:.1f},{lmax:.1f}] nan={has_nan} inf={has_inf} "
f"|X|={x_max:.3f}")
if has_nan or has_inf:
print(" Numerical issue — stopping")
break
if next_id == tokenizer.eos_token_id:
break
out = tokenizer.decode(generated, skip_special_tokens=True)
total = time.time() - t_start
print(f"\n{'='*70}")
print(f"Input: '{PROMPT}'")
print(f"Output: '{out}'")
print(f"Total: {total:.1f}s")
print(f"{'='*70}")
if __name__ == "__main__":
main()