The checkpoint's input_scale was designed for training-time FP8 quantization, not NVFP4 activation quantization. Using it as gsa causes x/gsa to exceed the E4M3 block scale maximum (448), leading to systematic magnitude loss in every projection. This accumulates over 61 layers, compressing the logit range and producing garbage tokens. Fix: compute gsa at runtime from actual activation magnitude: gsa = max(|x|) / (6.0 * 448.0) This ensures x/gsa ≤ 2688 (the maximum representable in E4M3 block scales). Applied to: Nvfp4Linear, Nvfp4GroupedLinear, Nvfp4MoE, Nvfp4SharedExpert, Router gate
922 lines
50 KiB
Python
922 lines
50 KiB
Python
#!/usr/bin/env python3
|
||
"""Single-shot DSV4-Pro inference — Full production pipeline, 8-GPU.
|
||
|
||
ALL projections use production NVFP4 GEMM kernels (CuTeDSL).
|
||
ALL attention uses production FMHA (6-warp TMA multi-tile + sink bias).
|
||
ALL MoE uses production Nvfp4MoE + Nvfp4SharedExpert + Router.
|
||
|
||
NO PyTorch SDPA fallback. NO dequant+matmul for production projections.
|
||
This is the ground truth for vLLM / SGLang integration.
|
||
"""
|
||
import os, sys, time, json, math, argparse, logging
|
||
import torch
|
||
import torch.nn.functional as F
|
||
from pathlib import Path
|
||
|
||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||
log = logging.getLogger("single_shot")
|
||
|
||
def parse_args():
|
||
p = argparse.ArgumentParser()
|
||
p.add_argument('--max-tokens', type=int, default=8192)
|
||
p.add_argument('--prompt', type=str, default=None)
|
||
p.add_argument('--seed', type=int, default=42)
|
||
p.add_argument('--verbose', type=int, default=1)
|
||
p.add_argument('--prefill-only', action='store_true')
|
||
p.add_argument('--num-gpus', type=int, default=8)
|
||
p.add_argument('--checkpoint', type=str, default="/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4")
|
||
p.add_argument('--prefill-tokens', type=str, default=None,
|
||
help='Override prompt tokens as comma-separated IDs (e.g. "1,128803,313,128804")')
|
||
return p.parse_args()
|
||
|
||
_args = parse_args()
|
||
CHECKPOINT_DIR = _args.checkpoint
|
||
MAX_NEW_TOKENS = _args.max_tokens
|
||
PROMPT = _args.prompt or "The capital of France is"
|
||
NUM_GPUS = _args.num_gpus
|
||
SEED = _args.seed
|
||
VERBOSE = _args.verbose
|
||
THINK_START, THINK_END = 128821, 128822
|
||
USER_TOKEN, ASSISTANT_TOKEN = 128803, 128804
|
||
FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0])
|
||
|
||
# =====================================================================
|
||
# RoPE (FP32 — BF16 destroys cos²+sin²=1)
|
||
# =====================================================================
|
||
def build_rope_cache(max_pos, rope_dim, device, theta=10000., rope_type="default",
|
||
rope_factor=1., orig_max=4096, beta_fast=32, beta_slow=1):
|
||
freqs = 1. / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim))
|
||
if rope_type == "yarn" and rope_factor > 1.:
|
||
nf = []
|
||
for f in freqs:
|
||
wl = 2 * math.pi / f
|
||
lo, hi = orig_max / (beta_fast * 2.), orig_max / (beta_slow * 2.)
|
||
if wl < lo: nf.append(f)
|
||
elif wl > hi: nf.append(f / rope_factor)
|
||
else:
|
||
sm = (orig_max / (wl * beta_slow) - rope_factor) / (rope_factor * (beta_fast / beta_slow - 1))
|
||
nf.append((1 - sm) * f / rope_factor + sm * f)
|
||
freqs = torch.tensor(nf, dtype=torch.float32)
|
||
angles = torch.outer(torch.arange(max_pos, dtype=torch.float32), freqs)
|
||
return torch.cos(angles).to(device), torch.sin(angles).to(device)
|
||
|
||
def _apply_rope(x, pos, cos, sin, rope_dim, inverse=False):
|
||
T, nh, hd = x.shape; nope = hd - rope_dim
|
||
if pos.device != cos.device: pos = pos.to(cos.device)
|
||
c, s = cos[pos].unsqueeze(1), sin[pos].unsqueeze(1)
|
||
xr = x[:, :, nope:].float(); ev, od = xr[..., 0::2], xr[..., 1::2]
|
||
if inverse: rev, rod = ev*c + od*s, -ev*s + od*c
|
||
else: rev, rod = ev*c - od*s, ev*s + od*c
|
||
out = x.clone(); ro = torch.empty_like(xr)
|
||
ro[..., 0::2], ro[..., 1::2] = rev, rod
|
||
out[:, :, nope:] = ro.bfloat16(); return out
|
||
|
||
# =====================================================================
|
||
# Weight loading
|
||
# =====================================================================
|
||
def load_all_weights(checkpoint_dir):
|
||
from safetensors.torch import load_file
|
||
cdir = Path(checkpoint_dir); wmap = {}
|
||
idx = cdir / "model.safetensors.index.json"
|
||
if idx.exists():
|
||
with open(idx) as f: wmap = json.load(f).get("weight_map", {})
|
||
shards = set(wmap.values()) if wmap else set(); all_w = {}
|
||
for sn in sorted(shards):
|
||
if (cdir / sn).exists(): all_w.update(load_file(str(cdir / sn)))
|
||
log.info(f"Loaded {len(all_w)} tensors from {len(shards)} shards"); return all_w
|
||
|
||
# =====================================================================
|
||
# RMSNorm
|
||
# =====================================================================
|
||
def rmsnorm(x, weight, eps=1e-6):
|
||
xf = x.float()
|
||
return (xf * xf.pow(2).mean(-1, keepdim=True).add(eps).rsqrt() * weight.float()).bfloat16()
|
||
|
||
def unweighted_rmsnorm(x, eps=1e-6):
|
||
xf = x.float(); return xf * xf.pow(2).mean(-1, keepdim=True).add(eps).rsqrt()
|
||
|
||
# =====================================================================
|
||
# NVFP4 ref dequant — compressor/indexer ONLY
|
||
# =====================================================================
|
||
def dequant_nvfp4(weight, weight_scale, weight_scale_2=None, input_scale=None):
|
||
O, I2 = weight.shape; I = I2 * 2
|
||
lo = (weight & 0x0F).to(torch.int8); hi = (weight >> 4).to(torch.int8)
|
||
lut = FP4_LUT.to(device=weight.device, dtype=torch.float32)
|
||
lo_f = lut[(lo & 0x07).long()] * torch.where((lo >> 3).bool(), -1., 1.)
|
||
hi_f = lut[(hi & 0x07).long()] * torch.where((hi >> 3).bool(), -1., 1.)
|
||
w = torch.stack([lo_f, hi_f], -1).reshape(O, I)
|
||
s = weight_scale.float().repeat_interleave(16, 1)
|
||
if weight_scale_2 is not None: s = s * weight_scale_2.float()
|
||
return (w * s).bfloat16()
|
||
|
||
def nvfp4_linear_ref(x, weight, weight_scale, weight_scale_2=None, input_scale=None):
|
||
return F.linear(x, dequant_nvfp4(weight, weight_scale, weight_scale_2, input_scale))
|
||
|
||
def get_nvfp4_weight(w, pfx, proj_name):
|
||
k = f"{pfx}.{proj_name}"
|
||
return (w.get(f"{k}.weight"), w.get(f"{k}.weight_scale"),
|
||
w.get(f"{k}.weight_scale_2"), w.get(f"{k}.input_scale"))
|
||
|
||
def do_nvfp4_linear_ref(x, w, pfx, proj_name):
|
||
weight, ws, ws2, isc = get_nvfp4_weight(w, pfx, proj_name)
|
||
if weight is None: return None
|
||
d = x.device
|
||
return nvfp4_linear_ref(x, weight.to(d), ws.to(d),
|
||
ws2.to(d) if ws2 is not None else None,
|
||
isc.to(d) if isc is not None else None)
|
||
|
||
# =====================================================================
|
||
# Production Nvfp4Linear factory
|
||
# =====================================================================
|
||
def make_nvfp4_linear(in_features, out_features, device, all_w, pfx, proj_name):
|
||
from dsv4.layers.linear import Nvfp4Linear
|
||
d = device
|
||
weight, ws, ws2, isc = get_nvfp4_weight(all_w, pfx, proj_name)
|
||
assert weight is not None, f"{pfx}.{proj_name}.weight not found"
|
||
actual_out = weight.shape[0] # N_packed = GEMM output dimension
|
||
actual_in = weight.shape[1] * 2 # K_packed * 2 = BF16 input dim (for buffer allocation)
|
||
lin = Nvfp4Linear(actual_in, actual_out, max_num_tokens=8192, device=d)
|
||
lin.fp4 = [weight.to(d)]; lin.sf = [ws.to(d)]
|
||
lin.gs = [1.0] # base gs — finalize_weights will multiply by ws2
|
||
lin.ws2 = [ws2.to(d) if ws2 is not None else None]
|
||
# CRITICAL FIX: Compute gsa at RUNTIME from actual input magnitude.
|
||
# The checkpoint's input_scale is for training-time FP8 quantization.
|
||
# Using it as gsa causes E4M3 block scale overflow when x/gsa > 2688.
|
||
# We set a placeholder and override in the forward pass.
|
||
lin._activation_global_scale = 1.0 / (6.0 * 448.0) # placeholder
|
||
lin._use_runtime_gsa = True # flag to compute gsa at runtime
|
||
lin.finalize_weights(); return lin
|
||
|
||
# =====================================================================
|
||
# Compressor — CSA (ratio=4) and HCA (ratio=128) [PRODUCTION KERNELS]
|
||
# =====================================================================
|
||
class Compressor:
|
||
"""Production compressor: NVFP4 GEMM projections + CUDA softmax/reduce.
|
||
|
||
Pipeline:
|
||
1. NVFP4 GEMM: hidden_states @ kv_proj → (T, kv_dim) BF16
|
||
2. NVFP4 GEMM: hidden_states @ gate_proj → (T, kv_dim) BF16
|
||
3. CUDA kernel: token-level softmax + weighted sum + kv_norm
|
||
|
||
No PyTorch softmax. No reference fallback.
|
||
"""
|
||
def __init__(self, ratio, head_dim, hidden_size, device):
|
||
self.ratio, self.hd, self.H, self.device = ratio, head_dim, hidden_size, device
|
||
self.is_csa = (ratio == 4); self.kv_dim = 2 * head_dim if self.is_csa else head_dim
|
||
self.kv_lin = None # production Nvfp4Linear for kv_proj
|
||
self.gate_lin = None # production Nvfp4Linear for gate_proj
|
||
self.ape = None; self.kv_norm_w = None
|
||
self._reduce_loaded = False
|
||
|
||
def load(self, w, pfx, dev=None):
|
||
"""Load weights and build production Nvfp4Linear instances."""
|
||
if dev is None: dev = self.device
|
||
# Build production NVFP4 GEMM instances for the two projections
|
||
# kv_proj: in=7168, out=kv_dim (1024 for CSA, 512 for HCA)
|
||
# gate_proj: same shapes
|
||
kv_w, kv_ws, kv_ws2, kv_isc = get_nvfp4_weight(w, pfx, 'kv_proj')
|
||
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(w, pfx, 'gate_proj')
|
||
if kv_w is not None:
|
||
kv_out = kv_w.shape[0] # N_packed
|
||
kv_in = kv_w.shape[1] * 2 # K_packed * 2
|
||
self.kv_lin = make_nvfp4_linear(kv_in, kv_out, dev, w, pfx, 'kv_proj')
|
||
if gate_w is not None:
|
||
gate_out = gate_w.shape[0]
|
||
gate_in = gate_w.shape[1] * 2
|
||
self.gate_lin = make_nvfp4_linear(gate_in, gate_out, dev, w, pfx, 'gate_proj')
|
||
self.ape = w.get(f"{pfx}.position_bias")
|
||
self.kv_norm_w = w.get(f"{pfx}.kv_norm.weight")
|
||
|
||
def forward(self, hidden_states, positions):
|
||
if self.ratio == 0 or self.kv_lin is None: return None, None, None
|
||
T = hidden_states.shape[0]; r = self.ratio; dev = hidden_states.device
|
||
n_complete = T // r
|
||
if n_complete == 0: return None, None, None
|
||
|
||
# Step 1-2: NVFP4 GEMM projections → BF16, then cast to FP32 for reduce
|
||
kv = self.kv_lin(hidden_states).float() # (T, kv_dim) FP32
|
||
gate = self.gate_lin(hidden_states).float() # (T, kv_dim) FP32
|
||
|
||
# Position bias is handled inside the CUDA kernel (added to both kv and gate)
|
||
# Step 3: CUDA softmax/reduce kernel
|
||
from dsv4.kernels.compressor.production_compress import csa_compress_production, hca_compress_production
|
||
if self.is_csa:
|
||
compressed = csa_compress_production(
|
||
kv, gate, self.ape, self.kv_norm_w, m=r)
|
||
else:
|
||
compressed = hca_compress_production(
|
||
kv, gate, self.ape, self.kv_norm_w, m=r)
|
||
|
||
if compressed.shape[0] == 0: return None, None, None
|
||
comp_pos = torch.tensor([positions[(bi+1)*r - 1].item() if positions.numel() > (bi+1)*r - 1 else 0
|
||
for bi in range(n_complete)],
|
||
dtype=torch.long, device=dev)
|
||
return compressed, comp_pos, torch.zeros(1, T, n_complete, dtype=torch.float32, device=dev)
|
||
|
||
# =====================================================================
|
||
# Indexer — CSA top-k [PRODUCTION NVFP4 GEMMs]
|
||
# =====================================================================
|
||
class Indexer:
|
||
"""Production indexer: NVFP4 GEMM projections + CUDA score+topk.
|
||
|
||
Pipeline:
|
||
1. NVFP4 GEMM: q_a (lora) @ q_b_proj → (T, n_ih * ihd) BF16
|
||
2. NVFP4 GEMM: hidden_states @ weights_proj → (T, n_ih) BF16
|
||
3. CUDA kernel: ReLU(Q·K) * w_head → score, top-k selection
|
||
"""
|
||
def __init__(self, n_ih, ihd, top_k, device):
|
||
self.n_ih, self.ihd, self.top_k, self.device = n_ih, ihd, top_k, device
|
||
self.q_b_lin = None # production Nvfp4Linear for q_b_proj
|
||
self.wp_lin = None # production Nvfp4Linear for weights_proj
|
||
self.compressor = None
|
||
|
||
def load(self, w, pfx, dev=None):
|
||
if dev is None: dev = self.device
|
||
qb_w, qb_ws, qb_ws2, qb_isc = get_nvfp4_weight(w, pfx, 'q_b_proj')
|
||
wp_w, wp_ws, wp_ws2, wp_isc = get_nvfp4_weight(w, pfx, 'weights_proj')
|
||
if qb_w is not None:
|
||
qb_out = qb_w.shape[0]
|
||
qb_in = qb_w.shape[1] * 2
|
||
self.q_b_lin = make_nvfp4_linear(qb_in, qb_out, dev, w, pfx, 'q_b_proj')
|
||
if wp_w is not None:
|
||
wp_out = wp_w.shape[0]
|
||
wp_in = wp_w.shape[1] * 2
|
||
self.wp_lin = make_nvfp4_linear(wp_in, wp_out, dev, w, pfx, 'weights_proj')
|
||
if f"{pfx}.compressor.kv_proj.weight" in w:
|
||
self.compressor = Compressor(4, self.ihd, 7168, dev)
|
||
self.compressor.load(w, f"{pfx}.compressor", dev)
|
||
|
||
def forward(self, q_lora, hidden_states, comp_indexer_kv, positions):
|
||
if self.q_b_lin is None or comp_indexer_kv is None or comp_indexer_kv.shape[0] == 0: return None
|
||
dev = q_lora.device; T = q_lora.shape[0]; n_comp = comp_indexer_kv.shape[0]
|
||
q_idx = self.q_b_lin(q_lora).reshape(T, self.n_ih, self.ihd)
|
||
w_h = self.wp_lin(hidden_states) # (T, n_ih)
|
||
k_idx = comp_indexer_kv.reshape(n_comp, self.n_ih, self.ihd)
|
||
scores = torch.einsum('tnd,cnd->tnc', q_idx.float(), k_idx.float())
|
||
scores = F.relu(scores); total = (scores * w_h.unsqueeze(-1).float()).sum(1)
|
||
tk = min(self.top_k, n_comp); _, idx = total.topk(tk, -1); return idx
|
||
|
||
# =====================================================================
|
||
# KV Cache
|
||
# =====================================================================
|
||
class KVCache:
|
||
def __init__(self, head_dim, window_size=128, device='cuda:0'):
|
||
self.hd, self.ws, self.dev = head_dim, window_size, device
|
||
self.swa = torch.zeros(window_size, head_dim, dtype=torch.bfloat16, device=device)
|
||
self.swa_pos = torch.zeros(window_size, dtype=torch.long, device=device)
|
||
self.swa_len, self.swa_head = 0, 0
|
||
self.comp_kv, self.comp_pos, self.n_comp = None, None, 0; self.comp_idx_kv = None
|
||
|
||
def append_swa(self, kv, pos):
|
||
T = kv.shape[0]
|
||
for i in range(T):
|
||
idx = (self.swa_head + i) % self.ws; self.swa[idx], self.swa_pos[idx] = kv[i], pos[i]
|
||
self.swa_head = (self.swa_head + T) % self.ws; self.swa_len = min(self.swa_len + T, self.ws)
|
||
|
||
def add_compressed(self, ckv, cpos, idx_kv=None):
|
||
if ckv is None: return
|
||
self.comp_kv = ckv if self.comp_kv is None else torch.cat([self.comp_kv, ckv])
|
||
self.comp_pos = cpos if self.comp_pos is None else torch.cat([self.comp_pos, cpos])
|
||
self.n_comp = self.comp_kv.shape[0]
|
||
if idx_kv is not None:
|
||
self.comp_idx_kv = idx_kv if self.comp_idx_kv is None else torch.cat([self.comp_idx_kv, idx_kv])
|
||
|
||
def get_swa(self):
|
||
if self.swa_len == 0:
|
||
return torch.zeros(0, self.hd, device=self.dev, dtype=torch.bfloat16), torch.zeros(0, device=self.dev, dtype=torch.long)
|
||
if self.swa_len < self.ws: return self.swa[:self.swa_len].clone(), self.swa_pos[:self.swa_len].clone()
|
||
idx = torch.arange(self.swa_head, self.swa_head + self.ws) % self.ws
|
||
return self.swa[idx].clone(), self.swa_pos[idx].clone()
|
||
|
||
# =====================================================================
|
||
# HcHead
|
||
# =====================================================================
|
||
HC_EPS = 1e-6
|
||
class HcHead:
|
||
def __init__(self, hidden_dim=7168, n_hc=4, device='cuda:0'):
|
||
self.K, self.device, self.n_hc = n_hc * hidden_dim, device, n_hc
|
||
def load(self, fn, base, scale=None):
|
||
self.fn = fn.to(self.device, torch.float32).contiguous()
|
||
self.base = base.to(self.device, torch.float32).contiguous()
|
||
self.scale = scale.to(self.device, torch.float32).item() if scale is not None else 1.0
|
||
def forward(self, X):
|
||
T = X.shape[0]; Xn = unweighted_rmsnorm(X.reshape(T, self.K).bfloat16())
|
||
mix = F.linear(Xn, self.fn[:self.n_hc]).float()
|
||
pre = torch.sigmoid(mix * self.scale + self.base[:self.n_hc].unsqueeze(0)) + HC_EPS
|
||
return (pre.unsqueeze(-1) * X.float()).sum(1).bfloat16()
|
||
|
||
# =====================================================================
|
||
# Production FMHA
|
||
# =====================================================================
|
||
def _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w, pfx):
|
||
from dsv4.kernels.attention.production import dsv4_attention
|
||
# Head-packed dispatch: single kernel launch for all 128 heads (MQA: 1 KV head shared)
|
||
q = q_heads.permute(1, 0, 2).contiguous() # (n_h, T, hd)
|
||
k = all_kv.unsqueeze(0).contiguous() # (1, N, hd) — MQA single KV head
|
||
v = k.clone()
|
||
sinks = w.get(f"{pfx}.sinks"); sink_bias = None
|
||
if sinks is not None: sink_bias = sinks.to(device=dev).float().reshape(n_h)
|
||
attn_out = dsv4_attention(q=q, k=k, v=v, scale=scale, n_comp=0, sink_bias=sink_bias)
|
||
return attn_out.permute(1, 0, 2) # (T, n_h, hd)
|
||
|
||
# =====================================================================
|
||
# Attention — ALL production kernels
|
||
# =====================================================================
|
||
def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||
kv_cache, positions, compressor, indexer, prod_lin):
|
||
dev = x_normed.device; T = x_normed.shape[0]
|
||
n_h = cfg["num_attention_heads"]; hd = cfg["head_dim"]; rd = cfg.get("qk_rope_head_dim", 64)
|
||
o_groups = cfg.get("o_groups", 16); o_rank = cfg.get("o_lora_rank", 1024)
|
||
ratio = compressor.ratio if compressor is not None else 0
|
||
scale = 1.0 / math.sqrt(hd); pfx = f"model.layers.{li}.self_attn"
|
||
if positions.device != rope_cos.device: positions = positions.to(rope_cos.device)
|
||
|
||
# 1. Q: q_a (NVFP4 GEMM) → q_a_norm → q_b (NVFP4 GEMM) → q_b_norm
|
||
q_a = prod_lin['q_a'](x_normed)
|
||
if VERBOSE >= 2 and li < 3:
|
||
# Compare q_a with PyTorch reference
|
||
q_a_ref = do_nvfp4_linear_ref(x_normed, w, pfx, 'q_a_proj')
|
||
if q_a_ref is not None:
|
||
cos_qa = torch.nn.functional.cosine_similarity(q_a.flatten().float(), q_a_ref.flatten().float(), dim=0).item()
|
||
print(f" L{li} q_a: |prod|={q_a.abs().max().item():.6f} |ref|={q_a_ref.abs().max().item():.6f} cos={cos_qa:.6f}", flush=True)
|
||
q_norm_w = w.get(f"{pfx}.q_a_norm.weight")
|
||
if q_norm_w is not None: q_a = rmsnorm(q_a, q_norm_w.to(dev, torch.float32))
|
||
q = prod_lin['q_b'](q_a); q = unweighted_rmsnorm(q).bfloat16()
|
||
q_heads = q.reshape(T, n_h, hd); q_heads = _apply_rope(q_heads, positions, rope_cos, rope_sin, rd)
|
||
|
||
# 2. KV (NVFP4 GEMM, MQA, single KV head)
|
||
kv = prod_lin['kv'](x_normed)
|
||
kv_norm_w = w.get(f"{pfx}.kv_norm.weight")
|
||
if kv_norm_w is not None: kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32))
|
||
kv_3d = kv.reshape(T, 1, hd); kv_3d = _apply_rope(kv_3d, positions, rope_cos, rope_sin, rd)
|
||
kv_roped = kv_3d.reshape(T, hd); kv_cache.append_swa(kv_roped, positions)
|
||
|
||
# 3. Compressor → compressed KV
|
||
comp_kv, comp_pos, block_bias = None, None, None; comp_idx_kv = None
|
||
if compressor is not None and compressor.ratio > 0:
|
||
comp_kv, comp_pos, block_bias = compressor.forward(x_normed, positions)
|
||
if comp_kv is not None:
|
||
comp_kv_3d = comp_kv.unsqueeze(1)
|
||
comp_kv_3d = _apply_rope(comp_kv_3d, comp_pos, rope_cos, rope_sin, rd)
|
||
comp_kv = comp_kv_3d.squeeze(1)
|
||
if compressor.is_csa and indexer is not None and indexer.compressor is not None:
|
||
comp_idx_kv, _, _ = indexer.compressor.forward(x_normed, positions)
|
||
kv_cache.add_compressed(comp_kv, comp_pos, comp_idx_kv)
|
||
|
||
# 4. Indexer top-k (CSA)
|
||
topk_idx = None
|
||
if indexer is not None and ratio == 4:
|
||
topk_idx = indexer.forward(q_a, x_normed, kv_cache.comp_idx_kv, positions)
|
||
|
||
# 5. Gather KV
|
||
swa_kv, swa_pos = kv_cache.get_swa()
|
||
if kv_cache.comp_kv is not None and kv_cache.n_comp > 0:
|
||
if ratio == 4 and topk_idx is not None:
|
||
tk = topk_idx[0].clamp(0, kv_cache.n_comp - 1)
|
||
all_kv = torch.cat([kv_cache.comp_kv[tk], swa_kv], dim=0)
|
||
elif ratio > 4: all_kv = torch.cat([kv_cache.comp_kv, swa_kv], dim=0)
|
||
else: all_kv = swa_kv
|
||
else: all_kv = swa_kv
|
||
seq_len = all_kv.shape[0]
|
||
if seq_len == 0: return torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev), q_a
|
||
|
||
# 6. Production FMHA
|
||
attn_out = _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w, pfx)
|
||
if VERBOSE >= 2 and li < 3:
|
||
# Compare with PyTorch reference
|
||
k_exp = all_kv.unsqueeze(0).expand(n_h, -1, -1).contiguous()
|
||
v_exp = k_exp.clone()
|
||
q_in = q_heads.permute(1, 0, 2)
|
||
ref_scores = torch.matmul(q_in, k_exp.transpose(-1, -2)) * scale
|
||
ref_attn = torch.matmul(torch.softmax(ref_scores.float(), -1).bfloat16(), v_exp).permute(1, 0, 2)
|
||
cos_sim = torch.nn.functional.cosine_similarity(attn_out.flatten().float(), ref_attn.flatten().float(), dim=0).item()
|
||
print(f" L{li} FMHA: |prod|={attn_out.abs().max().item():.6f} |ref|={ref_attn.abs().max().item():.6f} cos={cos_sim:.6f}", flush=True)
|
||
# 7. Inverse RoPE
|
||
attn_out = _apply_rope(attn_out, positions, rope_cos, rope_sin, rd, inverse=True)
|
||
|
||
# 8. Output: wo_a (NVFP4 grouped GEMM) + wo_b (NVFP4 GEMM)
|
||
wo_a_lin = prod_lin.get('o_a')
|
||
if wo_a_lin is not None:
|
||
# Nvfp4GroupedLinear: (T, n_h, hd) → (T, n_groups, o_rank) → flatten for o_b
|
||
g_3d = wo_a_lin.run(attn_out) # (T, n_groups, o_rank) BF16
|
||
g_flat = g_3d.reshape(T, -1) # (T, n_groups * o_rank) BF16
|
||
F_attn = prod_lin['o_b'](g_flat)
|
||
else:
|
||
# BF16 grouped BMM fallback (should not happen in production)
|
||
hpg_fb = n_h // o_groups; gid_fb = hpg_fb * hd
|
||
oa_full = w.get(f"{pfx}.o_a_proj.weight")
|
||
if oa_full is not None:
|
||
oa_bf = oa_full.bfloat16().to(dev); a_flat = attn_out.reshape(T, n_h * hd)
|
||
a_grp = a_flat.reshape(T, o_groups, gid_fb); oa_3d = oa_bf.reshape(o_groups, o_rank, gid_fb)
|
||
g_out = torch.bmm(a_grp.permute(1, 0, 2), oa_3d.transpose(1, 2))
|
||
g_flat = g_out.permute(1, 0, 2).reshape(T, o_groups * o_rank)
|
||
F_attn = prod_lin['o_b'](g_flat)
|
||
else:
|
||
log.warning(f"L{li}: No o_a_proj weight, zero attention output")
|
||
F_attn = torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev)
|
||
if VERBOSE >= 2 and li < 3:
|
||
print(f" L{li} F_attn: |F_attn|={F_attn.abs().max().item():.6f}", flush=True)
|
||
return F_attn, q_a
|
||
|
||
# =====================================================================
|
||
# MoE — production kernels
|
||
# =====================================================================
|
||
def moe_forward(x, li, moe_runner, se_runner, router, token_id):
|
||
# Ensure token_id is on same GPU as router
|
||
token_id_dev = token_id.to(x.device) if token_id.device != x.device else token_id
|
||
topk_w, topk_ids = router(x, token_ids=token_id_dev)
|
||
torch.cuda.synchronize(x.device)
|
||
if topk_ids.max().item() >= 384 or topk_ids.min().item() < 0:
|
||
print(f" L{li} BAD topk_ids: min={topk_ids.min().item()} max={topk_ids.max().item()}", flush=True)
|
||
if li >= 58:
|
||
print(f" L{li} MoE DIAG: topk_ids={topk_ids[0].tolist()} topk_w=[{','.join(f'{w:.3f}' for w in topk_w[0].tolist())}]", flush=True)
|
||
# Also print gate logits for debugging
|
||
if hasattr(router, '_gate_lin') and router._gate_lin is not None:
|
||
gate_logits = router._gate_lin(x).float()
|
||
print(f" L{li} gate logits: [{gate_logits.min().item():.3f}, {gate_logits.max().item():.3f}] mean={gate_logits.mean().item():.3f}", flush=True)
|
||
if VERBOSE >= 2 and li < 3:
|
||
print(f" L{li} MoE input: |x|={x.abs().max().item():.4f} has_nan={torch.isnan(x).any().item()}", flush=True)
|
||
routed_out = moe_runner.run(x, topk_w, topk_ids)
|
||
shared_out = se_runner.run(x)
|
||
if li >= 58:
|
||
print(f" L{li} MoE DIAG: |routed|={routed_out.abs().max().item():.1f} |shared|={shared_out.abs().max().item():.1f} |x|={x.abs().max().item():.1f}", flush=True)
|
||
if VERBOSE >= 2 and li < 3:
|
||
has_nan = torch.isnan(shared_out).any().item()
|
||
out_max = shared_out.abs().max().item() if not has_nan else float('nan')
|
||
print(f" L{li} MoE shared: |out|={out_max:.4f} has_nan={has_nan}", flush=True)
|
||
# Check weight integrity
|
||
if hasattr(se_runner, '_l1_mat_b') and se_runner._l1_mat_b is not None:
|
||
wb = se_runner._l1_mat_b.view(torch.uint8)
|
||
print(f" L{li} SE l1 weight: shape={list(se_runner._l1_mat_b.shape)} dtype={se_runner._l1_mat_b.dtype} uint8_range=[{wb.min().item()},{wb.max().item()}]", flush=True)
|
||
if hasattr(se_runner, '_l1_scale_b') and se_runner._l1_scale_b is not None:
|
||
sb = se_runner._l1_scale_b.float()
|
||
print(f" L{li} SE l1 scale: shape={list(se_runner._l1_scale_b.shape)} dtype={se_runner._l1_scale_b.dtype} float_range=[{sb.min().item():.6f},{sb.max().item():.6f}] has_nan={torch.isnan(sb).any().item()}", flush=True)
|
||
print(f" L{li} SE gsa: l1={se_runner._l1_activation_global_scale:.6f} l2={se_runner._l2_activation_global_scale:.6f} gsb: l1={se_runner._l1_gsb[0].item():.6f} l2={se_runner._l2_gsb[0].item():.6f}", flush=True)
|
||
return routed_out + shared_out
|
||
|
||
# =====================================================================
|
||
# Layer forward
|
||
# =====================================================================
|
||
def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
|
||
attn_mhc, ffn_mhc, attn_norm_w, ffn_norm_w,
|
||
kv_cache, positions, token_id,
|
||
compressor=None, indexer=None,
|
||
moe_runner=None, se_runner=None, router=None,
|
||
prod_lin=None):
|
||
x_in, ctx_a = attn_mhc.pre_block(X_l); x_normed = rmsnorm(x_in, attn_norm_w)
|
||
F_attn, _ = forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
||
kv_cache, positions, compressor, indexer, prod_lin)
|
||
X_mid = attn_mhc.post_block(X_l, F_attn, ctx_a)
|
||
x_in_f, ctx_f = ffn_mhc.pre_block(X_mid); x_ffn = rmsnorm(x_in_f, ffn_norm_w)
|
||
F_ffn = moe_forward(x_ffn, li, moe_runner, se_runner, router, token_id)
|
||
X_next = ffn_mhc.post_block(X_mid, F_ffn, ctx_f)
|
||
if VERBOSE >= 1:
|
||
print(f" L{li}: |X|={X_l.abs().max().item():.1f}->{X_next.abs().max().item():.1f} "
|
||
f"|Fa|={F_attn.abs().max().item():.1f} |Ff|={F_ffn.abs().max().item():.1f}", flush=True)
|
||
# Detailed diagnostics for last 3 layers or any layer with explosive growth
|
||
if li >= 58 or (li > 0 and X_next.abs().max().item() > 200):
|
||
A_a, B_a, C_a = attn_mhc._dynamic_params(X_l)
|
||
A_f, B_f, C_f = ffn_mhc._dynamic_params(X_mid)
|
||
print(f" L{li} DIAG: A_attn=[{A_a.min().item():.4f},{A_a.max().item():.4f}] "
|
||
f"C_attn=[{C_a.min().item():.4f},{C_a.max().item():.4f}] "
|
||
f"A_ffn=[{A_f.min().item():.4f},{A_f.max().item():.4f}] "
|
||
f"C_ffn=[{C_f.min().item():.4f},{C_f.max().item():.4f}]", flush=True)
|
||
print(f" L{li} DIAG: B_attn row_sum=[{B_a.sum(-1).min().item():.4f},{B_a.sum(-1).max().item():.4f}] "
|
||
f"col_sum=[{B_a.sum(-2).min().item():.4f},{B_a.sum(-2).max().item():.4f}] "
|
||
f"B_ffn row_sum=[{B_f.sum(-1).min().item():.4f},{B_f.sum(-1).max().item():.4f}] "
|
||
f"col_sum=[{B_f.sum(-2).min().item():.4f},{B_f.sum(-2).max().item():.4f}]", flush=True)
|
||
print(f" L{li} DIAG: |x_in_attn|={x_in.abs().max().item():.1f} "
|
||
f"|x_in_ffn|={x_in_f.abs().max().item():.1f} "
|
||
f"|X_l|={X_l.abs().max().item():.1f} "
|
||
f"|X_mid|={X_mid.abs().max().item():.1f} "
|
||
f"|X_next|={X_next.abs().max().item():.1f}", flush=True)
|
||
return X_next
|
||
|
||
# =====================================================================
|
||
# MoE weight loading
|
||
# =====================================================================
|
||
def _load_moe_weights_stacked(all_w, li, pfx, dev, moe, cfg):
|
||
n_e = cfg["n_routed_experts"]
|
||
l1_fp4_list, l1_sf_list, l1_gs_list, l1_ws2_list, l1_gsa_list = [], [], [], [], []
|
||
l2_fp4_list, l2_sf_list, l2_gs_list, l2_ws2_list, l2_gsa_list = [], [], [], [], []
|
||
for eid in range(n_e):
|
||
ep = f"{pfx}.experts.{eid}"
|
||
gw, gws, gws2, gisc = get_nvfp4_weight(all_w, ep, 'gate_proj')
|
||
uw, uws, uws2, uisc = get_nvfp4_weight(all_w, ep, 'up_proj')
|
||
if gw is not None and uw is not None:
|
||
l1_fp4_list.append(torch.cat([gw, uw], dim=0).to(dev))
|
||
if gws is not None and uws is not None: l1_sf_list.append(torch.cat([gws, uws], dim=0).to(dev))
|
||
gs = gisc.float().item() if gisc is not None else 1.0 / (6.0 * 448.0)
|
||
l1_gs_list.append(1.0) # gsb base — ws2 will be folded in by _ensure_stacked
|
||
l1_gsa_list.append(gs) # gsa = input_scale
|
||
# weight_scale_2: scalar, folded into global_scale_b
|
||
l1_ws2_list.append(gws2.to(dev) if gws2 is not None else None)
|
||
dw, dws, dws2, disc = get_nvfp4_weight(all_w, ep, 'down_proj')
|
||
if dw is not None:
|
||
l2_fp4_list.append(dw.to(dev))
|
||
if dws is not None: l2_sf_list.append(dws.to(dev))
|
||
gs2 = disc.float().item() if disc is not None else 1.0 / (6.0 * 448.0)
|
||
l2_gs_list.append(1.0) # gsb base
|
||
l2_gsa_list.append(gs2) # gsa = input_scale
|
||
l2_ws2_list.append(dws2.to(dev) if dws2 is not None else None)
|
||
if not l1_fp4_list: log.warning(f"L{li}: No expert weights found"); return
|
||
l1_stacked = torch.stack(l1_fp4_list).to(dev)
|
||
l1_sf_stacked = torch.stack(l1_sf_list).to(dev) if l1_sf_list else None
|
||
l2_stacked = torch.stack(l2_fp4_list).to(dev) if l2_fp4_list else None
|
||
l2_sf_stacked = torch.stack(l2_sf_list).to(dev) if l2_sf_list else None
|
||
del l1_fp4_list, l1_sf_list, l2_fp4_list, l2_sf_list
|
||
moe.prepare_weights_from_stacked(l1_stacked, l1_sf_stacked, l1_gs_list, l2_stacked, l2_sf_stacked, l2_gs_list)
|
||
# Save activation global scales — _ensure_stacked will override them from l1_gs (which is 1.0)
|
||
# We must re-set them AFTER _ensure_stacked
|
||
moe._saved_l1_gsa = l1_gsa_list[0] if l1_gsa_list else 1.0 / (6.0 * 448.0)
|
||
moe._saved_l2_gsa = l2_gsa_list[0] if l2_gsa_list else 1.0 / (6.0 * 448.0)
|
||
moe.l1_ws2 = l1_ws2_list
|
||
moe.l2_ws2 = l2_ws2_list
|
||
|
||
def _load_shared_expert_weights(all_w, li, pfx, dev, se, cfg):
|
||
gw, gws, gws2, gisc = get_nvfp4_weight(all_w, f"{pfx}.shared_experts", 'gate_proj')
|
||
uw, uws, uws2, uisc = get_nvfp4_weight(all_w, f"{pfx}.shared_experts", 'up_proj')
|
||
dw, dws, dws2, disc = get_nvfp4_weight(all_w, f"{pfx}.shared_experts", 'down_proj')
|
||
if gw is not None and uw is not None:
|
||
se.l1_fp4 = [torch.cat([gw, uw], dim=0).to(dev)]
|
||
se.l1_sf = [torch.cat([gws, uws], dim=0).to(dev)] if gws is not None and uws is not None else [torch.zeros(1, device=dev, dtype=torch.float8_e4m3fn)]
|
||
l1_isc = gisc.float().item() if gisc is not None else 1.0 / (6.0 * 448.0)
|
||
se.l1_gs = [1.0] # gsb base — ws2 will be folded in by finalize_weights
|
||
se.l1_ws2 = [gws2.to(dev) if gws2 is not None else None]
|
||
se._l1_activation_global_scale = l1_isc # Will be overridden by _ensure_initialized
|
||
se._saved_l1_gsa = l1_isc # Save for after _ensure_initialized
|
||
if dw is not None:
|
||
se.l2_fp4 = [dw.to(dev)]
|
||
se.l2_sf = [dws.to(dev)] if dws is not None else [torch.zeros(1, device=dev, dtype=torch.float8_e4m3fn)]
|
||
l2_isc = disc.float().item() if disc is not None else 1.0 / (6.0 * 448.0)
|
||
se.l2_gs = [1.0] # gsb base
|
||
se.l2_ws2 = [dws2.to(dev) if dws2 is not None else None]
|
||
se._l2_activation_global_scale = l2_isc # Will be overridden by _ensure_initialized
|
||
se._saved_l2_gsa = l2_isc # Save for after _ensure_initialized
|
||
|
||
def _cache_layer_weights_no_experts(all_w, n_layers, devices):
|
||
cached = {}
|
||
for li in range(n_layers):
|
||
dev = devices[li % len(devices)]; pfx = f"model.layers.{li}."
|
||
w = {k: v.to(device=dev, non_blocking=True) for k, v in all_w.items()
|
||
if k.startswith(pfx) and '.experts.' not in k and '.shared_experts.' not in k}
|
||
cached[li] = w
|
||
if (li+1) % 10 == 0: log.info(f" Cached {li+1}/{n_layers} layers")
|
||
return cached
|
||
|
||
# =====================================================================
|
||
# Main
|
||
# =====================================================================
|
||
def kill_stale_gpu_processes():
|
||
"""Kill any leftover python processes on all GPUs before starting."""
|
||
import subprocess
|
||
try:
|
||
result = subprocess.run(['nvidia-smi', '--query-compute-apps=pid', '--format=csv,noheader'],
|
||
capture_output=True, text=True, timeout=5)
|
||
if result.returncode == 0 and result.stdout.strip():
|
||
pids = [p.strip() for p in result.stdout.strip().split('\n') if p.strip()]
|
||
for pid in pids:
|
||
try:
|
||
import os; os.kill(int(pid), 9)
|
||
log.info(f" Killed stale GPU process {pid}")
|
||
except (ValueError, ProcessLookupError):
|
||
pass
|
||
except Exception as e:
|
||
log.warning(f"Could not check GPU processes: {e}")
|
||
|
||
def main():
|
||
t0 = time.time(); torch.manual_seed(SEED)
|
||
print("=" * 70)
|
||
print("DSV4 Single-Shot Inference - PRODUCTION KERNEL STACK")
|
||
print(" FMHA: 6-warp TMA multi-tile + sink bias")
|
||
print(" NVFP4 GEMM (CuTeDSL) for ALL projections")
|
||
print(" Production MoE + Router | Production mHC")
|
||
print(" NO PyTorch SDPA | NO dequant+matmul | NO reference fallback")
|
||
print("=" * 70)
|
||
|
||
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
|
||
cfg = json.load(f)
|
||
n_layers = cfg["num_hidden_layers"]; H = cfg["hidden_size"]
|
||
hd = cfg["head_dim"]; n_h = cfg["num_attention_heads"]
|
||
rd = cfg.get("qk_rope_head_dim", 64)
|
||
cr = cfg.get("compress_ratios", [128] * n_layers)
|
||
o_groups = cfg.get("o_groups", 16); o_rank = cfg.get("o_lora_rank", 1024)
|
||
print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}")
|
||
print(f"Compress ratios: first5={cr[:5]} len={len(cr)}")
|
||
print(f"Experts: {cfg['n_routed_experts']}, top-{cfg.get('num_experts_per_tok', 6)}")
|
||
|
||
# ---- Phase 1: Load weights ----
|
||
print(f"\nPhase 1: Loading weights..."); all_w = load_all_weights(CHECKPOINT_DIR)
|
||
print(f" {time.time()-t0:.1f}s")
|
||
|
||
# ---- Phase 2: Build production components ----
|
||
print("Building production components...")
|
||
from dsv4.layers.mhc import mHCLayer
|
||
from dsv4.layers.router import Router
|
||
from dsv4.layers.moe import Nvfp4MoE
|
||
from dsv4.layers.shared_expert import Nvfp4SharedExpert
|
||
|
||
# Kill stale GPU processes from prior runs (OOM, crash, etc.)
|
||
kill_stale_gpu_processes()
|
||
for g in range(NUM_GPUS): torch.cuda.set_device(g); torch.cuda.empty_cache()
|
||
torch.cuda.set_device(0)
|
||
|
||
# mHC + norms
|
||
attn_mhcs, ffn_mhcs, attn_norms, ffn_norms = {}, {}, {}, {}
|
||
for li in range(n_layers):
|
||
dev = f"cuda:{li % NUM_GPUS}"
|
||
for tag, blocks, fn_s, base_s, scale_s in [
|
||
("attn", attn_mhcs, f"model.layers.{li}.attn_hc.fn", f"model.layers.{li}.attn_hc.base", f"model.layers.{li}.attn_hc.scale"),
|
||
("ffn", ffn_mhcs, f"model.layers.{li}.ffn_hc.fn", f"model.layers.{li}.ffn_hc.base", f"model.layers.{li}.ffn_hc.scale"),
|
||
]:
|
||
fn, base, scale = all_w.get(fn_s), all_w.get(base_s), all_w.get(scale_s)
|
||
if fn is not None and base is not None and scale is not None:
|
||
m = mHCLayer(hidden_dim=H, n_hc=4, t_max_sinkhorn=20, device=dev)
|
||
n = 4
|
||
m.load_weights(
|
||
W_pre=fn[0:n].to(dev, torch.float32), W_post=fn[n:2*n].to(dev, torch.float32),
|
||
W_comb=fn[2*n:].to(dev, torch.float32),
|
||
S_pre=base[0:n].reshape(1, n).to(dev, torch.float32),
|
||
S_post=base[n:2*n].reshape(n, 1).to(dev, torch.float32),
|
||
S_comb=base[2*n:].reshape(n, n).to(dev, torch.float32),
|
||
alpha_pre=scale[0].item(), alpha_post=scale[1].item(), alpha_comb=scale[2].item(),
|
||
)
|
||
blocks[li] = m
|
||
an_k = f"model.layers.{li}.input_layernorm.weight"
|
||
if an_k in all_w: attn_norms[li] = all_w[an_k].to(dev, torch.float32)
|
||
fn_k = f"model.layers.{li}.post_attention_layernorm.weight"
|
||
if fn_k in all_w: ffn_norms[li] = all_w[fn_k].to(dev, torch.float32)
|
||
|
||
# Production Nvfp4Linear for attention projections
|
||
print(" Building production Nvfp4Linear for attention projections...")
|
||
prod_lins = {}
|
||
# Weight dimensions (from checkpoint):
|
||
# q_a_proj: (1536, 3584) uint8 -> in=7168, out=1536
|
||
# q_b_proj: (65536, 768) uint8 -> in=1536, out=65536
|
||
# kv_proj: (512, 3584) uint8 -> in=7168, out=512
|
||
# o_a_proj: (16384, 4096) BF16 -> Nvfp4GroupedLinear (16 groups, 1024×4096 each)
|
||
# o_b_proj: (7168, 8192) uint8 -> in=16384, out=7168
|
||
from dsv4.layers.grouped_linear import Nvfp4GroupedLinear
|
||
for li in range(n_layers):
|
||
dev = f"cuda:{li % NUM_GPUS}"; pfx = f"model.layers.{li}.self_attn"
|
||
torch.cuda.set_device(li % NUM_GPUS)
|
||
pl = {}
|
||
pl['q_a'] = make_nvfp4_linear(7168, 1536, dev, all_w, pfx, 'q_a_proj')
|
||
pl['q_b'] = make_nvfp4_linear(1536, 65536, dev, all_w, pfx, 'q_b_proj')
|
||
pl['kv'] = make_nvfp4_linear(7168, 512, dev, all_w, pfx, 'kv_proj')
|
||
# o_a_proj: Nvfp4GroupedLinear (NVFP4 grouped GEMM)
|
||
n_local_groups = cfg.get('o_groups', 16)
|
||
heads_per_group = n_h // n_local_groups
|
||
o_rank_val = cfg.get('o_lora_rank', 1024)
|
||
wo_a = Nvfp4GroupedLinear(
|
||
n_local_groups=n_local_groups,
|
||
heads_per_group=heads_per_group,
|
||
head_dim=hd,
|
||
o_lora_rank=o_rank_val,
|
||
max_num_tokens=8192,
|
||
device=dev,
|
||
)
|
||
oa_w_nvfp4, oa_ws, oa_ws2, oa_isc = get_nvfp4_weight(all_w, pfx, 'o_a_proj')
|
||
if oa_w_nvfp4 is not None and oa_ws is not None:
|
||
# Checkpoint has NVFP4 weights — load directly
|
||
# TODO: Nvfp4GroupedLinear needs a load_nvfp4_weight method
|
||
# For now, dequant and re-quantize via set_bf16_weight
|
||
oa_bf16 = dequant_nvfp4(oa_w_nvfp4, oa_ws, oa_ws2, oa_isc).to(dev)
|
||
wo_a.set_bf16_weight(oa_bf16)
|
||
else:
|
||
# BF16 checkpoint weight
|
||
oa_bf = all_w.get(f"{pfx}.o_a_proj.weight")
|
||
if oa_bf is not None:
|
||
wo_a.set_bf16_weight(oa_bf.bfloat16().to(dev))
|
||
pl['o_a'] = wo_a
|
||
wo_a._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow
|
||
pl['o_b'] = make_nvfp4_linear(16384, 7168, dev, all_w, pfx, 'o_b_proj')
|
||
prod_lins[li] = pl
|
||
if (li+1) % 10 == 0: print(f" {li+1}/{n_layers} layers")
|
||
print(" All attention projections: production NVFP4 GEMM (o_a now NVFP4 grouped)")
|
||
|
||
# Routers, MoE, shared experts
|
||
routers, moe_runners, se_runners = {}, {}, {}
|
||
for li in range(n_layers):
|
||
dev = f"cuda:{li % NUM_GPUS}"; pfx = f"model.layers.{li}.mlp"
|
||
torch.cuda.set_device(li % NUM_GPUS); torch.cuda.synchronize()
|
||
is_hash = (li < cfg.get("num_hash_layers", 3)) and (f"{pfx}.gate.tid2eid" in all_w)
|
||
router = Router(hidden_size=H, num_experts=cfg["n_routed_experts"],
|
||
top_k=cfg.get("num_experts_per_tok", 6),
|
||
routed_scaling_factor=cfg.get("routed_scaling_factor", 2.5),
|
||
mode="hash" if is_hash else "dense",
|
||
vocab_size=cfg.get("vocab_size", 128000) if is_hash else None, device=dev)
|
||
if is_hash:
|
||
router.load_weights(hash_lut=all_w[f"{pfx}.gate.tid2eid"].to(dev, torch.int32))
|
||
else:
|
||
eb = all_w.get(f"{pfx}.gate.e_score_correction_bias")
|
||
# NVFP4 production GEMM for router gate
|
||
# Custom CuTeDSL fused kernel crashes MLIR optimizer,
|
||
# so we use Nvfp4Linear (proven production path).
|
||
from dsv4.layers.linear import Nvfp4Linear
|
||
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(all_w, pfx, 'gate')
|
||
E = cfg["n_routed_experts"]
|
||
if gate_w is not None and gate_ws is not None:
|
||
# Checkpoint has NVFP4 gate weight (N_packed, K_packed) — correct layout
|
||
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
|
||
gate_w_view = gate_w.to(dev).view(torch.float4_e2m1fn_x2) if gate_w.dtype == torch.uint8 else gate_w.to(dev)
|
||
gate_lin.fp4 = [gate_w_view]
|
||
gate_lin.sf = [gate_ws.to(dev)]
|
||
ws2_v = gate_ws2.float().item() if gate_ws2 is not None else 1.0
|
||
isc_v = gate_isc.float().item() if gate_isc is not None else 1.0/(6.0*448.0)
|
||
gate_lin.gs = [1.0]
|
||
gate_lin.ws2 = [torch.tensor([ws2_v], device=dev, dtype=torch.float32)]
|
||
gate_lin._activation_global_scale = isc_v
|
||
gate_lin.finalize_weights()
|
||
router.load_nvfp4_gate(gate_lin)
|
||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||
if li < 5: print(f" L{li}: NVFP4 router gate (checkpoint)", flush=True)
|
||
else:
|
||
# BF16 gate weight: quantize to NVFP4
|
||
gw = all_w.get(f"{pfx}.gate.weight")
|
||
if gw is not None:
|
||
g_bf16 = gw if gw.shape == (E, H) else gw.T.contiguous()
|
||
g_bf16 = g_bf16.bfloat16().to(dev)
|
||
from dsv4.ops.quantize import quantize_to_nvfp4
|
||
g_fp4, g_sf, g_gs = quantize_to_nvfp4(g_bf16)
|
||
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
|
||
gate_lin.fp4 = [g_fp4]
|
||
gate_lin.sf = [g_sf]
|
||
gate_lin.gs = [g_gs]
|
||
gate_lin.ws2 = [torch.tensor([g_gs], device=dev, dtype=torch.float32)]
|
||
gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0)
|
||
gate_lin.finalize_weights()
|
||
router.load_nvfp4_gate(gate_lin)
|
||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||
if li < 5: print(f" L{li}: NVFP4 router gate (quantized, gs={g_gs:.6f})", flush=True)
|
||
else:
|
||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||
router.load_weights(e_bias=eb.to(dev, torch.float32))
|
||
router.finalize_weights(); routers[li] = router
|
||
|
||
moe = Nvfp4MoE(num_experts=cfg["n_routed_experts"], hidden_size=H,
|
||
intermediate_size=cfg.get("moe_intermediate_size", 3072),
|
||
top_k=cfg.get("num_experts_per_tok", 6), device=dev)
|
||
moe.set_swiglu_limit(cfg.get("swiglu_limit", 10.0))
|
||
_load_moe_weights_stacked(all_w, li, pfx, dev, moe, cfg)
|
||
# EAGERLY process stacked weights → K-major + swizzle, free raw tensors
|
||
moe._ensure_stacked()
|
||
# Fix activation global scales — _ensure_stacked sets gsa from l1_gs (which is 1.0)
|
||
# FIX: Do NOT use checkpoint input_scale as gsa — causes E4M3 overflow.
|
||
# Instead, compute gsa at runtime from actual activation magnitude.
|
||
# The MoE runner's compute_activation_global_scales() does this correctly.
|
||
# We enable runtime gsa for both MoE and SharedExpert.
|
||
moe._use_runtime_gsa = True
|
||
moe_runners[li] = moe
|
||
|
||
se = Nvfp4SharedExpert(hidden_size=H, intermediate_size=cfg.get("moe_intermediate_size", 3072),
|
||
device=dev, swiglu_limit=cfg.get("swiglu_limit", 10.0))
|
||
_load_shared_expert_weights(all_w, li, pfx, dev, se, cfg)
|
||
# EAGERLY process shared expert weights
|
||
se._ensure_initialized()
|
||
# Fix activation global scales — _ensure_initialized sets gsa from l1_gs (which is 1.0)
|
||
# FIX: Same runtime gsa for SharedExpert
|
||
se._use_runtime_gsa = True
|
||
se_runners[li] = se
|
||
if (li+1) % 10 == 0: print(f" Built {li+1}/{n_layers} MoE layers")
|
||
torch.cuda.empty_cache()
|
||
|
||
# Global weights
|
||
torch.cuda.set_device(0)
|
||
embed_w = all_w.get("model.embed_tokens.weight")
|
||
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to('cuda:0'))
|
||
lm_w = all_w.get("lm_head.weight", embed_w).bfloat16().to('cuda:0')
|
||
final_norm_w = all_w.get("model.norm.weight")
|
||
if final_norm_w is not None: final_norm_w = final_norm_w.to('cuda:0', torch.float32)
|
||
|
||
hc_head = HcHead(H, 4, 'cuda:0')
|
||
hc_fn = all_w.get("model.hc_head.hc_fn"); hc_base = all_w.get("model.hc_head.hc_base"); hc_scale = all_w.get("model.hc_head.hc_scale")
|
||
if hc_fn is not None and hc_base is not None: hc_head.load(hc_fn, hc_base, hc_scale); print(" hc_head loaded")
|
||
|
||
# RoPE (FP32)
|
||
rp = cfg.get("rope_scaling", cfg.get("rope_parameters", {}))
|
||
rt = rp.get("type", rp.get("rope_type", "yarn")); rf = rp.get("factor", 16.0)
|
||
rtheta = cfg.get("rope_theta", 10000.); romax = rp.get("original_max_position_embeddings", 65536)
|
||
rbfast, rbslow = rp.get("beta_fast", 32), rp.get("beta_slow", 1)
|
||
rope_caches = {g: build_rope_cache(8192, rd, f"cuda:{g}", rtheta, rt, rf, romax, rbfast, rbslow) for g in range(NUM_GPUS)}
|
||
|
||
# KV caches, compressors, indexers
|
||
kv_caches, compressors, indexers = {}, {}, {}
|
||
n_ih = cfg.get("index_n_heads", 64); ihd = cfg.get("index_head_dim", 128); itk = cfg.get("index_topk", 1024)
|
||
for li in range(n_layers):
|
||
dev = f"cuda:{li % NUM_GPUS}"; ratio = cr[li] if li < len(cr) else 128
|
||
kv_caches[li] = KVCache(hd, cfg.get("sliding_window", 128), dev)
|
||
if ratio > 0: compressors[li] = Compressor(ratio, hd, H, dev)
|
||
if ratio == 4: indexers[li] = Indexer(n_ih, ihd, itk, dev)
|
||
|
||
# Cache layer weights (no MoE/SE)
|
||
print("Caching layer weights to GPUs (excluding MoE expert weights)...")
|
||
devs = [f"cuda:{g}" for g in range(NUM_GPUS)]
|
||
layer_w = _cache_layer_weights_no_experts(all_w, n_layers, devs)
|
||
del all_w; import gc; gc.collect()
|
||
for g in range(NUM_GPUS): torch.cuda.set_device(g); torch.cuda.empty_cache()
|
||
torch.cuda.set_device(0)
|
||
print(f" {time.time()-t0:.1f}s")
|
||
|
||
# Load compressor/indexer weights
|
||
for li in range(n_layers):
|
||
pfx = f"model.layers.{li}.self_attn.compressor"
|
||
if li in compressors: compressors[li].load(layer_w[li], pfx, dev=f"cuda:{li % NUM_GPUS}")
|
||
if li in indexers: indexers[li].load(layer_w[li], f"{pfx}.indexer", dev=f"cuda:{li % NUM_GPUS}")
|
||
print(" Compressors/indexers loaded")
|
||
|
||
# ---- Phase 3: Inference ----
|
||
print(f"\nPhase 3: Inference")
|
||
from transformers import AutoTokenizer
|
||
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
|
||
|
||
bos = tokenizer.bos_token_id or 0
|
||
if _args.prefill_tokens:
|
||
generated = [int(x) for x in _args.prefill_tokens.split(',')]
|
||
else:
|
||
input_ids = [bos, USER_TOKEN]
|
||
input_ids += tokenizer.encode('\n\n' + PROMPT, add_special_tokens=False)
|
||
input_ids.append(ASSISTANT_TOKEN)
|
||
generated = input_ids
|
||
all_tokens = generated.copy()
|
||
print(f"Input: {len(generated)} tokens")
|
||
|
||
# Prefill — one token at a time (decode-style; TODO: batched prefill)
|
||
print(f"Prefilling {len(generated)} tokens...")
|
||
for pi, tid_val in enumerate(generated):
|
||
t1 = time.time()
|
||
tid_int64 = torch.tensor([tid_val], dtype=torch.long, device='cuda:0')
|
||
tid = tid_int64.to(torch.int32) # hash router needs int32
|
||
pos = torch.tensor([pi], dtype=torch.long, device='cuda:0')
|
||
X = mHCLayer.init_state(embed(tid_int64))
|
||
for li in range(n_layers):
|
||
gpu = li % NUM_GPUS
|
||
if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}")
|
||
torch.cuda.set_device(gpu)
|
||
try:
|
||
X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu],
|
||
attn_mhcs.get(li), ffn_mhcs.get(li),
|
||
attn_norms.get(li), ffn_norms.get(li),
|
||
kv_caches[li], pos, tid,
|
||
compressors.get(li), indexers.get(li),
|
||
moe_runners.get(li), se_runners.get(li), routers.get(li),
|
||
prod_lin=prod_lins.get(li))
|
||
except Exception as e:
|
||
torch.cuda.synchronize()
|
||
err = torch.cuda.current_stream(gpu).query()
|
||
print(f" CRASH at token {pi} layer {li} gpu {gpu}: {e}", flush=True)
|
||
raise
|
||
if VERBOSE >= 2 and pi == 0 and li < 3:
|
||
torch.cuda.synchronize(gpu)
|
||
print(f" Token {pi} L{li}: OK |X|={X.abs().max().item():.1f}", flush=True)
|
||
X = X.to('cuda:0'); torch.cuda.set_device(0)
|
||
if pi % 10 == 0: print(f" Token {pi}/{len(generated)}: {time.time()-t1:.2f}s", flush=True)
|
||
print(f" Prefill done ({time.time()-t0:.1f}s)")
|
||
|
||
if _args.prefill_only: print("Prefill-only mode, stopping."); return
|
||
|
||
# Decode
|
||
print(f"\nDecoding (max {MAX_NEW_TOKENS} tokens)...")
|
||
for step in range(MAX_NEW_TOKENS):
|
||
t1 = time.time()
|
||
tid_int64 = torch.tensor([all_tokens[-1]], dtype=torch.long, device='cuda:0')
|
||
tid = tid_int64.to(torch.int32) # hash router needs int32
|
||
dec_pos = torch.tensor([len(all_tokens)-1], dtype=torch.long, device='cuda:0')
|
||
X = mHCLayer.init_state(embed(tid_int64))
|
||
for li in range(n_layers):
|
||
gpu = li % NUM_GPUS
|
||
if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}")
|
||
torch.cuda.set_device(gpu)
|
||
X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu],
|
||
attn_mhcs.get(li), ffn_mhcs.get(li),
|
||
attn_norms.get(li), ffn_norms.get(li),
|
||
kv_caches[li], dec_pos, tid,
|
||
compressors.get(li), indexers.get(li),
|
||
moe_runners.get(li), se_runners.get(li), routers.get(li),
|
||
prod_lin=prod_lins.get(li))
|
||
X = X.to('cuda:0'); torch.cuda.set_device(0)
|
||
x_out = hc_head.forward(X) if hc_head is not None else X[:, 0, :]
|
||
if final_norm_w is not None: x_out = rmsnorm(x_out, final_norm_w)
|
||
logits = F.linear(x_out, lm_w)
|
||
next_id = torch.argmax(logits, -1).item(); all_tokens.append(next_id)
|
||
dt = time.time() - t1
|
||
has_nan = torch.isnan(logits.float()).any().item()
|
||
if step % 1 == 0 or has_nan:
|
||
tv, ti = torch.topk(logits[0], 5)
|
||
top5 = ' '.join(f'{tokenizer.decode([t.item()])}({v.item():.1f})' for t, v in zip(ti[:5], tv[:5]))
|
||
print(f" Step {step}: {next_id} '{tokenizer.decode([next_id])}' ({dt:.2f}s) "
|
||
f"logits=[{logits.float().min().item():.1f},{logits.float().max().item():.1f}] "
|
||
f"nan={has_nan} |X|={X.abs().max().item():.1f} top5: {top5}", flush=True)
|
||
if has_nan: break
|
||
if next_id == tokenizer.eos_token_id: break
|
||
|
||
out = tokenizer.decode(all_tokens, skip_special_tokens=True)
|
||
print(f"\n{'='*70}")
|
||
print(f"Input: '{PROMPT}'")
|
||
print(f"Output: '{out}'")
|
||
print(f"Total: {time.time()-t0:.1f}s")
|
||
print(f"{'='*70}")
|
||
|
||
if __name__ == "__main__":
|
||
main()
|