860 lines
35 KiB
Python
860 lines
35 KiB
Python
#!/usr/bin/env python3
|
||
"""Single-shot DSV4-Pro inference — Full 61-layer pipeline, 8-GPU.
|
||
|
||
This is a reference implementation that exercises the production kernel
|
||
stack end-to-end. It should be usable as ground truth when integrating
|
||
into vLLM or SGLang.
|
||
|
||
Architecture (paper §2):
|
||
X_l → mHC.pre_block → RMSNorm → Attention → F_attn → mHC.post_block → X_mid
|
||
X_mid → mHC.pre_block → RMSNorm → FFN(MoE) → F_ffn → mHC.post_block → X_{l+1}
|
||
|
||
Components exercised:
|
||
- mHC (Manifold-Constrained Hyper-Connections) — proper Sinkhorn-Knopp
|
||
- Low-rank Q projection (q_a → q_b) + KV projection (MQA, 1 KV head)
|
||
- Partial RoPE (last 64 dims, GPT-J interleaved)
|
||
- Production FMHA kernel (6-warp multi-tile, C API + ctypes)
|
||
- Inverse RoPE on attention output (paper §2.3.3)
|
||
- Grouped output projection (wo_a BMM + wo_b NVFP4)
|
||
- Routed MoE (384 experts, top-6, hash + dense routing, SwiGLU clamp)
|
||
- Shared expert (NVFP4 gate/up/down)
|
||
- RMSNorm (pre-norm before each sub-block)
|
||
- KV cache across decode steps
|
||
|
||
Attention type simplification for this single-shot test:
|
||
For short sequences (seq_len ≤ sliding_window=128), ALL attention
|
||
types (CSA/HCA/SWA) reduce to dense attention over the full KV cache.
|
||
CSA's compressed branch and indexer are only needed for long sequences
|
||
where seq_len > sliding_window. HCA is dense over compressed entries,
|
||
but at short sequence lengths, the compressed sequence is trivially
|
||
small. So we use dense MQA attention over the full KV for all layers.
|
||
This is mathematically correct for short sequences and exercises the
|
||
FMHA kernel properly.
|
||
|
||
Usage (on B200):
|
||
source /root/dsv4-nvfp4-workspace/venv/bin/activate
|
||
cd /root/dsv4-nvfp4-workspace/kernel
|
||
python3 single_shot_inference.py
|
||
"""
|
||
import os, sys, time, json, math
|
||
import torch
|
||
from pathlib import Path
|
||
|
||
# =====================================================================
|
||
# Configuration
|
||
# =====================================================================
|
||
|
||
CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
||
MAX_NEW_TOKENS = 10
|
||
PROMPT = "The capital of France is"
|
||
NUM_GPUS = 8
|
||
|
||
# =====================================================================
|
||
# NVFP4 dequantization — matches checkpoint format exactly
|
||
# =====================================================================
|
||
|
||
FP4_LUT = torch.tensor([0., 2., 3., 4., 6., 8., 12., 24.])
|
||
|
||
def dequant_nvfp4_weight(weight, weight_scale, weight_scale_2):
|
||
"""Dequantize NVFP4 weight to BF16.
|
||
|
||
weight: (out_dim, in_dim//2) uint8 — 2 FP4 values per byte
|
||
weight_scale: (out_dim, in_dim//16) E4M3 — per-16-element block scale
|
||
weight_scale_2: (out_dim, 1) float32 — per-row global scale
|
||
"""
|
||
out_dim = weight.shape[0]
|
||
in_packed = weight.shape[1]
|
||
in_features = in_packed * 2
|
||
low = (weight & 0x0F).to(torch.int8)
|
||
high = (weight >> 4).to(torch.int8)
|
||
low_sign, low_idx = (low >> 3).bool(), (low & 0x07).long()
|
||
high_sign, high_idx = (high >> 3).bool(), (high & 0x07).long()
|
||
lut = FP4_LUT.to(device=weight.device, dtype=torch.float32)
|
||
low_f = lut[low_idx] * torch.where(low_sign, -1.0, 1.0)
|
||
high_f = lut[high_idx] * torch.where(high_sign, -1.0, 1.0)
|
||
w_f = torch.stack([low_f, high_f], dim=-1).reshape(out_dim, in_features)
|
||
scale_f = weight_scale.float() * weight_scale_2.float()
|
||
scale_expanded = scale_f.repeat_interleave(16, dim=1)
|
||
return (w_f * scale_expanded).bfloat16()
|
||
|
||
|
||
def nvfp4_linear(x, weight, weight_scale, weight_scale_2):
|
||
"""BF16 linear with NVFP4 dequant."""
|
||
w = dequant_nvfp4_weight(weight, weight_scale, weight_scale_2)
|
||
return torch.nn.functional.linear(x, w)
|
||
|
||
|
||
# =====================================================================
|
||
# RMSNorm — matches dsv4/layers/norm.py
|
||
# =====================================================================
|
||
|
||
class RMSNorm:
|
||
def __init__(self, hidden_size, eps=1e-6, device='cuda:0'):
|
||
self.eps = eps
|
||
self.weight = torch.ones(hidden_size, dtype=torch.float32, device=device)
|
||
|
||
def forward(self, x):
|
||
"""x: (T, H) BF16 → (T, H) BF16"""
|
||
x_f = x.float()
|
||
rms = x_f.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
|
||
return (x_f * rms * self.weight).to(torch.bfloat16)
|
||
|
||
|
||
# =====================================================================
|
||
# mHC — proper Sinkhorn-Knopp implementation
|
||
# =====================================================================
|
||
|
||
class mHCBlock:
|
||
"""Manifold-Constrained Hyper-Connections (paper §2.2).
|
||
|
||
This is a standalone implementation matching dsv4/layers/mhc.py
|
||
but self-contained for single-shot inference. Uses the BF16 matmul
|
||
fallback (no DeepGEMM dependency).
|
||
|
||
Weight mapping from checkpoint:
|
||
fn: (N_proj, K_proj) FP32 — fused projection [W_pre; W_res; W_post]
|
||
base: (N_proj,) — static biases [S_pre; S_res; S_post]
|
||
scale: (3,) — gating alphas [alpha_pre, alpha_res, alpha_post]
|
||
"""
|
||
def __init__(self, hidden_dim=7168, n_hc=4, sinkhorn_iters=20, device='cuda:0'):
|
||
self.d = hidden_dim
|
||
self.n_hc = n_hc
|
||
self.K_proj = n_hc * hidden_dim # 28672
|
||
self.N_proj = n_hc + n_hc**2 + n_hc # 4 + 16 + 4 = 24
|
||
self.sinkhorn_iters = sinkhorn_iters
|
||
self.device = device
|
||
self.eps = 1e-6
|
||
|
||
# Weights — set by load_from_checkpoint
|
||
self.W_stacked = None # (N_proj, K_proj) FP32
|
||
self.S_pre = None # (1, n_hc) FP32
|
||
self.S_res = None # (n_hc, n_hc) FP32
|
||
self.S_post = None # (n_hc, 1) FP32
|
||
self.alpha_pre = None # float
|
||
self.alpha_res = None # float
|
||
self.alpha_post = None # float
|
||
|
||
def load_from_checkpoint(self, fn, base, scale):
|
||
"""Load from checkpoint tensors.
|
||
|
||
fn: (24, 28672) FP32
|
||
base: (24,)
|
||
scale: (3,) — [alpha_pre, alpha_res, alpha_post]
|
||
"""
|
||
n = self.n_hc
|
||
self.W_stacked = fn.to(device=self.device, dtype=torch.float32).contiguous()
|
||
self.S_pre = base[0:n].reshape(1, n).to(device=self.device, dtype=torch.float32)
|
||
self.S_res = base[n:n + n*n].reshape(n, n).to(device=self.device, dtype=torch.float32)
|
||
self.S_post = base[n + n*n:].reshape(n, 1).to(device=self.device, dtype=torch.float32)
|
||
self.alpha_pre = scale[0].item() if scale.numel() > 0 else 0.01
|
||
self.alpha_res = scale[1].item() if scale.numel() > 1 else 0.01
|
||
self.alpha_post = scale[2].item() if scale.numel() > 2 else 0.01
|
||
|
||
def _project_and_rms(self, X_flat):
|
||
"""Compute RMSNorm(X_flat) @ W_stacked.T → (T, N_proj) BF16.
|
||
|
||
X_flat: (T, K_proj) BF16
|
||
"""
|
||
T = X_flat.shape[0]
|
||
K = self.K_proj
|
||
x_f32 = X_flat.float()
|
||
d_out = x_f32 @ self.W_stacked.T # (T, N_proj)
|
||
sqr_sum = x_f32.pow(2).sum(dim=-1) # (T,)
|
||
rms_scale = torch.sqrt(K / (sqr_sum + self.eps)) # (T,)
|
||
return (d_out * rms_scale.unsqueeze(-1)).bfloat16() # (T, N_proj) BF16
|
||
|
||
def _dynamic_params(self, X_l):
|
||
"""Compute A_l, B_l, C_l from residual state.
|
||
|
||
X_l: (T, n_hc, d) BF16
|
||
Returns: A_l (T, n_hc), B_l (T, n_hc, n_hc) FP32, C_l (T, n_hc)
|
||
"""
|
||
T, n, d = X_l.shape
|
||
X_flat = X_l.reshape(T, self.K_proj) # (T, K_proj)
|
||
proj = self._project_and_rms(X_flat).float() # (T, N_proj) FP32
|
||
|
||
# Split projection
|
||
i0, i1, i2, i3 = 0, n, n + n*n, self.N_proj
|
||
A_raw = proj[:, i0:i1]
|
||
B_raw = proj[:, i1:i2]
|
||
C_raw = proj[:, i2:i3]
|
||
|
||
# Add biases and scale by gating alphas (paper eqs. 3-5)
|
||
A_tilde = self.alpha_pre * A_raw + self.S_pre
|
||
B_tilde = self.alpha_res * B_raw + self.S_res.flatten().unsqueeze(0)
|
||
C_tilde = self.alpha_post * C_raw + self.S_post.flatten().unsqueeze(0)
|
||
|
||
# Apply constraints
|
||
A_l = torch.sigmoid(A_tilde) # (T, n_hc) ∈ (0,1)
|
||
C_l = 2.0 * torch.sigmoid(C_tilde) # (T, n_hc) ∈ (0,2)
|
||
|
||
# B_l: exp → Sinkhorn-Knopp → doubly stochastic
|
||
B_exp = torch.exp(B_tilde).reshape(T, n, n) # (T, n_hc, n_hc)
|
||
B_l = self._sinkhorn_knopp(B_exp)
|
||
|
||
return A_l.bfloat16(), B_l, C_l.bfloat16()
|
||
|
||
def _sinkhorn_knopp(self, M, t_max=None):
|
||
"""Project (T, n, n) positive matrices onto Birkhoff polytope.
|
||
|
||
Alternating row/col normalization, t_max=20 iterations.
|
||
Result is doubly stochastic: rows and columns each sum to 1.
|
||
"""
|
||
if t_max is None:
|
||
t_max = self.sinkhorn_iters
|
||
for _ in range(t_max):
|
||
M = M / (M.sum(dim=-1, keepdim=True) + self.eps) # row norm
|
||
M = M / (M.sum(dim=-2, keepdim=True) + self.eps) # col norm
|
||
return M
|
||
|
||
@staticmethod
|
||
def init_state(embeddings, n_hc=4):
|
||
"""Initialise X_0 from token embeddings.
|
||
|
||
Convention: broadcast embedding across all n_hc residual streams.
|
||
embeddings: (T, d) BF16
|
||
Returns: (T, n_hc, d) BF16
|
||
"""
|
||
return embeddings.unsqueeze(1).expand(-1, n_hc, -1).clone()
|
||
|
||
def pre_block(self, X_l):
|
||
"""Compute dynamic mixing params and extract layer input.
|
||
|
||
X_l: (T, n_hc, d) BF16
|
||
Returns: x_in (T, d) BF16, ctx tuple
|
||
"""
|
||
A_l, B_l, C_l = self._dynamic_params(X_l)
|
||
# x_in = A_l @ X_l — weighted sum of residual streams
|
||
x_in = torch.bmm(A_l.unsqueeze(1), X_l).squeeze(1) # (T, d)
|
||
return x_in, (B_l, C_l)
|
||
|
||
def post_block(self, X_l, F_out, ctx):
|
||
"""Apply mHC residual update (paper eq. 1):
|
||
X_{l+1} = B_l @ X_l + C_l ⊗ F_out
|
||
|
||
X_l: (T, n_hc, d) BF16 — residual state BEFORE sub-layer
|
||
F_out: (T, d) BF16 — sub-layer output
|
||
ctx: (B_l, C_l) from pre_block
|
||
Returns: X_next (T, n_hc, d) BF16
|
||
"""
|
||
B_l, C_l = ctx
|
||
# B_l is FP32, X_l is BF16 — bmm upcasts automatically
|
||
BX = torch.bmm(B_l, X_l.float()) # (T, n_hc, d) FP32
|
||
CF = C_l.unsqueeze(-1) * F_out.unsqueeze(1) # (T, n_hc, d) BF16
|
||
return (BX + CF.float()).bfloat16()
|
||
|
||
|
||
# =====================================================================
|
||
# RoPE — partial, GPT-J interleaved, last rope_dim dims
|
||
# =====================================================================
|
||
|
||
def build_rope_cache(max_pos, rope_dim, device, theta=10000.0):
|
||
"""Build cos/sin caches for partial RoPE.
|
||
|
||
Returns: (cos_cache, sin_cache) each (max_pos, rope_dim//2) BF16
|
||
"""
|
||
half = rope_dim // 2
|
||
freqs = 1.0 / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim))
|
||
angles = torch.outer(torch.arange(max_pos, dtype=torch.float32), freqs)
|
||
return torch.cos(angles).bfloat16().to(device), torch.sin(angles).bfloat16().to(device)
|
||
|
||
|
||
def apply_rope_partial(x, positions, cos_cache, sin_cache, head_dim, rope_dim):
|
||
"""Apply partial GPT-J interleaved RoPE to the last rope_dim dims of each head.
|
||
|
||
x: (T, n_h, hd) BF16 — already reshaped to per-head
|
||
positions: (T,) int64
|
||
Returns: (T, n_h, hd) BF16 with RoPE applied to last rope_dim dims
|
||
"""
|
||
T, n_h, hd = x.shape
|
||
nope = hd - rope_dim
|
||
half = rope_dim // 2
|
||
|
||
cos = cos_cache[positions].unsqueeze(1) # (T, 1, half) BF16
|
||
sin = sin_cache[positions].unsqueeze(1)
|
||
|
||
out = x.clone()
|
||
x_rope = x[:, :, nope:] # (T, n_h, rope_dim)
|
||
x_even = x_rope[:, :, 0::2] # (T, n_h, half)
|
||
x_odd = x_rope[:, :, 1::2]
|
||
|
||
out[:, :, nope:][..., 0::2] = x_even * cos - x_odd * sin
|
||
out[:, :, nope:][..., 1::2] = x_even * sin + x_odd * cos
|
||
return out
|
||
|
||
|
||
def apply_inverse_rope(o, positions, cos_cache, sin_cache, head_dim, rope_dim):
|
||
"""Apply inverse RoPE (conjugate rotation) to attention output.
|
||
|
||
Paper §2.3.3: after attention, the per-head output carries position
|
||
information from the RoPE'd queries/keys. Inverse RoPE removes this
|
||
so the output is position-invariant.
|
||
|
||
o: (T, n_h, hd) BF16
|
||
positions: (T,) int64
|
||
Returns: (T, n_h, hd) BF16 with inverse RoPE on last rope_dim dims
|
||
"""
|
||
T, n_h, hd = o.shape
|
||
nope = hd - rope_dim
|
||
half = rope_dim // 2
|
||
|
||
cos = cos_cache[positions].unsqueeze(1) # (T, 1, half) BF16
|
||
sin = sin_cache[positions].unsqueeze(1)
|
||
|
||
out = o.clone()
|
||
o_rope = o[:, :, nope:]
|
||
o_even = o_rope[:, :, 0::2]
|
||
o_odd = o_rope[:, :, 1::2]
|
||
|
||
# Conjugate rotation
|
||
out[:, :, nope:][..., 0::2] = o_even * cos + o_odd * sin
|
||
out[:, :, nope:][..., 1::2] = -o_even * sin + o_odd * cos
|
||
return out
|
||
|
||
|
||
# =====================================================================
|
||
# KV cache — BF16 flat, MQA (1 KV head)
|
||
# =====================================================================
|
||
|
||
class SimpleKVCache:
|
||
"""Per-layer KV cache for decode. Stores BF16 K,V accumulated across steps.
|
||
MQA: 1 KV head, so cache is (1, seq_len, hd) per layer."""
|
||
def __init__(self, head_dim, max_seq=8192, device='cuda:0'):
|
||
self.hd = head_dim
|
||
self.max_seq = max_seq
|
||
self.device = device
|
||
self.k = torch.zeros(1, max_seq, head_dim, dtype=torch.bfloat16, device=device)
|
||
self.v = torch.zeros(1, max_seq, head_dim, dtype=torch.bfloat16, device=device)
|
||
self.len = 0
|
||
|
||
def append(self, k_new, v_new):
|
||
"""Append K,V. k_new: (1, T, hd), v_new: (1, T, hd)."""
|
||
T = k_new.shape[1]
|
||
self.k[0, self.len:self.len + T] = k_new[0]
|
||
self.v[0, self.len:self.len + T] = v_new[0]
|
||
self.len += T
|
||
|
||
def get(self):
|
||
"""Get K,V up to current length. Returns (1, seq_len, hd) each."""
|
||
return self.k[:, :self.len], self.v[:, :self.len]
|
||
|
||
|
||
# =====================================================================
|
||
# Weight loading — streams safetensors shards, distributes to 8 GPUs
|
||
# =====================================================================
|
||
|
||
def load_all_weights(checkpoint_dir, num_layers):
|
||
"""Load all weights from checkpoint, distribute layers across GPUs.
|
||
|
||
Returns:
|
||
layer_weights: dict[li] → {key: tensor on cuda:li%8}
|
||
global_weights: {key: tensor on cuda:0}
|
||
"""
|
||
from safetensors.torch import load_file
|
||
cdir = Path(checkpoint_dir)
|
||
index_path = cdir / "model.safetensors.index.json"
|
||
weight_map = {}
|
||
if index_path.exists():
|
||
with open(index_path) as f:
|
||
weight_map = json.load(f).get("weight_map", {})
|
||
shard_names = set(weight_map.values()) if weight_map else {
|
||
f"model-{i:05d}-of-00095.safetensors" for i in range(1, 96)
|
||
}
|
||
print(f"Loading {len(shard_names)} shards...")
|
||
all_weights = {}
|
||
loaded = 0
|
||
for shard_name in sorted(shard_names):
|
||
if not (cdir / shard_name).exists():
|
||
continue
|
||
data = load_file(str(cdir / shard_name))
|
||
all_weights.update(data)
|
||
loaded += 1
|
||
if loaded % 20 == 0:
|
||
print(f" {loaded}/{len(shard_names)} shards, {len(all_weights)} tensors")
|
||
print(f" Done: {len(all_weights)} tensors")
|
||
|
||
layer_weights = {}
|
||
global_weights = {}
|
||
print("Assigning to GPUs...")
|
||
for key, tensor in all_weights.items():
|
||
if key.startswith("model.layers."):
|
||
li = int(key.split(".")[2])
|
||
target_gpu = li % NUM_GPUS
|
||
target_device = f"cuda:{target_gpu}"
|
||
if li not in layer_weights:
|
||
layer_weights[li] = {"_device": target_device, "_gpu": target_gpu}
|
||
layer_weights[li][key] = tensor.to(target_device)
|
||
elif key.startswith("model.embed_tokens"):
|
||
global_weights[key] = tensor.to("cuda:0")
|
||
elif key.startswith("model.norm"):
|
||
global_weights[key] = tensor.to("cuda:0")
|
||
elif key.startswith("lm_head"):
|
||
global_weights[key] = tensor.to("cuda:0")
|
||
|
||
for gpu in range(NUM_GPUS):
|
||
alloc = torch.cuda.memory_allocated(gpu) / 1e9
|
||
if alloc > 0:
|
||
print(f" GPU {gpu}: {alloc:.1f}GB")
|
||
return layer_weights, global_weights
|
||
|
||
|
||
# =====================================================================
|
||
# Single layer forward
|
||
# =====================================================================
|
||
|
||
def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||
attn_mhc, ffn_mhc, attn_norm, ffn_norm,
|
||
kv_cache, token_id, positions):
|
||
"""Forward one layer with mHC + Attention + FFN.
|
||
|
||
Architecture (paper §2):
|
||
X_l → mHC.pre_block(attn) → RMSNorm → Attention → F_attn → mHC.post_block → X_mid
|
||
X_mid → mHC.pre_block(ffn) → RMSNorm → MoE → F_ffn → mHC.post_block → X_{l+1}
|
||
|
||
X_l: (T, n_hc, H) BF16 — mHC residual state
|
||
Returns: X_next (T, n_hc, H) BF16
|
||
"""
|
||
device = X_l.device
|
||
H = cfg["hidden_size"]
|
||
n_h = cfg["num_attention_heads"]
|
||
hd = cfg["head_dim"]
|
||
rd = cfg.get("qk_rope_head_dim", cfg.get("rope_dim", 64))
|
||
o_rank = cfg.get("output_group_dim", 1024)
|
||
o_groups = cfg.get("num_output_groups", 16)
|
||
n_hc = 4
|
||
pre = f"model.layers.{li}.self_attn"
|
||
T = X_l.shape[0]
|
||
heads_per_group = n_h // o_groups
|
||
group_input_dim = heads_per_group * hd
|
||
|
||
# ==================================================================
|
||
# ATTENTION SUB-BLOCK
|
||
# ==================================================================
|
||
|
||
# -- mHC pre_block (attention) --
|
||
x_in, attn_ctx = attn_mhc.pre_block(X_l) # x_in: (T, H)
|
||
|
||
# -- RMSNorm (pre-norm before attention) --
|
||
x_normed = attn_norm.forward(x_in) # (T, H) BF16
|
||
|
||
# -- Q projection: q_a (low-rank down) → q_b (low-rank up) --
|
||
c_Q = nvfp4_linear(x_normed,
|
||
w[f"{pre}.q_a_proj.weight"],
|
||
w[f"{pre}.q_a_proj.weight_scale"],
|
||
w[f"{pre}.q_a_proj.weight_scale_2"]) # (T, dc)
|
||
q = nvfp4_linear(c_Q,
|
||
w[f"{pre}.q_b_proj.weight"],
|
||
w[f"{pre}.q_b_proj.weight_scale"],
|
||
w[f"{pre}.q_b_proj.weight_scale_2"]) # (T, n_h * hd)
|
||
|
||
# -- KV projection (MQA: 1 KV head) --
|
||
kv = nvfp4_linear(x_normed,
|
||
w[f"{pre}.kv_proj.weight"],
|
||
w[f"{pre}.kv_proj.weight_scale"],
|
||
w[f"{pre}.kv_proj.weight_scale_2"]) # (T, hd) — 1 KV head, no split
|
||
|
||
# -- Reshape for attention --
|
||
q_heads = q.reshape(T, n_h, hd) # (T, n_h, hd)
|
||
kv_new = kv.reshape(T, 1, hd) # (T, 1, hd) — 1 KV head
|
||
|
||
# -- Apply RoPE to Q (at current positions) --
|
||
q_heads = apply_rope_partial(q_heads, positions, rope_cos, rope_sin, hd, rd)
|
||
|
||
# -- Apply RoPE to KV (at current positions) BEFORE caching --
|
||
# DSV4 convention: RoPE applied to KV before writing to cache.
|
||
# K = V in DSV4 MQA (same projection, same RoPE'd tensor).
|
||
kv_new = apply_rope_partial(kv_new, positions, rope_cos, rope_sin, hd, rd)
|
||
|
||
# -- KV cache: append RoPE'd KV (K=V) --
|
||
k_new = kv_new # (T, 1, hd) — RoPE'd
|
||
v_new = kv_new # K = V in DSV4 MQA
|
||
kv_cache.append(k_new.permute(1, 0, 2), v_new.permute(1, 0, 2)) # (1, T, hd)
|
||
|
||
# -- Get full KV from cache (already RoPE'd) --
|
||
k_full, v_full = kv_cache.get() # (1, seq_len, hd) each — RoPE'd, K=V
|
||
|
||
# -- FMHA: (n_h, T, hd) × (1, seq_len, hd) → (n_h, T, hd) --
|
||
from dsv4.kernels.attention.production import dsv4_attention
|
||
q_input = q_heads.permute(1, 0, 2) # (n_h, T, hd)
|
||
k_input = k_full.permute(1, 0, 2) # (1, seq_len, hd) — already RoPE'd
|
||
v_input = v_full.permute(1, 0, 2) # (1, seq_len, hd) — K=V, RoPE'd
|
||
attn_out = dsv4_attention(q_input, k_input, v_input) # (n_h, T, hd)
|
||
attn_out = attn_out.permute(1, 0, 2) # (T, n_h, hd)
|
||
|
||
# -- Inverse RoPE on attention output (paper §2.3.3) --
|
||
attn_out = apply_inverse_rope(attn_out, positions, rope_cos, rope_sin, hd, rd)
|
||
|
||
# -- Output projection: wo_a (grouped BMM) + wo_b (NVFP4) --
|
||
# wo_a: grouped linear, (n_h, hd) → (n_groups, o_rank) via BMM
|
||
attn_flat = attn_out.reshape(T, n_h * hd) # (T, n_h * hd)
|
||
attn_grouped = attn_flat.reshape(T, o_groups, heads_per_group * hd) # (T, groups, group_dim)
|
||
oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16() # (n_groups * o_rank, group_input_dim) BF16
|
||
oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim) # (groups, o_rank, group_dim)
|
||
attn_for_bmm = attn_grouped.permute(1, 0, 2) # (groups, T, group_dim)
|
||
grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2)) # (groups, T, o_rank)
|
||
grouped_flat = grouped_out.permute(1, 0, 2).reshape(T, o_groups * o_rank) # (T, groups*o_rank)
|
||
|
||
F_attn = nvfp4_linear(grouped_flat,
|
||
w[f"{pre}.o_b_proj.weight"],
|
||
w[f"{pre}.o_b_proj.weight_scale"],
|
||
w[f"{pre}.o_b_proj.weight_scale_2"]) # (T, H)
|
||
|
||
# -- mHC post_block (attention) --
|
||
X_mid = attn_mhc.post_block(X_l, F_attn, attn_ctx) # (T, n_hc, H)
|
||
|
||
# ==================================================================
|
||
# FFN SUB-BLOCK
|
||
# ==================================================================
|
||
|
||
# -- mHC pre_block (FFN) --
|
||
x_ffn, ffn_ctx = ffn_mhc.pre_block(X_mid) # (T, H)
|
||
|
||
# -- RMSNorm (pre-norm before FFN) --
|
||
x_ffn_normed = ffn_norm.forward(x_ffn) # (T, H) BF16
|
||
|
||
# -- MoE + shared expert --
|
||
F_ffn = moe_forward(x_ffn_normed, w, li, cfg, token_id, device)
|
||
|
||
# -- mHC post_block (FFN) --
|
||
X_next = ffn_mhc.post_block(X_mid, F_ffn, ffn_ctx) # (T, n_hc, H)
|
||
|
||
return X_next
|
||
|
||
|
||
# =====================================================================
|
||
# MoE forward — hash + dense routing, SwiGLU with clamping
|
||
# =====================================================================
|
||
|
||
def moe_forward(x, w, li, cfg, token_id, device):
|
||
"""Run routed MoE + shared expert.
|
||
|
||
x: (T, H) BF16 — post-RMSNorm FFN input
|
||
Returns: (T, H) BF16
|
||
"""
|
||
H = cfg["hidden_size"]
|
||
n_experts = cfg["n_routed_experts"]
|
||
top_k = cfg.get("num_experts_per_tok", 6)
|
||
routed_scaling = cfg.get("routed_scaling_factor", 2.5)
|
||
swiglu_limit = cfg.get("swiglu_limit", 10.0)
|
||
mlp_inter = cfg["moe_intermediate_size"]
|
||
is_hash = li < 3
|
||
|
||
# ---- Routing ----
|
||
expert_ids = None
|
||
expert_weights = None
|
||
|
||
if is_hash:
|
||
tid2eid_key = f"model.layers.{li}.mlp.gate.tid2eid"
|
||
if tid2eid_key in w:
|
||
tid2eid = w[tid2eid_key]
|
||
tid = token_id.item() if token_id.numel() == 1 else token_id[0].item()
|
||
expert_ids = tid2eid[tid] # (top_k,) int64
|
||
expert_weights = torch.ones(top_k, dtype=torch.float32, device=x.device) / top_k
|
||
else:
|
||
# Fallback: use dense routing even for hash layers
|
||
is_hash = False
|
||
|
||
if not is_hash:
|
||
# Dense routing: sqrt(softplus(X @ W_gate)) + e_bias for selection
|
||
gate_w = w[f"model.layers.{li}.mlp.gate.weight"] # (H, n_experts) BF16
|
||
logits = torch.nn.functional.linear(x, gate_w.bfloat16()) # (T, n_experts)
|
||
# Activation: sqrt(softplus(logits))
|
||
activated = torch.sqrt(torch.nn.functional.softplus(logits.float()) + 1e-6)
|
||
# e_bias: learned per-expert bias for SELECTION ONLY (not in weights)
|
||
e_bias_key = f"model.layers.{li}.mlp.gate.e_bias"
|
||
if e_bias_key in w:
|
||
activated = activated + w[e_bias_key].float().unsqueeze(0)
|
||
# Top-k
|
||
scores, indices = activated.topk(top_k, dim=-1) # (T, top_k)
|
||
# Renormalize on UNBIASED activation (no e_bias in weights)
|
||
unbiased = torch.sqrt(torch.nn.functional.softplus(logits.float()) + 1e-6)
|
||
unbiased_scores = torch.gather(unbiased, -1, indices)
|
||
expert_weights = unbiased_scores / unbiased_scores.sum(dim=-1, keepdim=True)
|
||
# For T=1 decode, squeeze
|
||
if x.shape[0] == 1:
|
||
expert_ids = indices[0]
|
||
expert_weights = expert_weights[0]
|
||
else:
|
||
raise NotImplementedError("Multi-token MoE routing")
|
||
|
||
# ---- Run selected experts ----
|
||
T = x.shape[0]
|
||
expert_outputs = []
|
||
for i, eid in enumerate(expert_ids):
|
||
eid_int = eid.item()
|
||
epre = f"model.layers.{li}.mlp.experts.{eid_int}"
|
||
|
||
gate = nvfp4_linear(x,
|
||
w[f"{epre}.gate_proj.weight"],
|
||
w[f"{epre}.gate_proj.weight_scale"],
|
||
w[f"{epre}.gate_proj.weight_scale_2"])
|
||
up = nvfp4_linear(x,
|
||
w[f"{epre}.up_proj.weight"],
|
||
w[f"{epre}.up_proj.weight_scale"],
|
||
w[f"{epre}.up_proj.weight_scale_2"])
|
||
|
||
# SwiGLU with clamping (paper §4.2.3)
|
||
silu_out = torch.nn.functional.silu(gate.float())
|
||
if swiglu_limit is not None:
|
||
silu_out = silu_out.clamp(-swiglu_limit, swiglu_limit)
|
||
up_clamped = up.float().clamp(-swiglu_limit, swiglu_limit)
|
||
else:
|
||
up_clamped = up.float()
|
||
hidden = (silu_out * up_clamped).bfloat16()
|
||
|
||
down = nvfp4_linear(hidden,
|
||
w[f"{epre}.down_proj.weight"],
|
||
w[f"{epre}.down_proj.weight_scale"],
|
||
w[f"{epre}.down_proj.weight_scale_2"])
|
||
expert_outputs.append(down)
|
||
|
||
# Weighted combine + scaling
|
||
routed_out = torch.zeros_like(x)
|
||
for i, (out, wt) in enumerate(zip(expert_outputs, expert_weights)):
|
||
routed_out = routed_out + (out.float() * wt).bfloat16()
|
||
routed_out = (routed_out.float() * routed_scaling).bfloat16()
|
||
|
||
# ---- Shared expert ----
|
||
se_pre = f"model.layers.{li}.mlp.shared_experts"
|
||
se_gate_key = f"{se_pre}.gate_proj.weight"
|
||
if se_gate_key in w:
|
||
gate = nvfp4_linear(x,
|
||
w[se_gate_key],
|
||
w[f"{se_pre}.gate_proj.weight_scale"],
|
||
w[f"{se_pre}.gate_proj.weight_scale_2"])
|
||
up = nvfp4_linear(x,
|
||
w[f"{se_pre}.up_proj.weight"],
|
||
w[f"{se_pre}.up_proj.weight_scale"],
|
||
w[f"{se_pre}.up_proj.weight_scale_2"])
|
||
silu_out = torch.nn.functional.silu(gate.float())
|
||
if swiglu_limit is not None:
|
||
silu_out = silu_out.clamp(-swiglu_limit, swiglu_limit)
|
||
up_clamped = up.float().clamp(-swiglu_limit, swiglu_limit)
|
||
else:
|
||
up_clamped = up.float()
|
||
hidden = (silu_out * up_clamped).bfloat16()
|
||
shared_out = nvfp4_linear(hidden,
|
||
w[f"{se_pre}.down_proj.weight"],
|
||
w[f"{se_pre}.down_proj.weight_scale"],
|
||
w[f"{se_pre}.down_proj.weight_scale_2"])
|
||
else:
|
||
shared_out = torch.zeros_like(x)
|
||
|
||
return routed_out + shared_out
|
||
|
||
|
||
# =====================================================================
|
||
# Main
|
||
# =====================================================================
|
||
|
||
def main():
|
||
t_start = time.time()
|
||
print("=" * 70)
|
||
print("DSV4 Single-Shot Inference — Full Pipeline (mHC+Attn+MoE)")
|
||
print(" Proper Sinkhorn mHC, RMSNorm, inverse RoPE, production FMHA")
|
||
print("=" * 70)
|
||
|
||
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
|
||
cfg = json.load(f)
|
||
n_layers = cfg["num_hidden_layers"]
|
||
H = cfg["hidden_size"]
|
||
n_h = cfg["num_attention_heads"]
|
||
hd = cfg["head_dim"]
|
||
rd = cfg.get("qk_rope_head_dim", cfg.get("rope_dim", 64))
|
||
n_hc = 4
|
||
print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}")
|
||
print(f"Experts: {cfg['n_routed_experts']}, top-{cfg.get('num_experts_per_tok', 6)}")
|
||
|
||
# ==== Phase 1: Load weights ====
|
||
print(f"\n{'='*70}\nPhase 1: Loading weights\n{'='*70}")
|
||
layer_weights, global_weights = load_all_weights(CHECKPOINT_DIR, n_layers)
|
||
t_loaded = time.time()
|
||
print(f"Weight loading: {t_loaded - t_start:.1f}s")
|
||
|
||
# ==== Build mHC blocks (proper Sinkhorn) ====
|
||
print("Building mHC blocks...")
|
||
attn_mhc_blocks = {}
|
||
ffn_mhc_blocks = {}
|
||
attn_norms = {}
|
||
ffn_norms = {}
|
||
for li in range(n_layers):
|
||
gpu = li % NUM_GPUS
|
||
dev = f"cuda:{gpu}"
|
||
|
||
# mHC blocks
|
||
for prefix, blocks in [(f"model.layers.{li}.attn_hc", attn_mhc_blocks),
|
||
(f"model.layers.{li}.ffn_hc", ffn_mhc_blocks)]:
|
||
fn = layer_weights[li].get(f"{prefix}.fn")
|
||
base = layer_weights[li].get(f"{prefix}.base")
|
||
scale = layer_weights[li].get(f"{prefix}.scale")
|
||
if fn is not None and base is not None and scale is not None:
|
||
mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
|
||
mhc.load_from_checkpoint(fn, base, scale)
|
||
blocks[li] = mhc
|
||
else:
|
||
# Fallback: identity mHC (A=1, B=I, C=1) — not ideal but prevents crash
|
||
print(f" WARNING: no mHC weights for {prefix}, using identity fallback")
|
||
mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
|
||
n = n_hc
|
||
K = n * H
|
||
mhc.W_stacked = torch.zeros(n + n*n + n, K, dtype=torch.float32, device=dev)
|
||
mhc.S_pre = torch.zeros(1, n, dtype=torch.float32, device=dev)
|
||
mhc.S_res = torch.eye(n, dtype=torch.float32, device=dev)
|
||
mhc.S_post = torch.ones(n, 1, dtype=torch.float32, device=dev) * 0.5
|
||
mhc.alpha_pre = 0.01
|
||
mhc.alpha_res = 0.01
|
||
mhc.alpha_post = 0.01
|
||
blocks[li] = mhc
|
||
|
||
# RMSNorms (pre-norm before each sub-block)
|
||
attn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=dev)
|
||
an_key = f"model.layers.{li}.input_layernorm.weight"
|
||
if an_key in layer_weights[li]:
|
||
attn_norm.weight = layer_weights[li][an_key].to(device=dev, dtype=torch.float32)
|
||
attn_norms[li] = attn_norm
|
||
|
||
ffn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=dev)
|
||
fn_key = f"model.layers.{li}.post_attention_layernorm.weight"
|
||
if fn_key in layer_weights[li]:
|
||
ffn_norm.weight = layer_weights[li][fn_key].to(device=dev, dtype=torch.float32)
|
||
ffn_norms[li] = ffn_norm
|
||
|
||
print(f" attn mHC: {len(attn_mhc_blocks)}, ffn mHC: {len(ffn_mhc_blocks)}")
|
||
|
||
# ==== Global weights ====
|
||
torch.cuda.set_device(0)
|
||
embed_w = global_weights.get("model.embed_tokens.weight")
|
||
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16())
|
||
lm_w = global_weights.get("lm_head.weight", embed_w).bfloat16()
|
||
final_norm_w = global_weights.get("model.norm.weight")
|
||
rope_caches = {g: build_rope_cache(8192, rd, f"cuda:{g}") for g in range(NUM_GPUS)}
|
||
|
||
# ==== KV caches (one per layer on its GPU) ====
|
||
kv_caches = {}
|
||
for li in range(n_layers):
|
||
kv_caches[li] = SimpleKVCache(head_dim=hd, max_seq=8192, device=f"cuda:{li % NUM_GPUS}")
|
||
|
||
# ==== Phase 2: Compile FMHA ====
|
||
print(f"\n{'='*70}\nPhase 2: JIT compiling\n{'='*70}")
|
||
from dsv4.kernels.attention.production import dsv4_attention
|
||
torch.cuda.set_device(0)
|
||
dummy_q = torch.randn(n_h, 1, hd, dtype=torch.bfloat16, device='cuda:0')
|
||
dummy_k = torch.randn(1, 1, hd, dtype=torch.bfloat16, device='cuda:0')
|
||
try:
|
||
_ = dsv4_attention(dummy_q, dummy_k, dummy_k.clone())
|
||
print(" FMHA: compiled OK")
|
||
except Exception as e:
|
||
print(f" FMHA error: {e}")
|
||
t_compiled = time.time()
|
||
print(f"Compile: {t_compiled - t_loaded:.1f}s")
|
||
|
||
# ==== Phase 3: Inference ====
|
||
print(f"\n{'='*70}\nPhase 3: Inference\n{'='*70}")
|
||
from transformers import AutoTokenizer
|
||
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
|
||
input_ids = tokenizer.encode(PROMPT, return_tensors="pt").cuda()
|
||
print(f"Prompt: '{PROMPT}' → {input_ids.tolist()}")
|
||
|
||
generated = input_ids[0].tolist()
|
||
|
||
# ==== Prefill: process prompt tokens to fill KV cache ====
|
||
print(f"Prefilling {len(generated)} prompt tokens...")
|
||
for prefill_idx, tid_val in enumerate(generated):
|
||
tid = torch.tensor([tid_val], dtype=torch.long, device='cuda:0')
|
||
positions = torch.tensor([prefill_idx], dtype=torch.long, device='cuda:0')
|
||
emb = embed(tid) # (1, H) on gpu0
|
||
X = mHCBlock.init_state(emb, n_hc) # (1, n_hc, H)
|
||
|
||
for li in range(n_layers):
|
||
gpu = li % NUM_GPUS
|
||
target_device = f"cuda:{gpu}"
|
||
if X.device != torch.device(target_device):
|
||
X = X.to(target_device)
|
||
torch.cuda.set_device(gpu)
|
||
|
||
attn_mhc = attn_mhc_blocks.get(li)
|
||
ffn_mhc = ffn_mhc_blocks.get(li)
|
||
a_norm = attn_norms[li]
|
||
f_norm = ffn_norms[li]
|
||
rc, rs = rope_caches[gpu]
|
||
X = forward_layer(X, layer_weights[li], li, cfg, rc, rs,
|
||
attn_mhc, ffn_mhc, a_norm, f_norm,
|
||
kv_caches[li], tid, positions)
|
||
|
||
X = X.to('cuda:0')
|
||
torch.cuda.set_device(0)
|
||
|
||
print(f" Prefill done ({len(generated)} tokens, {time.time()-t_compiled:.1f}s)")
|
||
|
||
# ==== Decode: generate new tokens ====
|
||
print(f"\nDecoding (max {MAX_NEW_TOKENS} new tokens)...")
|
||
all_tokens = generated.copy()
|
||
|
||
for step in range(MAX_NEW_TOKENS):
|
||
t0 = time.time()
|
||
tid = torch.tensor([all_tokens[-1]], dtype=torch.long, device='cuda:0')
|
||
decode_pos = len(all_tokens) - 1
|
||
positions = torch.tensor([decode_pos], dtype=torch.long, device='cuda:0')
|
||
|
||
emb = embed(tid) # (1, H) on gpu0
|
||
X = mHCBlock.init_state(emb, n_hc) # (1, n_hc, H)
|
||
|
||
for li in range(n_layers):
|
||
gpu = li % NUM_GPUS
|
||
target_device = f"cuda:{gpu}"
|
||
if X.device != torch.device(target_device):
|
||
X = X.to(target_device)
|
||
torch.cuda.set_device(gpu)
|
||
|
||
attn_mhc = attn_mhc_blocks.get(li)
|
||
ffn_mhc = ffn_mhc_blocks.get(li)
|
||
a_norm = attn_norms[li]
|
||
f_norm = ffn_norms[li]
|
||
rc, rs = rope_caches[gpu]
|
||
X = forward_layer(X, layer_weights[li], li, cfg, rc, rs,
|
||
attn_mhc, ffn_mhc, a_norm, f_norm,
|
||
kv_caches[li], tid, positions)
|
||
|
||
X = X.to('cuda:0')
|
||
torch.cuda.set_device(0)
|
||
|
||
# Read out stream 0 → RMSNorm → lm_head
|
||
x_out = X[:, 0, :] # (1, H)
|
||
if final_norm_w is not None:
|
||
xf = x_out.float()
|
||
rms = xf.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||
x_out = (xf * rms * final_norm_w.float()).bfloat16()
|
||
|
||
logits = torch.nn.functional.linear(x_out, lm_w)
|
||
next_id = torch.argmax(logits, dim=-1).item()
|
||
generated.append(next_id)
|
||
all_tokens.append(next_id)
|
||
|
||
tok_str = tokenizer.decode([next_id])
|
||
dt = time.time() - t0
|
||
has_nan = torch.isnan(logits.float()).any().item()
|
||
has_inf = torch.isinf(logits.float()).any().item()
|
||
lmin, lmax = logits.float().min().item(), logits.float().max().item()
|
||
x_max = X.abs().max().item()
|
||
print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s) "
|
||
f"logits=[{lmin:.1f},{lmax:.1f}] nan={has_nan} inf={has_inf} "
|
||
f"|X|={x_max:.3f}")
|
||
|
||
if has_nan or has_inf:
|
||
print(" Numerical issue — stopping")
|
||
break
|
||
if next_id == tokenizer.eos_token_id:
|
||
break
|
||
|
||
out = tokenizer.decode(generated, skip_special_tokens=True)
|
||
total = time.time() - t_start
|
||
print(f"\n{'='*70}")
|
||
print(f"Input: '{PROMPT}'")
|
||
print(f"Output: '{out}'")
|
||
print(f"Total: {total:.1f}s")
|
||
print(f"{'='*70}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|