806 lines
34 KiB
Python
806 lines
34 KiB
Python
#!/usr/bin/env python3
|
||
"""Single-shot DSV4-Pro inference — Full 61-layer pipeline, 8-GPU.
|
||
|
||
This is a reference implementation that exercises the production kernel
|
||
stack end-to-end. It should be usable as ground truth when integrating
|
||
into vLLM or SGLang.
|
||
|
||
Architecture (paper §2):
|
||
X_l → mHC.pre_block → RMSNorm → Attention → F_attn → mHC.post_block → X_mid
|
||
X_mid → mHC.pre_block → RMSNorm → FFN(MoE) → F_ffn → mHC.post_block → X_{l+1}
|
||
|
||
Components exercised:
|
||
- mHC (Manifold-Constrained Hyper-Connections) — proper Sinkhorn-Knopp
|
||
- Low-rank Q projection (q_a → q_b) + KV projection (MQA, 1 KV head)
|
||
- Partial RoPE (last 64 dims, GPT-J interleaved)
|
||
- Production FMHA kernel (6-warp multi-tile, C API + ctypes)
|
||
- Inverse RoPE on attention output (paper §2.3.3)
|
||
- Grouped output projection (wo_a BMM + wo_b NVFP4)
|
||
- Routed MoE (384 experts, top-6, hash + dense routing, SwiGLU clamp)
|
||
- Shared expert (NVFP4 gate/up/down)
|
||
- RMSNorm (pre-norm before each sub-block)
|
||
- KV cache across decode steps
|
||
|
||
Attention type simplification for this single-shot test:
|
||
For short sequences (seq_len ≤ sliding_window=128), ALL attention
|
||
types (CSA/HCA/SWA) reduce to dense attention over the full KV cache.
|
||
CSA's compressed branch and indexer are only needed for long sequences
|
||
where seq_len > sliding_window. HCA is dense over compressed entries,
|
||
but at short sequence lengths, the compressed sequence is trivially
|
||
small. So we use dense MQA attention over the full KV for all layers.
|
||
This is mathematically correct for short sequences and exercises the
|
||
FMHA kernel properly.
|
||
|
||
Usage (on B200):
|
||
source /root/dsv4-nvfp4-workspace/venv/bin/activate
|
||
cd /root/dsv4-nvfp4-workspace/kernel
|
||
python3 single_shot_inference.py
|
||
"""
|
||
import os, sys, time, json, math
|
||
import torch
|
||
from pathlib import Path
|
||
|
||
# =====================================================================
|
||
# Configuration
|
||
# =====================================================================
|
||
|
||
CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
||
MAX_NEW_TOKENS = 10
|
||
PROMPT = "The capital of France is"
|
||
NUM_GPUS = 8
|
||
|
||
# =====================================================================
|
||
# NVFP4 dequantization — matches checkpoint format exactly
|
||
# =====================================================================
|
||
|
||
FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]) # E2M1 magnitudes
|
||
|
||
def dequant_nvfp4_weight(weight, weight_scale, weight_scale_2):
|
||
"""Dequantize NVFP4 weight to BF16.
|
||
|
||
weight: (out_dim, in_dim//2) uint8 — 2 FP4 values per byte
|
||
weight_scale: (out_dim, in_dim//16) E4M3 — per-16-element block scale
|
||
weight_scale_2: (out_dim, 1) float32 — per-row global scale
|
||
"""
|
||
out_dim = weight.shape[0]
|
||
in_packed = weight.shape[1]
|
||
in_features = in_packed * 2
|
||
low = (weight & 0x0F).to(torch.int8)
|
||
high = (weight >> 4).to(torch.int8)
|
||
low_sign, low_idx = (low >> 3).bool(), (low & 0x07).long()
|
||
high_sign, high_idx = (high >> 3).bool(), (high & 0x07).long()
|
||
lut = FP4_LUT.to(device=weight.device, dtype=torch.float32)
|
||
low_f = lut[low_idx] * torch.where(low_sign, -1.0, 1.0)
|
||
high_f = lut[high_idx] * torch.where(high_sign, -1.0, 1.0)
|
||
w_f = torch.stack([low_f, high_f], dim=-1).reshape(out_dim, in_features)
|
||
scale_f = weight_scale.float() * weight_scale_2.float()
|
||
scale_expanded = scale_f.repeat_interleave(16, dim=1)
|
||
return (w_f * scale_expanded).bfloat16()
|
||
|
||
|
||
def nvfp4_linear(x, weight, weight_scale, weight_scale_2):
|
||
"""BF16 linear with NVFP4 dequant."""
|
||
w = dequant_nvfp4_weight(weight, weight_scale, weight_scale_2)
|
||
return torch.nn.functional.linear(x, w)
|
||
|
||
|
||
# =====================================================================
|
||
# RMSNorm — matches dsv4/layers/norm.py
|
||
# =====================================================================
|
||
|
||
class RMSNorm:
|
||
def __init__(self, hidden_size, eps=1e-6, device='cuda:0'):
|
||
self.eps = eps
|
||
self.weight = torch.ones(hidden_size, dtype=torch.float32, device=device)
|
||
|
||
def forward(self, x):
|
||
"""x: (T, H) BF16 → (T, H) BF16"""
|
||
x_f = x.float()
|
||
rms = x_f.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
|
||
return (x_f * rms * self.weight).to(torch.bfloat16)
|
||
|
||
|
||
# =====================================================================
|
||
# mHC — proper Sinkhorn-Knopp implementation
|
||
# =====================================================================
|
||
|
||
class mHCBlock:
|
||
"""Wrapper around dsv4.layers.mhc.mHCLayer for single-shot inference.
|
||
|
||
Uses the production mHCLayer implementation with proper Sinkhorn-Knopp.
|
||
"""
|
||
def __init__(self, hidden_dim=7168, n_hc=4, sinkhorn_iters=20, device='cuda:0'):
|
||
from dsv4.layers.mhc import mHCLayer
|
||
self._impl = mHCLayer(
|
||
hidden_dim=hidden_dim, n_hc=n_hc,
|
||
t_max_sinkhorn=sinkhorn_iters,
|
||
device=device, dtype=torch.bfloat16)
|
||
self.device = device
|
||
self.n_hc = n_hc
|
||
self.hidden_dim = hidden_dim
|
||
|
||
def load_from_checkpoint(self, fn, base, scale):
|
||
"""Load from checkpoint tensors.
|
||
fn: (24, 28672) FP32 — fused projection
|
||
base: (24,) — [pre(4), post(4), res(16)]
|
||
scale: (3,) — [alpha_pre, alpha_post, alpha_res]
|
||
"""
|
||
n = self.n_hc
|
||
dev = self.device
|
||
|
||
# fn rows: [W_pre(4), W_post(4), W_res(16)]
|
||
W_pre = fn[0:n].to(device=dev, dtype=torch.float32).contiguous()
|
||
W_post = fn[n:2*n].to(device=dev, dtype=torch.float32).contiguous()
|
||
W_res = fn[2*n:].to(device=dev, dtype=torch.float32).contiguous()
|
||
|
||
# base: [S_pre(4), S_post(4), S_res(16)]
|
||
S_pre = base[0:n].reshape(1, n).to(device=dev, dtype=torch.bfloat16).contiguous()
|
||
S_post = base[n:2*n].reshape(n, 1).to(device=dev, dtype=torch.bfloat16).contiguous()
|
||
S_res = base[2*n:].reshape(n, n).to(device=dev, dtype=torch.bfloat16).contiguous()
|
||
|
||
# scale: [alpha_pre, alpha_post, alpha_res]
|
||
alpha_pre = scale[0].item()
|
||
alpha_post = scale[1].item()
|
||
alpha_res = scale[2].item()
|
||
|
||
self._impl.load_weights(
|
||
W_pre=W_pre, W_res=W_res, W_post=W_post,
|
||
S_pre=S_pre, S_res=S_res, S_post=S_post,
|
||
alpha_pre=alpha_pre, alpha_res=alpha_res, alpha_post=alpha_post)
|
||
|
||
@staticmethod
|
||
def init_state(embeddings, n_hc=4):
|
||
from dsv4.layers.mhc import mHCLayer
|
||
return mHCLayer.init_state(embeddings, n_hc)
|
||
|
||
def pre_block(self, X_l):
|
||
return self._impl.pre_block(X_l)
|
||
|
||
def post_block(self, X_l, F_out, ctx):
|
||
return self._impl.post_block(X_l, F_out, ctx)
|
||
|
||
|
||
# =====================================================================
|
||
# RoPE — partial, GPT-J interleaved, last rope_dim dims
|
||
# =====================================================================
|
||
|
||
def build_rope_cache(max_pos, rope_dim, device, theta=10000.0):
|
||
"""Build cos/sin caches for partial RoPE.
|
||
Returns: (cos_cache, sin_cache) each (max_pos, rope_dim//2) BF16
|
||
"""
|
||
half = rope_dim // 2
|
||
freqs = 1.0 / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim))
|
||
angles = torch.outer(torch.arange(max_pos, dtype=torch.float32), freqs)
|
||
return torch.cos(angles).bfloat16().to(device), torch.sin(angles).bfloat16().to(device)
|
||
|
||
|
||
def apply_rope_partial(x, positions, cos_cache, sin_cache, head_dim, rope_dim):
|
||
"""Apply partial GPT-J interleaved RoPE to the last rope_dim dims of each head."""
|
||
T, n_h, hd = x.shape
|
||
nope = hd - rope_dim
|
||
cos = cos_cache[positions].unsqueeze(1) # (T, 1, half) BF16
|
||
sin = sin_cache[positions].unsqueeze(1)
|
||
out = x.clone()
|
||
x_rope = x[:, :, nope:]
|
||
out[:, :, nope:][..., 0::2] = x_rope[..., 0::2] * cos - x_rope[..., 1::2] * sin
|
||
out[:, :, nope:][..., 1::2] = x_rope[..., 0::2] * sin + x_rope[..., 1::2] * cos
|
||
return out
|
||
|
||
|
||
def apply_inverse_rope(o, positions, cos_cache, sin_cache, head_dim, rope_dim):
|
||
"""Apply inverse RoPE (conjugate rotation) to attention output."""
|
||
T, n_h, hd = o.shape
|
||
nope = hd - rope_dim
|
||
cos = cos_cache[positions].unsqueeze(1)
|
||
sin = sin_cache[positions].unsqueeze(1)
|
||
out = o.clone()
|
||
o_rope = o[:, :, nope:]
|
||
out[:, :, nope:][..., 0::2] = o_rope[..., 0::2] * cos + o_rope[..., 1::2] * sin
|
||
out[:, :, nope:][..., 1::2] = -o_rope[..., 0::2] * sin + o_rope[..., 1::2] * cos
|
||
return out
|
||
|
||
class SimpleKVCache:
|
||
"""Per-layer KV cache for decode. Stores BF16 K,V accumulated across steps.
|
||
MQA: 1 KV head, so cache is (1, seq_len, hd) per layer."""
|
||
def __init__(self, head_dim, max_seq=8192, device='cuda:0'):
|
||
self.hd = head_dim
|
||
self.max_seq = max_seq
|
||
self.device = device
|
||
self.k = torch.zeros(1, max_seq, head_dim, dtype=torch.bfloat16, device=device)
|
||
self.v = torch.zeros(1, max_seq, head_dim, dtype=torch.bfloat16, device=device)
|
||
self.len = 0
|
||
|
||
def append(self, k_new, v_new):
|
||
"""Append K,V. k_new: (1, T, hd), v_new: (1, T, hd)."""
|
||
T = k_new.shape[1]
|
||
self.k[0, self.len:self.len + T] = k_new[0]
|
||
self.v[0, self.len:self.len + T] = v_new[0]
|
||
self.len += T
|
||
|
||
def get(self):
|
||
"""Get K,V up to current length. Returns (1, seq_len, hd) each."""
|
||
return self.k[:, :self.len], self.v[:, :self.len]
|
||
|
||
|
||
# =====================================================================
|
||
# Weight loading — streams safetensors shards, distributes to 8 GPUs
|
||
# =====================================================================
|
||
|
||
def load_weights_to_cpu(checkpoint_dir):
|
||
"""Load all weights from checkpoint to CPU memory.
|
||
|
||
Weights stay on CPU; we move per-layer to GPU on demand during inference.
|
||
This avoids OOM from 285K GPU allocations and allows streaming.
|
||
|
||
Returns:
|
||
all_weights: dict[key] → tensor on CPU
|
||
"""
|
||
from safetensors.torch import load_file
|
||
cdir = Path(checkpoint_dir)
|
||
index_path = cdir / "model.safetensors.index.json"
|
||
weight_map = {}
|
||
if index_path.exists():
|
||
with open(index_path) as f:
|
||
weight_map = json.load(f).get("weight_map", {})
|
||
shard_names = set(weight_map.values()) if weight_map else {
|
||
f"model-{i:05d}-of-00095.safetensors" for i in range(1, 96)
|
||
}
|
||
print(f"Loading {len(shard_names)} shards to CPU...")
|
||
all_weights = {}
|
||
loaded = 0
|
||
for shard_name in sorted(shard_names):
|
||
if not (cdir / shard_name).exists():
|
||
continue
|
||
data = load_file(str(cdir / shard_name))
|
||
all_weights.update(data)
|
||
loaded += 1
|
||
if loaded % 20 == 0:
|
||
print(f" {loaded}/{len(shard_names)} shards, {len(all_weights)} tensors")
|
||
print(f" Done: {len(all_weights)} tensors on CPU")
|
||
return all_weights
|
||
|
||
|
||
def get_layer_weights(all_weights, li, device):
|
||
"""Get weights for layer li, moved to target device.
|
||
|
||
Returns dict of key→tensor on device. Filters by model.layers.{li} prefix.
|
||
"""
|
||
prefix = f"model.layers.{li}."
|
||
w = {}
|
||
for key in all_weights:
|
||
if key.startswith(prefix):
|
||
w[key] = all_weights[key].to(device=device, non_blocking=True)
|
||
return w
|
||
|
||
|
||
# =====================================================================
|
||
# Single layer forward
|
||
# =====================================================================
|
||
|
||
def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||
attn_mhc, ffn_mhc, attn_norm, ffn_norm,
|
||
kv_cache, token_id, positions):
|
||
"""Forward one layer with mHC + Attention + FFN.
|
||
|
||
Architecture (paper §2):
|
||
X_l → mHC.pre_block(attn) → RMSNorm → Attention → F_attn → mHC.post_block → X_mid
|
||
X_mid → mHC.pre_block(ffn) → RMSNorm → MoE → F_ffn → mHC.post_block → X_{l+1}
|
||
|
||
X_l: (T, n_hc, H) BF16 — mHC residual state
|
||
Returns: X_next (T, n_hc, H) BF16
|
||
"""
|
||
device = X_l.device
|
||
H = cfg["hidden_size"]
|
||
n_h = cfg["num_attention_heads"]
|
||
hd = cfg["head_dim"]
|
||
rd = cfg.get("qk_rope_head_dim", cfg.get("rope_dim", 64))
|
||
o_rank = cfg.get("output_group_dim", 1024)
|
||
o_groups = cfg.get("num_output_groups", 16)
|
||
n_hc = 4
|
||
pre = f"model.layers.{li}.self_attn"
|
||
T = X_l.shape[0]
|
||
heads_per_group = n_h // o_groups
|
||
group_input_dim = heads_per_group * hd
|
||
|
||
# ==================================================================
|
||
# ATTENTION SUB-BLOCK
|
||
# ==================================================================
|
||
|
||
# -- mHC pre_block (attention) --
|
||
x_in, attn_ctx = attn_mhc.pre_block(X_l) # x_in: (T, H)
|
||
if False: # diag disabled
|
||
A_l = None
|
||
B_l, C_l = attn_ctx
|
||
print(f" L{li} pre_attn: |X_l|={X_l.abs().max().item():.2f} |x_in|={x_in.abs().max().item():.2f}", flush=True)
|
||
|
||
# -- RMSNorm (pre-norm before attention) --
|
||
x_normed = attn_norm.forward(x_in) # (T, H) BF16
|
||
|
||
# -- Q projection: q_a (low-rank down) → q_a_norm → q_b (low-rank up) --
|
||
c_Q = nvfp4_linear(x_normed,
|
||
w[f"{pre}.q_a_proj.weight"],
|
||
w[f"{pre}.q_a_proj.weight_scale"],
|
||
w[f"{pre}.q_a_proj.weight_scale_2"]) # (T, dc)
|
||
# Q norm (RMSNorm after q_a, before q_b)
|
||
q_norm_w = w.get(f"{pre}.q_a_norm.weight")
|
||
if q_norm_w is not None:
|
||
c_Q_f = c_Q.float()
|
||
c_Q_rms = c_Q_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||
c_Q = (c_Q_f * c_Q_rms * q_norm_w.float()).bfloat16()
|
||
q = nvfp4_linear(c_Q,
|
||
w[f"{pre}.q_b_proj.weight"],
|
||
w[f"{pre}.q_b_proj.weight_scale"],
|
||
w[f"{pre}.q_b_proj.weight_scale_2"]) # (T, n_h * hd)
|
||
|
||
# -- KV projection (MQA: 1 KV head) + KV norm --
|
||
kv = nvfp4_linear(x_normed,
|
||
w[f"{pre}.kv_proj.weight"],
|
||
w[f"{pre}.kv_proj.weight_scale"],
|
||
w[f"{pre}.kv_proj.weight_scale_2"]) # (T, hd)
|
||
# KV norm (RMSNorm after kv_proj)
|
||
kv_norm_w = w.get(f"{pre}.kv_norm.weight")
|
||
if kv_norm_w is not None:
|
||
kv_f = kv.float()
|
||
kv_rms = kv_f.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||
kv = (kv_f * kv_rms * kv_norm_w.float()).bfloat16()
|
||
|
||
# -- Reshape for attention --
|
||
q_heads = q.reshape(T, n_h, hd) # (T, n_h, hd)
|
||
kv_new = kv.reshape(T, 1, hd) # (T, 1, hd) — 1 KV head
|
||
|
||
# -- Apply RoPE to Q (at current positions) --
|
||
positions_dev = positions.to(device)
|
||
q_heads = apply_rope_partial(q_heads, positions_dev, rope_cos, rope_sin, hd, rd)
|
||
|
||
# -- Apply RoPE to KV (at current positions) BEFORE caching --
|
||
# DSV4 convention: RoPE applied to KV before writing to cache.
|
||
# K = V in DSV4 MQA (same projection, same RoPE'd tensor).
|
||
kv_new = apply_rope_partial(kv_new, positions_dev, rope_cos, rope_sin, hd, rd)
|
||
|
||
# -- KV cache: append RoPE'd KV (K=V) --
|
||
k_new = kv_new # (T, 1, hd) — RoPE'd
|
||
v_new = kv_new # K = V in DSV4 MQA
|
||
kv_cache.append(k_new.permute(1, 0, 2), v_new.permute(1, 0, 2)) # (1, T, hd)
|
||
|
||
# -- Get full KV from cache (already RoPE'd) --
|
||
k_full, v_full = kv_cache.get() # (1, seq_len, hd) each — RoPE'd, K=V
|
||
seq_len = k_full.shape[1]
|
||
|
||
# -- FMHA: (n_h, T, hd) × (1, seq_len, hd) → (n_h, T, hd) --
|
||
q_input = q_heads.permute(1, 0, 2) # (n_h, T, hd)
|
||
# Use PyTorch SDPA for correctness verification
|
||
USE_SDPA = False # Use production FMHA kernel (better residual, no sinks)
|
||
if USE_SDPA:
|
||
# Expand K/V for GQA: (1, seq_len, hd) → (n_h, seq_len, hd)
|
||
k_expanded = k_full.expand(n_h, -1, -1).contiguous() # (n_h, seq_len, hd)
|
||
v_expanded = v_full.expand(n_h, -1, -1).contiguous()
|
||
|
||
# Add attention sink (paper §2.3.3, D5c)
|
||
# The sink is a per-head logit bias added to a virtual position.
|
||
# We simulate it by appending a zero-valued KV position with the sink logit.
|
||
sink_key = f"{pre}.sinks"
|
||
if sink_key in w and seq_len > 0:
|
||
sinks = w[sink_key].to(device=device) # (n_h,) BF16
|
||
# Append zero KV entry for the sink
|
||
sink_k = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device=device)
|
||
sink_v = torch.zeros(n_h, 1, hd, dtype=torch.bfloat16, device=device)
|
||
k_with_sink = torch.cat([k_expanded, sink_k], dim=1) # (n_h, seq_len+1, hd)
|
||
v_with_sink = torch.cat([v_expanded, sink_v], dim=1)
|
||
# Create attention bias: sink logit added to the last position for each head
|
||
# attn_mask shape: (n_h, T, seq_len+1)
|
||
sink_bias_mask = torch.zeros(n_h, T, seq_len + 1, dtype=torch.bfloat16, device=device)
|
||
for h in range(n_h):
|
||
sink_bias_mask[h, :, -1] = sinks[h] # Add sink logit to sink position
|
||
attn_out = torch.nn.functional.scaled_dot_product_attention(
|
||
q_input, k_with_sink, v_with_sink,
|
||
attn_mask=sink_bias_mask,
|
||
scale=1.0 / math.sqrt(hd))
|
||
else:
|
||
attn_out = torch.nn.functional.scaled_dot_product_attention(
|
||
q_input, k_expanded, v_expanded,
|
||
scale=1.0 / math.sqrt(hd), is_causal=False)
|
||
else:
|
||
from dsv4.kernels.attention.production import dsv4_attention
|
||
attn_out = dsv4_attention(q_input, k_full, v_full)
|
||
attn_out = attn_out.permute(1, 0, 2) # (T, n_h, hd)
|
||
|
||
# -- Inverse RoPE on attention output (paper §2.3.3) --
|
||
attn_out = apply_inverse_rope(attn_out, positions_dev, rope_cos, rope_sin, hd, rd)
|
||
|
||
# -- Output projection: wo_a (grouped BMM) + wo_b (NVFP4) --
|
||
# wo_a: grouped linear, (n_h, hd) → (n_groups, o_rank) via BMM
|
||
attn_flat = attn_out.reshape(T, n_h * hd) # (T, n_h * hd)
|
||
attn_grouped = attn_flat.reshape(T, o_groups, heads_per_group * hd) # (T, groups, group_dim)
|
||
oa_w = w[f"{pre}.o_a_proj.weight"].bfloat16() # (n_groups * o_rank, group_input_dim) BF16
|
||
oa_3d = oa_w.reshape(o_groups, o_rank, group_input_dim) # (groups, o_rank, group_dim)
|
||
attn_for_bmm = attn_grouped.permute(1, 0, 2) # (groups, T, group_dim)
|
||
grouped_out = torch.bmm(attn_for_bmm, oa_3d.transpose(1, 2)) # (groups, T, o_rank)
|
||
grouped_flat = grouped_out.permute(1, 0, 2).reshape(T, o_groups * o_rank) # (T, groups*o_rank)
|
||
|
||
F_attn = nvfp4_linear(grouped_flat,
|
||
w[f"{pre}.o_b_proj.weight"],
|
||
w[f"{pre}.o_b_proj.weight_scale"],
|
||
w[f"{pre}.o_b_proj.weight_scale_2"]) # (T, H)
|
||
|
||
# -- mHC post_block (attention) --
|
||
X_mid = attn_mhc.post_block(X_l, F_attn, attn_ctx) # (T, n_hc, H)
|
||
# Diagnostic: check mHC is stabilizing the residual
|
||
if False: # Disable diagnostics for production run
|
||
B_l, C_l = attn_ctx
|
||
print(f" L{li} attn: |X_l|={X_l.abs().max().item():.2f} |F_attn|={F_attn.abs().max().item():.2f} |B|={B_l.abs().max().item():.4f} |C|={C_l.abs().max().item():.4f} |X_mid|={X_mid.abs().max().item():.2f}", flush=True)
|
||
|
||
# ==================================================================
|
||
# FFN SUB-BLOCK
|
||
# ==================================================================
|
||
|
||
# -- mHC pre_block (FFN) --
|
||
x_ffn, ffn_ctx = ffn_mhc.pre_block(X_mid) # (T, H)
|
||
|
||
# -- RMSNorm (pre-norm before FFN) --
|
||
x_ffn_normed = ffn_norm.forward(x_ffn) # (T, H) BF16
|
||
|
||
# -- MoE + shared expert --
|
||
F_ffn = moe_forward(x_ffn_normed, w, li, cfg, token_id, device)
|
||
|
||
# -- mHC post_block (FFN) --
|
||
X_next = ffn_mhc.post_block(X_mid, F_ffn, ffn_ctx) # (T, n_hc, H)
|
||
if False: # diag disabled
|
||
B_l_ffn, C_l_ffn = ffn_ctx
|
||
print(f" L{li} ffn: |X_mid|={X_mid.abs().max().item():.2f} |F_ffn|={F_ffn.abs().max().item():.2f} |B|={B_l_ffn.abs().max().item():.4f} |C|={C_l_ffn.abs().max().item():.4f} |X_next|={X_next.abs().max().item():.2f}", flush=True)
|
||
|
||
return X_next
|
||
|
||
|
||
# =====================================================================
|
||
# MoE forward — hash + dense routing, SwiGLU with clamping
|
||
# =====================================================================
|
||
|
||
def moe_forward(x, w, li, cfg, token_id, device):
|
||
"""Run routed MoE + shared expert.
|
||
|
||
x: (T, H) BF16 — post-RMSNorm FFN input
|
||
Returns: (T, H) BF16
|
||
"""
|
||
H = cfg["hidden_size"]
|
||
n_experts = cfg["n_routed_experts"]
|
||
top_k = cfg.get("num_experts_per_tok", 6)
|
||
routed_scaling = cfg.get("routed_scaling_factor", 2.5)
|
||
swiglu_limit = cfg.get("swiglu_limit", 10.0)
|
||
mlp_inter = cfg["moe_intermediate_size"]
|
||
is_hash = li < 3
|
||
|
||
# ---- Routing ----
|
||
expert_ids = None
|
||
expert_weights = None
|
||
|
||
if is_hash:
|
||
tid2eid_key = f"model.layers.{li}.mlp.gate.tid2eid"
|
||
if tid2eid_key in w:
|
||
tid2eid = w[tid2eid_key]
|
||
tid = token_id.item() if token_id.numel() == 1 else token_id[0].item()
|
||
expert_ids = tid2eid[tid] # (top_k,) int64
|
||
expert_weights = torch.ones(top_k, dtype=torch.float32, device=x.device) / top_k
|
||
else:
|
||
# Fallback: use dense routing even for hash layers
|
||
is_hash = False
|
||
|
||
if not is_hash:
|
||
# Dense routing: sqrt(softplus(X @ W_gate)) + e_bias for selection
|
||
gate_w = w[f"model.layers.{li}.mlp.gate.weight"] # (H, n_experts) BF16
|
||
logits = torch.nn.functional.linear(x, gate_w.bfloat16()) # (T, n_experts)
|
||
# Activation: sqrt(softplus(logits))
|
||
activated = torch.sqrt(torch.nn.functional.softplus(logits.float()) + 1e-6)
|
||
# e_bias: learned per-expert bias for SELECTION ONLY (not in weights)
|
||
e_bias_key = f"model.layers.{li}.mlp.gate.e_bias"
|
||
if e_bias_key in w:
|
||
activated = activated + w[e_bias_key].float().unsqueeze(0)
|
||
# Top-k
|
||
scores, indices = activated.topk(top_k, dim=-1) # (T, top_k)
|
||
# Renormalize on UNBIASED activation (no e_bias in weights)
|
||
unbiased = torch.sqrt(torch.nn.functional.softplus(logits.float()) + 1e-6)
|
||
unbiased_scores = torch.gather(unbiased, -1, indices)
|
||
expert_weights = unbiased_scores / unbiased_scores.sum(dim=-1, keepdim=True)
|
||
# For T=1 decode, squeeze
|
||
if x.shape[0] == 1:
|
||
expert_ids = indices[0]
|
||
expert_weights = expert_weights[0]
|
||
else:
|
||
raise NotImplementedError("Multi-token MoE routing")
|
||
|
||
# ---- Run selected experts ----
|
||
T = x.shape[0]
|
||
expert_outputs = []
|
||
for i, eid in enumerate(expert_ids):
|
||
eid_int = eid.item()
|
||
epre = f"model.layers.{li}.mlp.experts.{eid_int}"
|
||
|
||
gate = nvfp4_linear(x,
|
||
w[f"{epre}.gate_proj.weight"],
|
||
w[f"{epre}.gate_proj.weight_scale"],
|
||
w[f"{epre}.gate_proj.weight_scale_2"])
|
||
up = nvfp4_linear(x,
|
||
w[f"{epre}.up_proj.weight"],
|
||
w[f"{epre}.up_proj.weight_scale"],
|
||
w[f"{epre}.up_proj.weight_scale_2"])
|
||
|
||
# SwiGLU with clamping (paper §4.2.3)
|
||
silu_out = torch.nn.functional.silu(gate.float())
|
||
if swiglu_limit is not None:
|
||
silu_out = silu_out.clamp(-swiglu_limit, swiglu_limit)
|
||
up_clamped = up.float().clamp(-swiglu_limit, swiglu_limit)
|
||
else:
|
||
up_clamped = up.float()
|
||
hidden = (silu_out * up_clamped).bfloat16()
|
||
|
||
down = nvfp4_linear(hidden,
|
||
w[f"{epre}.down_proj.weight"],
|
||
w[f"{epre}.down_proj.weight_scale"],
|
||
w[f"{epre}.down_proj.weight_scale_2"])
|
||
expert_outputs.append(down)
|
||
|
||
# Weighted combine + scaling
|
||
routed_out = torch.zeros_like(x)
|
||
for i, (out, wt) in enumerate(zip(expert_outputs, expert_weights)):
|
||
routed_out = routed_out + (out.float() * wt).bfloat16()
|
||
routed_out = (routed_out.float() * routed_scaling).bfloat16()
|
||
|
||
# ---- Shared expert ----
|
||
se_pre = f"model.layers.{li}.mlp.shared_experts"
|
||
se_gate_key = f"{se_pre}.gate_proj.weight"
|
||
if se_gate_key in w:
|
||
gate = nvfp4_linear(x,
|
||
w[se_gate_key],
|
||
w[f"{se_pre}.gate_proj.weight_scale"],
|
||
w[f"{se_pre}.gate_proj.weight_scale_2"])
|
||
up = nvfp4_linear(x,
|
||
w[f"{se_pre}.up_proj.weight"],
|
||
w[f"{se_pre}.up_proj.weight_scale"],
|
||
w[f"{se_pre}.up_proj.weight_scale_2"])
|
||
silu_out = torch.nn.functional.silu(gate.float())
|
||
if swiglu_limit is not None:
|
||
silu_out = silu_out.clamp(-swiglu_limit, swiglu_limit)
|
||
up_clamped = up.float().clamp(-swiglu_limit, swiglu_limit)
|
||
else:
|
||
up_clamped = up.float()
|
||
hidden = (silu_out * up_clamped).bfloat16()
|
||
shared_out = nvfp4_linear(hidden,
|
||
w[f"{se_pre}.down_proj.weight"],
|
||
w[f"{se_pre}.down_proj.weight_scale"],
|
||
w[f"{se_pre}.down_proj.weight_scale_2"])
|
||
else:
|
||
shared_out = torch.zeros_like(x)
|
||
|
||
return routed_out + shared_out
|
||
|
||
|
||
# =====================================================================
|
||
# Main
|
||
# =====================================================================
|
||
|
||
def main():
|
||
t_start = time.time()
|
||
print("=" * 70)
|
||
print("DSV4 Single-Shot Inference — Full Pipeline (mHC+Attn+MoE)")
|
||
print(" Proper Sinkhorn mHC, RMSNorm, inverse RoPE, production FMHA")
|
||
print("=" * 70)
|
||
|
||
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
|
||
cfg = json.load(f)
|
||
n_layers = cfg["num_hidden_layers"]
|
||
H = cfg["hidden_size"]
|
||
n_h = cfg["num_attention_heads"]
|
||
hd = cfg["head_dim"]
|
||
rd = cfg.get("qk_rope_head_dim", cfg.get("rope_dim", 64))
|
||
n_hc = 4
|
||
print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}")
|
||
print(f"Experts: {cfg['n_routed_experts']}, top-{cfg.get('num_experts_per_tok', 6)}")
|
||
|
||
# ==== Phase 1: Load weights to CPU ====
|
||
print(f"\n{'='*70}\nPhase 1: Loading weights to CPU\n{'='*70}")
|
||
all_weights = load_weights_to_cpu(CHECKPOINT_DIR)
|
||
t_loaded = time.time()
|
||
print(f"Weight loading: {t_loaded - t_start:.1f}s")
|
||
|
||
# ==== Build mHC blocks + RMSNorms (small weights, keep on GPU) ====
|
||
print("Building mHC blocks and RMSNorms...")
|
||
attn_mhc_blocks = {}
|
||
ffn_mhc_blocks = {}
|
||
attn_norms = {}
|
||
ffn_norms = {}
|
||
for li in range(n_layers):
|
||
gpu = li % NUM_GPUS
|
||
dev = f"cuda:{gpu}"
|
||
|
||
# mHC blocks (small weights: fn (24, 28672) FP32 ≈ 2.6MB each)
|
||
for prefix, blocks in [(f"model.layers.{li}.attn_hc", attn_mhc_blocks),
|
||
(f"model.layers.{li}.ffn_hc", ffn_mhc_blocks)]:
|
||
fn_key = f"{prefix}.fn"
|
||
base_key = f"{prefix}.base"
|
||
scale_key = f"{prefix}.scale"
|
||
if fn_key in all_weights and base_key in all_weights and scale_key in all_weights:
|
||
mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
|
||
mhc.load_from_checkpoint(
|
||
all_weights[fn_key], all_weights[base_key], all_weights[scale_key])
|
||
blocks[li] = mhc
|
||
else:
|
||
print(f" WARNING: no mHC weights for {prefix}, using identity fallback")
|
||
mhc = mHCBlock(hidden_dim=H, n_hc=n_hc, device=dev)
|
||
n = n_hc
|
||
K = n * H
|
||
mhc.W_stacked = torch.zeros(n + n*n + n, K, dtype=torch.float32, device=dev)
|
||
mhc.S_pre = torch.zeros(1, n, dtype=torch.float32, device=dev)
|
||
mhc.S_res = torch.eye(n, dtype=torch.float32, device=dev)
|
||
mhc.S_post = torch.ones(n, 1, dtype=torch.float32, device=dev) * 0.5
|
||
mhc.alpha_pre = 0.01
|
||
mhc.alpha_res = 0.01
|
||
mhc.alpha_post = 0.01
|
||
blocks[li] = mhc
|
||
|
||
# RMSNorms
|
||
attn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=dev)
|
||
an_key = f"model.layers.{li}.input_layernorm.weight"
|
||
if an_key in all_weights:
|
||
attn_norm.weight = all_weights[an_key].to(device=dev, dtype=torch.float32)
|
||
attn_norms[li] = attn_norm
|
||
|
||
ffn_norm = RMSNorm(H, eps=cfg.get('rms_norm_eps', 1e-6), device=dev)
|
||
fn_key = f"model.layers.{li}.post_attention_layernorm.weight"
|
||
if fn_key in all_weights:
|
||
ffn_norm.weight = all_weights[fn_key].to(device=dev, dtype=torch.float32)
|
||
ffn_norms[li] = ffn_norm
|
||
|
||
print(f" attn mHC: {len(attn_mhc_blocks)}, ffn mHC: {len(ffn_mhc_blocks)}")
|
||
|
||
# ==== Global weights (small, keep on gpu0) ====
|
||
torch.cuda.set_device(0)
|
||
embed_w = all_weights.get("model.embed_tokens.weight")
|
||
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to('cuda:0'))
|
||
lm_w = all_weights.get("lm_head.weight", embed_w).bfloat16().to('cuda:0')
|
||
final_norm_w = all_weights.get("model.norm.weight")
|
||
if final_norm_w is not None:
|
||
final_norm_w = final_norm_w.to('cuda:0')
|
||
rope_caches = {g: build_rope_cache(8192, rd, f"cuda:{g}") for g in range(NUM_GPUS)}
|
||
|
||
# ==== KV caches (one per layer on its GPU) ====
|
||
kv_caches = {}
|
||
for li in range(n_layers):
|
||
kv_caches[li] = SimpleKVCache(head_dim=hd, max_seq=8192, device=f"cuda:{li % NUM_GPUS}")
|
||
|
||
# ==== Phase 2: Compile FMHA ====
|
||
print(f"\n{'='*70}\nPhase 2: JIT compiling\n{'='*70}")
|
||
from dsv4.kernels.attention.production import dsv4_attention
|
||
torch.cuda.set_device(0)
|
||
dummy_q = torch.randn(n_h, 1, hd, dtype=torch.bfloat16, device='cuda:0')
|
||
dummy_k = torch.randn(1, 1, hd, dtype=torch.bfloat16, device='cuda:0')
|
||
try:
|
||
_ = dsv4_attention(dummy_q, dummy_k, dummy_k.clone())
|
||
print(" FMHA: compiled OK")
|
||
except Exception as e:
|
||
print(f" FMHA error: {e}")
|
||
t_compiled = time.time()
|
||
print(f"Compile: {t_compiled - t_loaded:.1f}s")
|
||
|
||
# ==== Phase 3: Inference ====
|
||
print(f"\n{'='*70}\nPhase 3: Inference\n{'='*70}")
|
||
from transformers import AutoTokenizer
|
||
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
|
||
input_ids = tokenizer.encode(PROMPT, return_tensors="pt").cuda()
|
||
print(f"Prompt: '{PROMPT}' → {input_ids.tolist()}")
|
||
|
||
generated = input_ids[0].tolist()
|
||
|
||
# ==== Prefill: process prompt tokens to fill KV cache ====
|
||
print(f"Prefilling {len(generated)} prompt tokens...")
|
||
for prefill_idx, tid_val in enumerate(generated):
|
||
t0 = time.time()
|
||
tid = torch.tensor([tid_val], dtype=torch.long, device='cuda:0')
|
||
positions = torch.tensor([prefill_idx], dtype=torch.long, device='cuda:0')
|
||
emb = embed(tid) # (1, H) on gpu0
|
||
X = mHCBlock.init_state(emb, n_hc) # (1, n_hc, H)
|
||
|
||
for li in range(n_layers):
|
||
gpu = li % NUM_GPUS
|
||
dev = f"cuda:{gpu}"
|
||
if X.device != torch.device(dev):
|
||
X = X.to(dev)
|
||
torch.cuda.set_device(gpu)
|
||
|
||
# Fetch this layer's weights from CPU → GPU (streamed, not all at once)
|
||
w = get_layer_weights(all_weights, li, dev)
|
||
|
||
attn_mhc = attn_mhc_blocks.get(li)
|
||
ffn_mhc = ffn_mhc_blocks.get(li)
|
||
a_norm = attn_norms[li]
|
||
f_norm = ffn_norms[li]
|
||
rc, rs = rope_caches[gpu]
|
||
X = forward_layer(X, w, li, cfg, rc, rs,
|
||
attn_mhc, ffn_mhc, a_norm, f_norm,
|
||
kv_caches[li], tid, positions)
|
||
# Free per-layer GPU weights to save memory
|
||
del w
|
||
|
||
X = X.to('cuda:0')
|
||
torch.cuda.set_device(0)
|
||
if prefill_idx == 0:
|
||
print(f" Token 0: {time.time()-t0:.1f}s (includes per-layer weight transfer)")
|
||
|
||
print(f" Prefill done ({len(generated)} tokens, {time.time()-t_compiled:.1f}s)")
|
||
|
||
# ==== Decode: generate new tokens ====
|
||
print(f"\nDecoding (max {MAX_NEW_TOKENS} new tokens)...")
|
||
all_tokens = generated.copy()
|
||
|
||
for step in range(MAX_NEW_TOKENS):
|
||
t0 = time.time()
|
||
tid = torch.tensor([all_tokens[-1]], dtype=torch.long, device='cuda:0')
|
||
decode_pos = len(all_tokens) - 1
|
||
positions = torch.tensor([decode_pos], dtype=torch.long, device='cuda:0')
|
||
|
||
emb = embed(tid) # (1, H) on gpu0
|
||
X = mHCBlock.init_state(emb, n_hc) # (1, n_hc, H)
|
||
|
||
for li in range(n_layers):
|
||
gpu = li % NUM_GPUS
|
||
dev = f"cuda:{gpu}"
|
||
if X.device != torch.device(dev):
|
||
X = X.to(dev)
|
||
torch.cuda.set_device(gpu)
|
||
|
||
w = get_layer_weights(all_weights, li, dev)
|
||
|
||
attn_mhc = attn_mhc_blocks.get(li)
|
||
ffn_mhc = ffn_mhc_blocks.get(li)
|
||
a_norm = attn_norms[li]
|
||
f_norm = ffn_norms[li]
|
||
rc, rs = rope_caches[gpu]
|
||
X = forward_layer(X, w, li, cfg, rc, rs,
|
||
attn_mhc, ffn_mhc, a_norm, f_norm,
|
||
kv_caches[li], tid, positions)
|
||
del w
|
||
|
||
X = X.to('cuda:0')
|
||
torch.cuda.set_device(0)
|
||
|
||
# Read out stream 0 → RMSNorm → lm_head
|
||
x_out = X[:, 0, :] # (1, H)
|
||
if final_norm_w is not None:
|
||
xf = x_out.float()
|
||
rms = xf.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
|
||
x_out = (xf * rms * final_norm_w.float()).bfloat16()
|
||
|
||
logits = torch.nn.functional.linear(x_out, lm_w)
|
||
# Top-5 predictions for debugging
|
||
top5_vals, top5_ids = torch.topk(logits[0], 5)
|
||
top5_str = ' '.join([f'{tokenizer.decode([tid.item()])}({val.item():.1f})' for tid, val in zip(top5_ids, top5_vals)])
|
||
next_id = torch.argmax(logits, dim=-1).item()
|
||
generated.append(next_id)
|
||
all_tokens.append(next_id)
|
||
|
||
tok_str = tokenizer.decode([next_id])
|
||
dt = time.time() - t0
|
||
has_nan = torch.isnan(logits.float()).any().item()
|
||
has_inf = torch.isinf(logits.float()).any().item()
|
||
lmin, lmax = logits.float().min().item(), logits.float().max().item()
|
||
x_max = X.abs().max().item()
|
||
print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s) "
|
||
f"logits=[{lmin:.1f},{lmax:.1f}] nan={has_nan} inf={has_inf} "
|
||
f"|X|={x_max:.3f} top5: {top5_str}")
|
||
|
||
if has_nan or has_inf:
|
||
print(" Numerical issue — stopping")
|
||
break
|
||
if next_id == tokenizer.eos_token_id:
|
||
break
|
||
|
||
out = tokenizer.decode(generated, skip_special_tokens=True)
|
||
total = time.time() - t_start
|
||
print(f"\n{'='*70}")
|
||
print(f"Input: '{PROMPT}'")
|
||
print(f"Output: '{out}'")
|
||
print(f"Total: {total:.1f}s")
|
||
print(f"{'='*70}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|