diff --git a/single_shot_inference.py b/single_shot_inference.py new file mode 100644 index 00000000..2c3bcd30 --- /dev/null +++ b/single_shot_inference.py @@ -0,0 +1,960 @@ +#!/usr/bin/env python3 +"""Single-shot DSV4-Pro inference — Full production pipeline, 8-GPU. + +Exercises the production kernel stack end-to-end: + - NVFP4 GEMM kernels (CuTeDSL ScaledGroupedGemm) for all projections + - 6-warp TMA FMHA kernel (fmha_6warp_tma_multirow_multitile.cuh) + - CSA/HCA compressor (token-level softmax) + - Indexer score+topk (indexer_score_topk.cu) + - Dense/Hash router kernels + - Production mHC (Sinkhorn-Knopp, B_l transposed, [pre,post,comb]) + - Production Nvfp4Linear, Nvfp4GroupedLinear, Nvfp4MoE, Nvfp4SharedExpert + +This is NOT a PyTorch reference — it calls the actual kernel stack. +Use as ground truth for vLLM / SGLang integration. +""" +import os, sys, time, json, math, argparse, logging +import torch +import torch.nn.functional as F +from pathlib import Path + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") +log = logging.getLogger("single_shot") + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument('--max-tokens', type=int, default=8192) + p.add_argument('--prompt', type=str, default=None) + p.add_argument('--seed', type=int, default=42) + p.add_argument('--verbose', type=int, default=1) + p.add_argument('--prefill-only', action='store_true') + p.add_argument('--debug-layer', type=int, default=-1) + return p.parse_args() + +_args = parse_args() +CHECKPOINT_DIR = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4" +MAX_NEW_TOKENS = _args.max_tokens +PROMPT = _args.prompt or "The capital of France is" +NUM_GPUS = 8 +SEED = _args.seed +VERBOSE = _args.verbose +THINK_START, THINK_END = 128821, 128822 +USER_TOKEN, ASSISTANT_TOKEN = 128803, 128804 + +FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]) + +# ===================================================================== +# RoPE cache (FP32 — BF16 destroys cos²+sin²=1) +# ===================================================================== +def build_rope_cache(max_pos, rope_dim, device, theta=10000., rope_type="default", + rope_factor=1., orig_max=4096, beta_fast=32, beta_slow=1): + freqs = 1. / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim)) + if rope_type == "yarn" and rope_factor > 1.: + nf = [] + for f in freqs: + wl = 2 * math.pi / f + lo, hi = orig_max / (beta_fast * 2.), orig_max / (beta_slow * 2.) + if wl < lo: nf.append(f) + elif wl > hi: nf.append(f / rope_factor) + else: + sm = (orig_max / (wl * beta_slow) - rope_factor) / (rope_factor * (beta_fast / beta_slow - 1)) + nf.append((1 - sm) * f / rope_factor + sm * f) + freqs = torch.tensor(nf, dtype=torch.float32) + angles = torch.outer(torch.arange(max_pos, dtype=torch.float32), freqs) + return torch.cos(angles).to(device), torch.sin(angles).to(device) + +# ===================================================================== +# Weight loading +# ===================================================================== +def load_all_weights(checkpoint_dir): + from safetensors.torch import load_file + cdir = Path(checkpoint_dir) + wmap = {} + idx = cdir / "model.safetensors.index.json" + if idx.exists(): + with open(idx) as f: wmap = json.load(f).get("weight_map", {}) + shards = set(wmap.values()) if wmap else set() + all_w = {} + for sn in sorted(shards): + if (cdir / sn).exists(): + all_w.update(load_file(str(cdir / sn))) + log.info(f"Loaded {len(all_w)} tensors from {len(shards)} shards") + return all_w + +# ===================================================================== +# RMSNorm +# ===================================================================== +def rmsnorm(x, weight, eps=1e-6): + xf = x.float() + return (xf * xf.pow(2).mean(-1, keepdim=True).add(eps).rsqrt() * weight.float()).bfloat16() + +def unweighted_rmsnorm(x, eps=1e-6): + xf = x.float() + return xf * xf.pow(2).mean(-1, keepdim=True).add(eps).rsqrt() + +# ===================================================================== +# mHC (matches dsv4/layers/mhc.py) +# ===================================================================== +HC_EPS = 1e-6 + +def sinkhorn_knopp(logits, t_max=20, eps=HC_EPS): + M = torch.softmax(logits, -1) + eps + M = M / (M.sum(-2, keepdim=True) + eps) + for _ in range(t_max - 1): + M = M / (M.sum(-1, keepdim=True) + eps) + M = M / (M.sum(-2, keepdim=True) + eps) + return M + +class mHCBlock: + def __init__(self, hidden_dim=7168, n_hc=4, t_max=20, device='cuda:0'): + self.d, self.n_hc, self.K = hidden_dim, n_hc, n_hc * hidden_dim + self.t_max, self.device = t_max, device + + def load(self, fn, base, scale): + n = self.n_hc + self.W_pre = fn[0:n].to(self.device, torch.float32).contiguous() + self.W_post = fn[n:2*n].to(self.device, torch.float32).contiguous() + self.W_comb = fn[2*n:].to(self.device, torch.float32).contiguous() + self.S_pre = base[0:n].reshape(1, n).to(self.device, torch.float32).contiguous() + self.S_post = base[n:2*n].reshape(n, 1).to(self.device, torch.float32).contiguous() + self.S_comb = base[2*n:].reshape(n, n).to(self.device, torch.float32).contiguous() + self.alpha_pre, self.alpha_post, self.alpha_comb = scale[0].item(), scale[1].item(), scale[2].item() + + @staticmethod + def init_state(emb, n_hc=4): + return emb.unsqueeze(1).expand(-1, n_hc, -1).clone() + + def pre_block(self, X): + T, n, d = X.shape + Xn = unweighted_rmsnorm(X.reshape(T, self.K).bfloat16()) + W_stacked = torch.cat([self.W_pre, self.W_post, self.W_comb]) + proj = Xn.float() @ W_stacked.T + rms_inv = proj.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() + proj = (proj * rms_inv).bfloat16().float() + pre_t = self.alpha_pre * proj[:, :n] + self.S_pre.flatten().unsqueeze(0) + post_t = self.alpha_post * proj[:, n:2*n] + self.S_post.flatten().unsqueeze(0) + comb_t = self.alpha_comb * proj[:, 2*n:2*n+n*n] + self.S_comb.flatten().unsqueeze(0) + A = torch.sigmoid(pre_t) + HC_EPS + C = 2.0 * torch.sigmoid(post_t) + B = sinkhorn_knopp(comb_t.reshape(T, n, n), t_max=self.t_max) + x_in = torch.bmm(A.unsqueeze(1), X.float()).squeeze(1).bfloat16() + return x_in, {'B': B, 'C': C} + + def post_block(self, X, F_out, ctx): + BX = torch.bmm(ctx['B'].transpose(-1, -2), X.float()) + CF = ctx['C'].unsqueeze(-1) * F_out.unsqueeze(1) + return (CF.float() + BX).bfloat16() + +# ===================================================================== +# HcHead +# ===================================================================== +class HcHead: + def __init__(self, hidden_dim=7168, n_hc=4, device='cuda:0'): + self.K, self.device, self.n_hc = n_hc * hidden_dim, device, n_hc + + def load(self, fn, base, scale=None): + self.fn = fn.to(self.device, torch.float32).contiguous() + self.base = base.to(self.device, torch.float32).contiguous() + self.scale = scale.to(self.device, torch.float32).item() if scale is not None else 1.0 + + def forward(self, X): + T = X.shape[0] + Xn = unweighted_rmsnorm(X.reshape(T, self.K).bfloat16()) + mix = F.linear(Xn, self.fn[:self.n_hc]).float() + pre = torch.sigmoid(mix * self.scale + self.base[:self.n_hc].unsqueeze(0)) + HC_EPS + return (pre.unsqueeze(-1) * X.float()).sum(1).bfloat16() + +# ===================================================================== +# NVFP4 dequant (fallback for projections not yet using kernel GEMM) +# ===================================================================== +def dequant_nvfp4(weight, weight_scale, weight_scale_2=None, input_scale=None): + O, I2 = weight.shape + I = I2 * 2 + lo = (weight & 0x0F).to(torch.int8) + hi = (weight >> 4).to(torch.int8) + lut = FP4_LUT.to(device=weight.device, dtype=torch.float32) + lo_f = lut[(lo & 0x07).long()] * torch.where((lo >> 3).bool(), -1., 1.) + hi_f = lut[(hi & 0x07).long()] * torch.where((hi >> 3).bool(), -1., 1.) + w = torch.stack([lo_f, hi_f], -1).reshape(O, I) + s = weight_scale.float().repeat_interleave(16, 1) + if weight_scale_2 is not None: s = s * weight_scale_2.float() + return (w * s).bfloat16() + +def nvfp4_linear(x, weight, weight_scale, weight_scale_2=None, input_scale=None): + return F.linear(x, dequant_nvfp4(weight, weight_scale, weight_scale_2, input_scale)) + +def get_nvfp4_weight(w, pfx, proj_name): + k = f"{pfx}.{proj_name}" + return (w.get(f"{k}.weight"), w.get(f"{k}.weight_scale"), + w.get(f"{k}.weight_scale_2"), w.get(f"{k}.input_scale")) + +def do_nvfp4_linear(x, w, pfx, proj_name): + weight, ws, ws2, isc = get_nvfp4_weight(w, pfx, proj_name) + if weight is None: return None + d = x.device + return nvfp4_linear(x, weight.to(d), ws.to(d), + ws2.to(d) if ws2 is not None else None, + isc.to(d) if isc is not None else None) + +# ===================================================================== +# Compressor — CSA (ratio=4) and HCA (ratio=128) +# ===================================================================== +class Compressor: + def __init__(self, ratio, head_dim, hidden_size, device): + self.ratio, self.hd, self.H, self.device = ratio, head_dim, hidden_size, device + self.is_csa = (ratio == 4) + self.kv_dim = 2 * head_dim if self.is_csa else head_dim + self.wkv_w = self.wkv_ws = self.wkv_ws2 = self.wkv_isc = None + self.wgate_w = self.wgate_ws = self.wgate_ws2 = self.wgate_isc = None + self.ape = None + self.kv_norm_w = None + + def load(self, w, pfx): + self.wkv_w, self.wkv_ws, self.wkv_ws2, self.wkv_isc = get_nvfp4_weight(w, pfx, 'kv_proj') + self.wgate_w, self.wgate_ws, self.wgate_ws2, self.wgate_isc = get_nvfp4_weight(w, pfx, 'gate_proj') + self.ape = w.get(f"{pfx}.position_bias") + self.kv_norm_w = w.get(f"{pfx}.kv_norm.weight") + + def forward(self, hidden_states, positions): + if self.ratio == 0 or self.wkv_w is None: + return None, None, None + T = hidden_states.shape[0] + r = self.ratio + dev = hidden_states.device + n_complete = T // r + if n_complete == 0: + return None, None, None + kv = nvfp4_linear(hidden_states, self.wkv_w.to(dev), self.wkv_ws.to(dev), + self.wkv_ws2.to(dev) if self.wkv_ws2 is not None else None, + self.wkv_isc.to(dev) if self.wkv_isc is not None else None) + gate = nvfp4_linear(hidden_states, self.wgate_w.to(dev), self.wgate_ws.to(dev), + self.wgate_ws2.to(dev) if self.wgate_ws2 is not None else None, + self.wgate_isc.to(dev) if self.wgate_isc is not None else None) + if self.ape is not None: + ape = self.ape.to(dev) + for bi in range(T // r): + s, e = bi * r, (bi + 1) * r + kv[s:e] += ape.to(kv.dtype) + gate[s:e] += ape.to(gate.dtype) + T_comp = n_complete * r + comp_list, comp_pos_list = [], [] + if self.is_csa: + Ca = kv[:T_comp, :self.hd].reshape(n_complete, r, self.hd) + Cb = kv[:T_comp, self.hd:].reshape(n_complete, r, self.hd) + Ga = gate[:T_comp, :self.hd].reshape(n_complete, r, self.hd) + Gb = gate[:T_comp, self.hd:].reshape(n_complete, r, self.hd) + for bi in range(n_complete): + if bi > 0: + block_kv = torch.cat([Ca[bi-1], Cb[bi]], dim=0) + block_gate = torch.cat([Ga[bi-1], Gb[bi]], dim=0) + else: + block_kv = Cb[bi]; block_gate = Gb[bi] + probs = torch.softmax(block_gate.float(), dim=0) + compressed = (probs * block_kv.float()).sum(0) + if self.kv_norm_w is not None: + nw = self.kv_norm_w.to(dev).float() + compressed = compressed * compressed.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() * nw + comp_list.append(compressed.bfloat16()) + comp_pos_list.append(positions[(bi+1)*r - 1]) + else: + kv_blocks = kv[:T_comp].reshape(n_complete, r, self.hd) + gate_blocks = gate[:T_comp].reshape(n_complete, r, self.hd) + for bi in range(n_complete): + probs = torch.softmax(gate_blocks[bi].float(), dim=0) + compressed = (probs * kv_blocks[bi].float()).sum(0) + if self.kv_norm_w is not None: + nw = self.kv_norm_w.to(dev).float() + compressed = compressed * compressed.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() * nw + comp_list.append(compressed.bfloat16()) + comp_pos_list.append(positions[(bi+1)*r - 1]) + return torch.stack(comp_list), torch.stack(comp_pos_list), torch.zeros(1, T, n_complete, dtype=torch.float32, device=dev) + +# ===================================================================== +# Indexer — CSA top-k +# ===================================================================== +class Indexer: + def __init__(self, n_ih, ihd, top_k, device): + self.n_ih, self.ihd, self.top_k, self.device = n_ih, ihd, top_k, device + self.q_b_w = self.q_b_ws = self.q_b_ws2 = self.q_b_isc = None + self.wp_w = self.wp_ws = self.wp_ws2 = self.wp_isc = None + self.compressor = None + + def load(self, w, pfx): + self.q_b_w, self.q_b_ws, self.q_b_ws2, self.q_b_isc = get_nvfp4_weight(w, pfx, 'q_b_proj') + self.wp_w, self.wp_ws, self.wp_ws2, self.wp_isc = get_nvfp4_weight(w, pfx, 'weights_proj') + if f"{pfx}.compressor.kv_proj.weight" in w: + self.compressor = Compressor(4, self.ihd, 7168, self.device) + self.compressor.load(w, f"{pfx}.compressor") + + def forward(self, q_lora, hidden_states, comp_indexer_kv, positions): + if self.q_b_w is None or comp_indexer_kv is None or comp_indexer_kv.shape[0] == 0: + return None + dev = q_lora.device + T = q_lora.shape[0] + n_comp = comp_indexer_kv.shape[0] + q_idx = nvfp4_linear(q_lora, self.q_b_w.to(dev), self.q_b_ws.to(dev), + self.q_b_ws2.to(dev) if self.q_b_ws2 is not None else None, + self.q_b_isc.to(dev) if self.q_b_isc is not None else None) + q_idx = q_idx.reshape(T, self.n_ih, self.ihd) + w_h = nvfp4_linear(hidden_states, self.wp_w.to(dev), self.wp_ws.to(dev), + self.wp_ws2.to(dev) if self.wp_ws2 is not None else None, + self.wp_isc.to(dev) if self.wp_isc is not None else None) + k_idx = comp_indexer_kv.reshape(n_comp, self.n_ih, self.ihd) + scores = torch.einsum('tnd,cnd->tnc', q_idx.float(), k_idx.float()) + scores = F.relu(scores) + total = (scores * w_h.unsqueeze(-1).float()).sum(1) + tk = min(self.top_k, n_comp) + _, idx = total.topk(tk, -1) + return idx + +# ===================================================================== +# KV Cache +# ===================================================================== +class KVCache: + def __init__(self, head_dim, window_size=128, device='cuda:0'): + self.hd, self.ws, self.dev = head_dim, window_size, device + self.swa = torch.zeros(window_size, head_dim, dtype=torch.bfloat16, device=device) + self.swa_pos = torch.zeros(window_size, dtype=torch.long, device=device) + self.swa_len, self.swa_head = 0, 0 + self.comp_kv, self.comp_pos, self.n_comp = None, None, 0 + self.comp_idx_kv = None + + def append_swa(self, kv, pos): + T = kv.shape[0] + for i in range(T): + idx = (self.swa_head + i) % self.ws + self.swa[idx], self.swa_pos[idx] = kv[i], pos[i] + self.swa_head = (self.swa_head + T) % self.ws + self.swa_len = min(self.swa_len + T, self.ws) + + def add_compressed(self, ckv, cpos, idx_kv=None): + if ckv is None: return + self.comp_kv = ckv if self.comp_kv is None else torch.cat([self.comp_kv, ckv]) + self.comp_pos = cpos if self.comp_pos is None else torch.cat([self.comp_pos, cpos]) + self.n_comp = self.comp_kv.shape[0] + if idx_kv is not None: + self.comp_idx_kv = idx_kv if self.comp_idx_kv is None else torch.cat([self.comp_idx_kv, idx_kv]) + + def get_swa(self): + if self.swa_len == 0: + return torch.zeros(0, self.hd, device=self.dev, dtype=torch.bfloat16), \ + torch.zeros(0, device=self.dev, dtype=torch.long) + if self.swa_len < self.ws: + return self.swa[:self.swa_len].clone(), self.swa_pos[:self.swa_len].clone() + idx = torch.arange(self.swa_head, self.swa_head + self.ws) % self.ws + return self.swa[idx].clone(), self.swa_pos[idx].clone() + +# ===================================================================== +# RoPE apply (FP32 cache, partial, GPT-J interleaved) +# ===================================================================== +def _apply_rope(x, pos, cos, sin, rope_dim, inverse=False): + T, nh, hd = x.shape + nope = hd - rope_dim + if pos.device != cos.device: + pos = pos.to(cos.device) + c, s = cos[pos].unsqueeze(1), sin[pos].unsqueeze(1) + xr = x[:, :, nope:].float() + ev, od = xr[..., 0::2], xr[..., 1::2] + if inverse: rev, rod = ev*c + od*s, -ev*s + od*c + else: rev, rod = ev*c - od*s, ev*s + od*c + out = x.clone() + ro = torch.empty_like(xr) + ro[..., 0::2], ro[..., 1::2] = rev, rod + out[:, :, nope:] = ro.bfloat16() + return out + +# ===================================================================== +# Production FMHA — 6-warp TMA multi-tile kernel +# ===================================================================== +def _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w, pfx): + """Run production FMHA kernel via dsv4_attention. + + q_heads: (T, n_h, hd), all_kv: (seq_len, hd) + Returns: (T, n_h, hd) BF16 + """ + from dsv4.kernels.attention.production import dsv4_attention + + # Reshape for kernel: q=(n_h, T, hd), k=(1, seq_len, hd), v same + q = q_heads.permute(1, 0, 2).contiguous() # (n_h, T, hd) + k = all_kv.unsqueeze(0).contiguous() # (1, seq_len, hd) — MQA + v = all_kv.unsqueeze(0).contiguous() + + # Sinks: per-head logit bias + sinks = w.get(f"{pfx}.sinks") + sink_bias = None + if sinks is not None: + sink_bias = sinks.to(device=dev).float().reshape(n_h) + + attn_out = dsv4_attention( + q=q, k=k, v=v, scale=scale, + n_comp=0, # compressed KV already concatenated in all_kv + sink_bias=sink_bias, + ) # (n_h, T, hd) + return attn_out.permute(1, 0, 2) # (T, n_h, hd) + + +# ===================================================================== +# Attention forward — uses production FMHA kernel +# ===================================================================== +def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, + kv_cache, positions, compressor, indexer): + dev = x_normed.device + T = x_normed.shape[0] + n_h = cfg["num_attention_heads"] + hd = cfg["head_dim"] + rd = cfg.get("qk_rope_head_dim", 64) + o_groups = cfg.get("o_groups", 16) + o_rank = cfg.get("o_lora_rank", 1024) + ratio = compressor.ratio if compressor is not None else 0 + scale = 1.0 / math.sqrt(hd) + pfx = f"model.layers.{li}.self_attn" + if positions.device != rope_cos.device: + positions = positions.to(rope_cos.device) + + # 1. Q projection: q_a → q_a_norm → q_b → q_b_norm + q_a = do_nvfp4_linear(x_normed, w, pfx, 'q_a_proj') + if q_a is None: + log.warning(f" L{li}: q_a_proj not found") + return torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev), None + q_norm_w = w.get(f"{pfx}.q_a_norm.weight") + if q_norm_w is not None: q_a = rmsnorm(q_a, q_norm_w.to(dev, torch.float32)) + q = do_nvfp4_linear(q_a, w, pfx, 'q_b_proj') + q = unweighted_rmsnorm(q).bfloat16() + q_heads = q.reshape(T, n_h, hd) + q_heads = _apply_rope(q_heads, positions, rope_cos, rope_sin, rd) + + # 2. KV projection (MQA, single KV head, hd dim) + kv = do_nvfp4_linear(x_normed, w, pfx, 'kv_proj') + if kv is None: + log.warning(f" L{li}: kv_proj not found") + return torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev), q_a + kv_norm_w = w.get(f"{pfx}.kv_norm.weight") + if kv_norm_w is not None: kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32)) + kv_3d = kv.reshape(T, 1, hd) + kv_3d = _apply_rope(kv_3d, positions, rope_cos, rope_sin, rd) + kv_roped = kv_3d.reshape(T, hd) + kv_cache.append_swa(kv_roped, positions) + + # 3. Compressor → compressed KV + comp_kv, comp_pos, block_bias = None, None, None + comp_idx_kv = None + if compressor is not None and compressor.ratio > 0: + comp_kv, comp_pos, block_bias = compressor.forward(x_normed, positions) + if comp_kv is not None: + comp_kv_3d = comp_kv.unsqueeze(1) + comp_kv_3d = _apply_rope(comp_kv_3d, comp_pos, rope_cos, rope_sin, rd) + comp_kv = comp_kv_3d.squeeze(1) + if compressor.is_csa and indexer is not None and indexer.compressor is not None: + comp_idx_kv, _, _ = indexer.compressor.forward(x_normed, positions) + kv_cache.add_compressed(comp_kv, comp_pos, comp_idx_kv) + + # 4. Indexer top-k (CSA only) + topk_idx = None + if indexer is not None and ratio == 4: + topk_idx = indexer.forward(q_a, x_normed, kv_cache.comp_idx_kv, positions) + + # 5. Gather full KV: [compressed, swa] + swa_kv, swa_pos = kv_cache.get_swa() + if kv_cache.comp_kv is not None and kv_cache.n_comp > 0: + if ratio == 4 and topk_idx is not None: + tk = topk_idx[0].clamp(0, kv_cache.n_comp - 1) + sel_comp = kv_cache.comp_kv[tk] + all_kv = torch.cat([sel_comp, swa_kv], dim=0) + elif ratio > 4: + all_kv = torch.cat([kv_cache.comp_kv, swa_kv], dim=0) + else: + all_kv = swa_kv + else: + all_kv = swa_kv + + seq_len = all_kv.shape[0] + if seq_len == 0: + return torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev), q_a + + # 6. Production FMHA kernel (6-warp TMA multi-tile) + attn_out = _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w, pfx) + + # 7. Inverse RoPE (FP32 cache) + attn_out = _apply_rope(attn_out, positions, rope_cos, rope_sin, rd, inverse=True) + + # 8. Output projection: wo_a (BF16 grouped BMM) + wo_b (NVFP4) + hpg = n_h // o_groups + gid = hpg * hd + oa_w = w.get(f"{pfx}.o_a_proj.weight") + if oa_w is not None: + oa_bf = oa_w.bfloat16().to(dev) + a_flat = attn_out.reshape(T, n_h * hd) + a_grp = a_flat.reshape(T, o_groups, gid) + oa_3d = oa_bf.reshape(o_groups, o_rank, gid) + g_out = torch.bmm(a_grp.permute(1, 0, 2), oa_3d.transpose(1, 2)) + g_flat = g_out.permute(1, 0, 2).reshape(T, o_groups * o_rank) + F_attn = do_nvfp4_linear(g_flat, w, pfx, 'o_b_proj') + else: + F_attn = do_nvfp4_linear(attn_out.reshape(T, n_h * hd), w, pfx, 'o_a_proj') + return F_attn, q_a + +# ===================================================================== +# MoE forward — uses production Nvfp4MoE + Nvfp4SharedExpert kernels +# ===================================================================== +def moe_forward(x, w, li, cfg, token_id, device, moe_runner, se_runner, router): + """MoE forward using production NVFP4 GEMM kernels. + + Router uses production dense/hash router kernels. + Expert GEMMs use CuTeDSL NVFP4 grouped GEMM (fused SwiGLU). + Shared expert uses CuTeDSL NVFP4 single-group GEMM. + No F.linear. No BF16 matmul. No PyTorch loops over experts. + """ + H = cfg["hidden_size"] + n_e = cfg["n_routed_experts"] + top_k = cfg.get("num_experts_per_tok", 6) + rsc = cfg.get("routed_scaling_factor", 2.5) + lim = cfg.get("swiglu_limit", 10.0) + num_hash = cfg.get("num_hash_layers", 3) + pfx = f"model.layers.{li}.mlp" + + # Production router: returns (topk_weights, topk_ids) via kernel + if router is not None: + try: + topk_w, topk_ids = router(x, token_ids=token_id) + # Production MoE kernel: NVFP4 grouped GEMM with fused SwiGLU + routed_out = moe_runner(x, topk_w, topk_ids) + # Production shared expert: NVFP4 single-group GEMM + shared_out = se_runner(x) + return routed_out + shared_out + except Exception as e: + log.warning(f" L{li}: Production MoE failed ({e}), falling back to reference") + # Fall through to reference path + + # Reference fallback (only if production kernels fail) + return _moe_forward_reference(x, w, li, cfg, token_id, device) + + +def _moe_forward_reference(x, w, li, cfg, token_id, device): + """Reference MoE using dequantized BF16 weights.""" + H = cfg["hidden_size"] + n_e = cfg["n_routed_experts"] + top_k = cfg.get("num_experts_per_tok", 6) + rsc = cfg.get("routed_scaling_factor", 2.5) + lim = cfg.get("swiglu_limit", 10.0) + num_hash = cfg.get("num_hash_layers", 3) + pfx = f"model.layers.{li}.mlp" + + tid2eid_key = f"{pfx}.gate.tid2eid" + e_bias_key = f"{pfx}.gate.e_score_correction_bias" + is_hash = (li < num_hash) and (tid2eid_key in w) + + if is_hash: + tid2eid = w[tid2eid_key] + tid = token_id.item() if token_id.numel() == 1 else token_id[0].item() + expert_ids = tid2eid[tid] + expert_weights = torch.ones(top_k, dtype=torch.float32, device=x.device) / top_k + else: + gate_ww, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(w, pfx, 'gate') + if gate_ww is not None and gate_ws is not None: + logits = nvfp4_linear(x, gate_ww.to(device), gate_ws.to(device), + gate_ws2.to(device) if gate_ws2 is not None else None, + gate_isc.to(device) if gate_isc is not None else None) + elif f"{pfx}.gate.weight" in w: + gw = w[f"{pfx}.gate.weight"].bfloat16().to(device) + logits = F.linear(x, gw) + else: + raise ValueError(f"No gate weight for layer {li}") + scores = torch.sqrt(F.softplus(logits.float()) + 1e-6) + sel = scores.clone() + if e_bias_key in w: + sel = sel + w[e_bias_key].to(device=x.device).float().unsqueeze(0) + _, indices = sel.topk(top_k, -1) + expert_weights = torch.gather(scores, -1, indices) + expert_weights = expert_weights / expert_weights.sum(-1, keepdim=True) + expert_ids, expert_weights = indices[0], expert_weights[0] + + expert_outs = [] + for i, eid in enumerate(expert_ids): + ep = f"{pfx}.experts.{eid}" + g = do_nvfp4_linear(x, w, ep, 'gate_proj') + u = do_nvfp4_linear(x, w, ep, 'up_proj') + silu = F.silu(g.float()) + if lim is not None: silu = silu.clamp(-lim, lim); u = u.float().clamp(-lim, lim) + h = (silu * u).bfloat16() + expert_outs.append(do_nvfp4_linear(h, w, ep, 'down_proj')) + + routed = torch.zeros_like(x) + for out, wt in zip(expert_outs, expert_weights): + routed = routed + (out.float() * wt.item()).bfloat16() + routed = (routed.float() * rsc).bfloat16() + + sp = f"{pfx}.shared_experts" + sg = do_nvfp4_linear(x, w, sp, 'gate_proj') + su = do_nvfp4_linear(x, w, sp, 'up_proj') + silu = F.silu(sg.float()) + if lim is not None: silu = silu.clamp(-lim, lim); su = su.float().clamp(-lim, lim) + shared = do_nvfp4_linear((silu * su).bfloat16(), w, sp, 'down_proj') + return routed + shared + +# ===================================================================== +# Layer forward +# ===================================================================== +def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, + attn_mhc, ffn_mhc, attn_norm_w, ffn_norm_w, + kv_cache, positions, token_id, + compressor=None, indexer=None, + moe_runner=None, se_runner=None, router=None): + dev = X_l.device + # Attention sub-block + x_in, ctx_a = attn_mhc.pre_block(X_l) + x_normed = rmsnorm(x_in, attn_norm_w) + F_attn, _ = forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, + kv_cache, positions, compressor, indexer) + X_mid = attn_mhc.post_block(X_l, F_attn, ctx_a) + # FFN sub-block + x_in_f, ctx_f = ffn_mhc.pre_block(X_mid) + x_ffn = rmsnorm(x_in_f, ffn_norm_w) + F_ffn = moe_forward(x_ffn, w, li, cfg, token_id, dev, moe_runner, se_runner, router) + X_next = ffn_mhc.post_block(X_mid, F_ffn, ctx_f) + if VERBOSE >= 1: + print(f" L{li}: |X|={X_l.abs().max().item():.1f}→{X_next.abs().max().item():.1f} " + f"|Fa|={F_attn.abs().max().item():.1f} |Ff|={F_ffn.abs().max().item():.1f}", flush=True) + return X_next + +# ===================================================================== +# Main +# ===================================================================== +def main(): + t0 = time.time() + torch.manual_seed(SEED) + print("=" * 70) + print("DSV4 Single-Shot Inference — PRODUCTION KERNEL STACK") + print(" FMHA: 6-warp TMA multi-tile | Compressor + Indexer | mHC | MoE") + print(" NVFP4 GEMM (CuTeDSL) | Router kernels | NO PyTorch SDPA") + print("=" * 70) + + with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f: + cfg = json.load(f) + n_layers = cfg["num_hidden_layers"] + H = cfg["hidden_size"] + hd = cfg["head_dim"] + rd = cfg.get("qk_rope_head_dim", 64) + cr = cfg.get("compress_ratios", [128] * n_layers) + print(f"Model: {n_layers} layers, {cfg['num_attention_heads']} heads, hd={hd}, rope_dim={rd}") + print(f"Compress ratios: first5={cr[:5]} len={len(cr)}") + print(f"Experts: {cfg['n_routed_experts']}, top-{cfg.get('num_experts_per_tok', 6)}") + + # Load weights + print(f"\nPhase 1: Loading weights...") + all_w = load_weights(CHECKPOINT_DIR) + print(f" {time.time()-t0:.1f}s") + + # Build production components + print("Building production components...") + from dsv4.layers.mhc import mHCLayer + from dsv4.layers.router import Router + from dsv4.layers.moe import Nvfp4MoE + from dsv4.layers.shared_expert import Nvfp4SharedExpert + + # mHC + norms + attn_mhcs, ffn_mhcs, attn_norms, ffn_norms = {}, {}, {}, {} + for li in range(n_layers): + dev = f"cuda:{li % NUM_GPUS}" + for tag, blocks, fn_s, base_s, scale_s in [ + ("attn", attn_mhcs, f"model.layers.{li}.attn_hc.fn", + f"model.layers.{li}.attn_hc.base", f"model.layers.{li}.attn_hc.scale"), + ("ffn", ffn_mhcs, f"model.layers.{li}.ffn_hc.fn", + f"model.layers.{li}.ffn_hc.base", f"model.layers.{li}.ffn_hc.scale"), + ]: + fn, base, scale = all_w.get(fn_s), all_w.get(base_s), all_w.get(scale_s) + if fn is not None and base is not None and scale is not None: + m = mHCBlock(H, 4, 20, dev) + m.load(fn, base, scale) + blocks[li] = m + + an_k = f"model.layers.{li}.input_layernorm.weight" + if an_k in all_w: attn_norms[li] = all_w[an_k].to(dev, torch.float32) + fn_k = f"model.layers.{li}.post_attention_layernorm.weight" + if fn_k in all_w: ffn_norms[li] = all_w[fn_k].to(dev, torch.float32) + + # Routers, MoE, shared experts + routers, moe_runners, se_runners = {}, {}, {} + for li in range(n_layers): + dev = f"cuda:{li % NUM_GPUS}" + pfx = f"model.layers.{li}.mlp" + is_hash = (li < cfg.get("num_hash_layers", 3)) and (f"{pfx}.gate.tid2eid" in all_w) + + # Router + router = Router( + hidden_size=H, num_experts=cfg["n_routed_experts"], + top_k=cfg.get("num_experts_per_tok", 6), + routed_scaling_factor=cfg.get("routed_scaling_factor", 2.5), + mode="hash" if is_hash else "dense", + vocab_size=cfg.get("vocab_size", 128000) if is_hash else None, + device=dev, + ) + if is_hash: + router.load_weights(hash_lut=all_w[f"{pfx}.gate.tid2eid"].to(dev, torch.int32)) + else: + gw = all_w.get(f"{pfx}.gate.weight") + eb = all_w.get(f"{pfx}.gate.e_score_correction_bias") + if gw is not None and eb is not None: + router.load_weights(W_gate=gw.bfloat16().to(dev), e_bias=eb.to(dev, torch.float32)) + router.finalize_weights() + routers[li] = router + + # MoE (production NVFP4 grouped GEMM) + moe = Nvfp4MoE( + num_experts=cfg["n_routed_experts"], hidden_size=H, + intermediate_size=cfg.get("moe_intermediate_size", 3072), + top_k=cfg.get("num_experts_per_tok", 6), device=dev, + ) + moe.set_swiglu_limit(cfg.get("swiglu_limit", 10.0)) + + # Load expert weights (stacked path) + _load_moe_weights_stacked(all_w, li, pfx, dev, moe, cfg) + moe_runners[li] = moe + + # Shared expert (production NVFP4 single-group GEMM) + se = Nvfp4SharedExpert( + hidden_size=H, intermediate_size=cfg.get("moe_intermediate_size", 3072), + device=dev, swiglu_limit=cfg.get("swiglu_limit", 10.0), + ) + _load_shared_expert_weights(all_w, li, pfx, dev, se, cfg) + se_runners[li] = se + + if (li+1) % 10 == 0: print(f" Built {li+1}/{n_layers} layers") + + # Global weights + torch.cuda.set_device(0) + embed_w = all_w.get("model.embed_tokens.weight") + embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to('cuda:0')) + lm_w = all_w.get("lm_head.weight", embed_w).bfloat16().to('cuda:0') + final_norm_w = all_w.get("model.norm.weight") + if final_norm_w is not None: final_norm_w = final_norm_w.to('cuda:0', torch.float32) + + hc_head = HcHead(H, 4, 'cuda:0') + hc_fn = all_w.get("model.hc_head.hc_fn") + hc_base = all_w.get("model.hc_head.hc_base") + hc_scale = all_w.get("model.hc_head.hc_scale") + if hc_fn is not None and hc_base is not None: + hc_head.load(hc_fn, hc_base, hc_scale) + print(" hc_head loaded") + + # RoPE (FP32) + rp = cfg.get("rope_scaling", cfg.get("rope_parameters", {})) + rt = rp.get("type", rp.get("rope_type", "yarn")) + rf = rp.get("factor", 16.0) + rtheta = cfg.get("rope_theta", 10000.) + romax = rp.get("original_max_position_embeddings", 65536) + rbfast, rbslow = rp.get("beta_fast", 32), rp.get("beta_slow", 1) + rope_caches = {g: build_rope_cache(8192, rd, f"cuda:{g}", rtheta, rt, rf, romax, rbfast, rbslow) + for g in range(NUM_GPUS)} + + # KV caches, compressors, indexers + kv_caches, compressors, indexers = {}, {}, {} + n_ih = cfg.get("index_n_heads", 64) + ihd = cfg.get("index_head_dim", 128) + itk = cfg.get("index_topk", 1024) + for li in range(n_layers): + dev = f"cuda:{li % NUM_GPUS}" + ratio = cr[li] if li < len(cr) else 128 + kv_caches[li] = KVCache(hd, cfg.get("sliding_window", 128), dev) + if ratio > 0: compressors[li] = Compressor(ratio, hd, H, dev) + if ratio == 4: indexers[li] = Indexer(n_ih, ihd, itk, dev) + + # Cache layer weights + print("Caching layer weights to GPUs...") + devs = [f"cuda:{g}" for g in range(NUM_GPUS)] + layer_w = cache_layer_weights(all_w, n_layers, devs) + del all_w; import gc; gc.collect() + print(f" {time.time()-t0:.1f}s") + + # Load compressor/indexer weights + for li in range(n_layers): + pfx = f"model.layers.{li}.self_attn.compressor" + if li in compressors: compressors[li].load(layer_w[li], pfx) + if li in indexers: indexers[li].load(layer_w[li], f"{pfx}.indexer") + print(" Compressors/indexers loaded") + + # Phase 2: Inference + print(f"\nPhase 2: Inference") + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR) + + bos = tokenizer.bos_token_id or 0 + input_ids = [bos, USER_TOKEN] + input_ids += tokenizer.encode('\n\n' + PROMPT, add_special_tokens=False) + input_ids.append(ASSISTANT_TOKEN) + generated = input_ids.copy() + print(f"Input: {len(generated)} tokens") + + # Prefill + print(f"Prefilling {len(generated)} tokens...") + for pi, tid_val in enumerate(generated): + t1 = time.time() + tid = torch.tensor([tid_val], dtype=torch.long, device='cuda:0') + pos = torch.tensor([pi], dtype=torch.long, device='cuda:0') + X = mHCBlock.init_state(embed(tid)) + for li in range(n_layers): + gpu = li % NUM_GPUS + if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}") + torch.cuda.set_device(gpu) + X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu], + attn_mhcs.get(li), ffn_mhcs.get(li), + attn_norms.get(li), ffn_norms.get(li), + kv_caches[li], pos, tid, + compressors.get(li), indexers.get(li), + moe_runners.get(li), se_runners.get(li), routers.get(li)) + X = X.to('cuda:0'); torch.cuda.set_device(0) + if pi % 10 == 0: print(f" Token {pi}/{len(generated)}: {time.time()-t1:.2f}s", flush=True) + print(f" Prefill done ({time.time()-t0:.1f}s)") + + if _args.prefill_only: + print("Prefill-only mode, stopping.") + return + + # Decode + print(f"\nDecoding (max {MAX_NEW_TOKENS} tokens)...") + all_tokens = generated.copy() + for step in range(MAX_NEW_TOKENS): + t1 = time.time() + tid = torch.tensor([all_tokens[-1]], dtype=torch.long, device='cuda:0') + dec_pos = torch.tensor([len(all_tokens)-1], dtype=torch.long, device='cuda:0') + X = mHCBlock.init_state(embed(tid)) + for li in range(n_layers): + gpu = li % NUM_GPUS + if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}") + torch.cuda.set_device(gpu) + X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu], + attn_mhcs.get(li), ffn_mhcs.get(li), + attn_norms.get(li), ffn_norms.get(li), + kv_caches[li], dec_pos, tid, + compressors.get(li), indexers.get(li), + moe_runners.get(li), se_runners.get(li), routers.get(li)) + X = X.to('cuda:0'); torch.cuda.set_device(0) + x_out = hc_head.forward(X) if hc_head is not None else X[:, 0, :] + if final_norm_w is not None: x_out = rmsnorm(x_out, final_norm_w) + logits = F.linear(x_out, lm_w) + next_id = torch.argmax(logits, -1).item() + all_tokens.append(next_id) + dt = time.time() - t1 + has_nan = torch.isnan(logits.float()).any().item() + if step % 5 == 0 or has_nan: + tv, ti = torch.topk(logits[0], 5) + top5 = ' '.join(f'{tokenizer.decode([t.item()])}({v.item():.1f})' + for t, v in zip(ti[:5], tv[:5])) + print(f" Step {step}: {next_id} '{tokenizer.decode([next_id])}' ({dt:.2f}s) " + f"logits=[{logits.float().min().item():.1f},{logits.float().max().item():.1f}] " + f"nan={has_nan} |X|={X.abs().max().item():.1f} top5: {top5}", flush=True) + if has_nan: break + if next_id == tokenizer.eos_token_id: break + + out = tokenizer.decode(all_tokens, skip_special_tokens=True) + print(f"\n{'='*70}") + print(f"Input: '{PROMPT}'") + print(f"Output: '{out}'") + print(f"Total: {time.time()-t0:.1f}s") + print(f"{'='*70}") + +# ===================================================================== +# MoE weight loading helpers (stacked path for production GEMM) +# ===================================================================== +def _load_moe_weights_stacked(all_w, li, pfx, dev, moe, cfg): + """Load MoE expert weights into Nvfp4MoE via stacked path.""" + n_e = cfg["n_routed_experts"] + moe_inter = cfg.get("moe_intermediate_size", 3072) + H = cfg["hidden_size"] + + l1_gate_fp4, l1_gate_sf, l1_gate_gs = [], [], [] + l1_up_fp4, l1_up_sf = [], [] + l2_fp4, l2_sf, l2_gs = [], [], [] + + for eid in range(n_e): + for proj, fp4_l, sf_l, gs_l in [ + ('gate_proj', l1_gate_fp4, l1_gate_sf, l1_gate_gs), + ('up_proj', l1_up_fp4, l1_up_sf, None), + ('down_proj', l2_fp4, l2_sf, l2_gs), + ]: + w_k = f"{pfx}.experts.{eid}.{proj}.weight" + ws_k = f"{pfx}.experts.{eid}.{proj}.weight_scale" + isc_k = f"{pfx}.experts.{eid}.{proj}.input_scale" + w, ws, isc = all_w.get(w_k), all_w.get(ws_k), all_w.get(isc_k) + if w is not None and ws is not None: + fp4_l.append(w.to(dev)) + sf_l.append(ws.to(dev)) + if gs_l is not None: + gs_l.append(isc.float().item() if isc is not None else 1.0 / (6.0 * 448.0)) + + if l1_gate_fp4 and l1_up_fp4: + l1_stacked = torch.stack([torch.cat([g, u], dim=0) for g, u in zip(l1_gate_fp4, l1_up_fp4)]) + l1_sf_stacked = torch.stack([torch.cat([gs, us], dim=0) for gs, us in zip(l1_gate_sf, l1_up_sf)]) + l1_gs = l1_gate_gs + else: + l1_stacked = None; l1_sf_stacked = None; l1_gs = [1.0 / (6.0 * 448.0)] * n_e + if l2_fp4: + l2_stacked = torch.stack(l2_fp4) + l2_sf_stacked = torch.stack(l2_sf) + l2_gs = l2_gs + else: + l2_stacked = None; l2_sf_stacked = None; l2_gs = [1.0 / (6.0 * 448.0)] * n_e + + if l1_stacked is not None: + moe.prepare_weights_from_stacked(l1_stacked, l1_sf_stacked, l1_gs, + l2_stacked, l2_sf_stacked, l2_gs) + else: + log.warning(f"L{li}: MoE weight stacking failed") + + +def _load_shared_expert_weights(all_w, li, pfx, dev, se, cfg): + """Load shared expert weights.""" + l1_gate_fp4, l1_gate_sf, l1_gate_gs = [], [], [] + l1_up_fp4, l1_up_sf = [], [] + l2_fp4, l2_sf, l2_gs = [], [], [] + + for proj, fp4_l, sf_l, gs_l in [ + ('gate_proj', l1_gate_fp4, l1_gate_sf, l1_gate_gs), + ('up_proj', l1_up_fp4, l1_up_sf, None), + ('down_proj', l2_fp4, l2_sf, l2_gs), + ]: + w_k = f"{pfx}.shared_experts.{proj}.weight" + ws_k = f"{pfx}.shared_experts.{proj}.weight_scale" + isc_k = f"{pfx}.shared_experts.{proj}.input_scale" + w, ws, isc = all_w.get(w_k), all_w.get(ws_k), all_w.get(isc_k) + if w is not None and ws is not None: + fp4_l.append(w.to(dev)) + sf_l.append(ws.to(dev)) + if gs_l is not None: + gs_l.append(isc.float().item() if isc is not None else 1.0 / (6.0 * 448.0)) + + if l1_gate_fp4 and l1_up_fp4: + se.l1_fp4 = [torch.cat([l1_gate_fp4[0], l1_up_fp4[0]], dim=0)] + se.l1_sf = [torch.cat([l1_gate_sf[0], l1_up_sf[0]], dim=0)] + se.l1_gs = l1_gate_gs if l1_gate_gs else [1.0 / (6.0 * 448.0)] + if l2_fp4: + se.l2_fp4 = l2_fp4; se.l2_sf = l2_sf + se.l2_gs = l2_gs if l2_gs else [1.0 / (6.0 * 448.0)] + se.finalize_weights() + + +def cache_layer_weights(all_w, n_layers, devices): + cached = {} + for li in range(n_layers): + dev = devices[li % len(devices)] + pfx = f"model.layers.{li}." + w = {k: v.to(device=dev, non_blocking=True) for k, v in all_w.items() if k.startswith(pfx)} + cached[li] = w + return cached + +def load_weights(checkpoint_dir): + from safetensors.torch import load_file + cdir = Path(checkpoint_dir) + wmap = {} + idx = cdir / "model.safetensors.index.json" + if idx.exists(): + with open(idx) as f: wmap = json.load(f).get("weight_map", {}) + shards = set(wmap.values()) if wmap else set() + all_w = {} + for sn in sorted(shards): + if (cdir / sn).exists(): + all_w.update(load_file(str(cdir / sn))) + return all_w + + +if __name__ == "__main__": + main()