Files
nvfp4-megamoe-kernel/single_shot_inference.py
biondizzle d8e17d70c1 P0+P1+P2: Enable fused SwiGLU (MoE+SE), fix SE _run_l1_fused, remove per-call gsa fill_
P0: Enable fused SwiGLU for MoE (set_fused_swiglu(True))
  - Saves 240+ unfused BF16 kernel launches per token
  - SiLU + clamp in kernel registers instead of separate launches

P1: Fix shared expert _run_l1_fused + enable fused SwiGLU
  - Fixed: _l1_sf_view -> _l1_scale_b, _l1_gs_view -> _l1_gsb
  - Fixed: expert_offsets dtype int64 -> int32
  - Added proper padded buffer + scale assembly (matching unfused path)
  - Added runtime gsa support (quantize_nvfp4_gpu_fused)

P2: Remove per-call gsa_buf.fill_() in Nvfp4Linear
  - fill_() was H2D transfer every forward pass (~5µs × 244 calls = ~1.2ms/token)
  - _gsa_buf now initialized with _activation_global_scale (not zeros)
  - After warmup_gsa, buffer already has correct value — no fill needed
2026-06-02 07:57:39 +00:00

1375 lines
73 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""Single-shot DSV4-Pro inference — Full production pipeline, 8-GPU.
ALL projections use production NVFP4 GEMM kernels (CuTeDSL).
ALL attention uses production FMHA (6-warp TMA multi-tile + sink bias).
ALL MoE uses production Nvfp4MoE + Nvfp4SharedExpert + Router.
NO PyTorch SDPA fallback. NO dequant+matmul for production projections.
This is the ground truth for vLLM / SGLang integration.
"""
import os, sys, time, json, math, argparse, logging
import torch
import torch.nn.functional as F
from pathlib import Path
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
log = logging.getLogger("single_shot")
def parse_args():
p = argparse.ArgumentParser()
p.add_argument('--max-tokens', type=int, default=512)
p.add_argument('--temperature', type=float, default=0.6, help='Sampling temperature (0=greedy)')
p.add_argument('--repetition-penalty', type=float, default=1.1, help='Repetition penalty factor (>1 penalizes repeats)')
p.add_argument('--top-k', type=int, default=50, help='Top-k filtering (0=disabled)')
p.add_argument('--top-p', type=float, default=0.95, help='Top-p (nucleus) filtering (1.0=disabled)')
p.add_argument('--prompt', type=str, default=None)
p.add_argument('--seed', type=int, default=42)
p.add_argument('--verbose', type=int, default=1)
p.add_argument('--prefill-only', action='store_true')
p.add_argument('--warmup-gsa', action='store_true', help='Fix gsa values after first decode step (eliminates amax kernel launches)')
p.add_argument('--profile', action='store_true', help='Profile per-component GPU time using CUDA events')
p.add_argument('--num-gpus', type=int, default=8)
p.add_argument('--checkpoint', type=str, default="/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4")
p.add_argument('--prefill-tokens', type=str, default=None,
help='Override prompt tokens as comma-separated IDs (e.g. "1,128803,313,128804")')
p.add_argument('--cuda-graph', action='store_true', help='Capture CUDA graph per layer for decode (eliminates Python dispatch overhead)')
p.add_argument('--max-context', type=int, default=8192, help='Target max context length (determines KV cache pre-allocation)')
return p.parse_args()
_args = parse_args()
CHECKPOINT_DIR = _args.checkpoint
MAX_NEW_TOKENS = _args.max_tokens
PROMPT = _args.prompt or "The capital of France is"
NUM_GPUS = _args.num_gpus
SEED = _args.seed
VERBOSE = _args.verbose
THINK_START, THINK_END = 128821, 128822
USER_TOKEN, ASSISTANT_TOKEN = 128803, 128804
FP4_LUT = torch.tensor([0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0])
# =====================================================================
# RoPE (FP32 — BF16 destroys cos²+sin²=1)
# =====================================================================
def build_rope_cache(max_pos, rope_dim, device, theta=10000., rope_type="default",
rope_factor=1., orig_max=4096, beta_fast=32, beta_slow=1):
freqs = 1. / (theta ** (torch.arange(0, rope_dim, 2, dtype=torch.float32) / rope_dim))
if rope_type == "yarn" and rope_factor > 1.:
nf = []
for f in freqs:
wl = 2 * math.pi / f
lo, hi = orig_max / (beta_fast * 2.), orig_max / (beta_slow * 2.)
if wl < lo: nf.append(f)
elif wl > hi: nf.append(f / rope_factor)
else:
sm = (orig_max / (wl * beta_slow) - rope_factor) / (rope_factor * (beta_fast / beta_slow - 1))
nf.append((1 - sm) * f / rope_factor + sm * f)
freqs = torch.tensor(nf, dtype=torch.float32)
angles = torch.outer(torch.arange(max_pos, dtype=torch.float32), freqs)
return torch.cos(angles).to(device), torch.sin(angles).to(device)
def _apply_rope(x, pos, cos, sin, rope_dim, inverse=False):
"""In-place RoPE — mutates x, no full clone, no empty_like allocation.
P5: Eliminates x.clone() + empty_like per RoPE call.
Old: 183 calls/token × 128KB clone = 23MB pointless memcpy + 183 kernel launches.
New: Operates on the rope dims in-place, one slice copy back.
"""
T, nh, hd = x.shape; nope = hd - rope_dim
if pos.device != cos.device: pos = pos.to(cos.device)
c, s = cos[pos].unsqueeze(1), sin[pos].unsqueeze(1)
xr = x[:, :, nope:] # view, not copy
ev = xr[..., 0::2].clone() # need original ev for the mix
od = xr[..., 1::2] # view; will be overwritten below
if inverse:
xr[..., 0::2] = (ev * c + od * s).bfloat16()
xr[..., 1::2] = (-ev * s + od * c).bfloat16()
else:
xr[..., 0::2] = (ev * c - od * s).bfloat16()
xr[..., 1::2] = (ev * s + od * c).bfloat16()
return x # mutated in place
# =====================================================================
# Weight loading
# =====================================================================
def load_all_weights(checkpoint_dir):
from safetensors.torch import load_file
cdir = Path(checkpoint_dir); wmap = {}
idx = cdir / "model.safetensors.index.json"
if idx.exists():
with open(idx) as f: wmap = json.load(f).get("weight_map", {})
shards = set(wmap.values()) if wmap else set(); all_w = {}
for sn in sorted(shards):
if (cdir / sn).exists(): all_w.update(load_file(str(cdir / sn)))
log.info(f"Loaded {len(all_w)} tensors from {len(shards)} shards"); return all_w
# =====================================================================
# RMSNorm
# =====================================================================
def rmsnorm(x, weight, eps=1e-6):
xf = x.float()
return (xf * xf.pow(2).mean(-1, keepdim=True).add(eps).rsqrt() * weight.float()).bfloat16()
def unweighted_rmsnorm(x, eps=1e-6):
xf = x.float(); return xf * xf.pow(2).mean(-1, keepdim=True).add(eps).rsqrt()
# =====================================================================
# CUDA Graph Decoder — capture per-layer graphs for zero-dispatch decode
# =====================================================================
class CUDAGraphDecoder:
"""Captures and replays CUDA graphs for the decode loop.
After one warmup step, each layer's compute is captured as a CUDA graph.
Replay eliminates Python dispatch overhead (~94ms for 61 layers) and
kernel launch latency.
Constraints:
- All tensors must have fixed addresses (pre-allocated)
- No dynamic shapes (T=1 decode has fixed shapes)
- No CPU-GPU syncs inside the graph
- The only sync is argmax at the end of each step
Architecture:
- One CUDA graph per (layer, gpu) pair — 61 graphs total
- One graph for (hc_head + norm + lm_head) on cuda:0
- Cross-GPU transfers (X.to(cuda:N)) happen outside graphs
- The warmup step also computes and fixes gsa values
"""
def __init__(self, n_layers, num_gpus, devices):
self.n_layers = n_layers
self.num_gpus = num_gpus
self.devices = devices
self.graphs = {} # (li) -> torch.cuda.CUDAGraph
self.lm_graph = None # single graph for hc_head + norm + lm_head
self.captured = False
# Pre-allocated I/O buffers — fixed addresses for graph capture
# Each layer reads X_in and writes X_out
self.x_in_bufs = {} # li -> tensor on device of layer li
self.x_out_bufs = {} # li -> tensor on device of layer li
self.logits_buf = None # (1, 129280) on cuda:0
def pre_allocate(self, cfg, attn_mhcs, ffn_mhcs, attn_norms, ffn_norms,
kv_caches, compressors, indexers, moe_runners, se_runners,
routers, prod_lins, layer_w, rope_caches, hc_head,
final_norm_w, lm_head_lin):
"""Pre-allocate all I/O buffers with fixed addresses."""
for li in range(self.n_layers):
dev = self.devices[li % self.num_gpus]
# X is (1, 4, 7168) BF16
self.x_in_bufs[li] = torch.zeros(1, 4, cfg["hidden_size"], dtype=torch.bfloat16, device=dev)
self.x_out_bufs[li] = torch.zeros(1, 4, cfg["hidden_size"], dtype=torch.bfloat16, device=dev)
self.logits_buf = torch.zeros(1, cfg.get("vocab_size", 129280), dtype=torch.bfloat16, device='cuda:0')
def capture(self, cfg, attn_mhcs, ffn_mhcs, attn_norms, ffn_norms,
kv_caches, compressors, indexers, moe_runners, se_runners,
routers, prod_lins, layer_w, rope_caches, hc_head,
final_norm_w, lm_head_lin, positions, token_id):
"""Capture CUDA graphs for all layers + lm_head.
Must be called after one warmup step so that:
1. All CuTeDSL kernels are compiled and cached
2. gsa values are fixed (from warmup_gsa)
3. CUDA kernels are warmed up (first launch is often slower)
"""
print(" Capturing CUDA graphs for decode...", flush=True)
# Capture each layer as a separate graph
for li in range(self.n_layers):
gpu = li % self.num_gpus
dev = self.devices[gpu]
torch.cuda.set_device(gpu)
# Copy current X into the fixed input buffer
# (In practice, the warmup step's X is already on the right device)
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
X_out = forward_layer(
self.x_in_bufs[li], layer_w[li], li, cfg, *rope_caches[gpu],
attn_mhcs.get(li), ffn_mhcs.get(li),
attn_norms.get(li), ffn_norms.get(li),
kv_caches[li], positions, token_id,
compressors.get(li), indexers.get(li),
moe_runners.get(li), se_runners.get(li), routers.get(li),
prod_lin=prod_lins.get(li)
)
# Copy output to fixed buffer
self.x_out_bufs[li].copy_(X_out)
self.graphs[li] = graph
if (li + 1) % 10 == 0:
print(f" Captured {li+1}/{self.n_layers} layer graphs", flush=True)
# Capture hc_head + norm + lm_head on cuda:0
torch.cuda.set_device(0)
self.lm_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.lm_graph):
# Note: x_in_bufs for the last layer is on the last layer's device.
# For the lm_head graph, we need the X on cuda:0.
# We'll handle the cross-GPU transfer outside the graph.
x_out = self.x_out_bufs[self.n_layers - 1] # may be on different GPU
x_cuda0 = x_out.to('cuda:0') # This may NOT work in a CUDA graph
# Actually, cross-device memcpy in CUDA graphs is not supported.
# We need to do the transfer outside and use a cuda:0 buffer.
pass # Will handle this differently
self.captured = True
print(f" Captured {len(self.graphs)} layer graphs", flush=True)
# =====================================================================
def dequant_nvfp4(weight, weight_scale, weight_scale_2=None, input_scale=None):
O, I2 = weight.shape; I = I2 * 2
lo = (weight & 0x0F).to(torch.int8); hi = (weight >> 4).to(torch.int8)
lut = FP4_LUT.to(device=weight.device, dtype=torch.float32)
lo_f = lut[(lo & 0x07).long()] * torch.where((lo >> 3).bool(), -1., 1.)
hi_f = lut[(hi & 0x07).long()] * torch.where((hi >> 3).bool(), -1., 1.)
w = torch.stack([lo_f, hi_f], -1).reshape(O, I)
s = weight_scale.float().repeat_interleave(16, 1)
if weight_scale_2 is not None: s = s * weight_scale_2.float()
return (w * s).bfloat16()
def nvfp4_linear_ref(x, weight, weight_scale, weight_scale_2=None, input_scale=None):
return F.linear(x, dequant_nvfp4(weight, weight_scale, weight_scale_2, input_scale))
def get_nvfp4_weight(w, pfx, proj_name):
k = f"{pfx}.{proj_name}"
return (w.get(f"{k}.weight"), w.get(f"{k}.weight_scale"),
w.get(f"{k}.weight_scale_2"), w.get(f"{k}.input_scale"))
def do_nvfp4_linear_ref(x, w, pfx, proj_name):
weight, ws, ws2, isc = get_nvfp4_weight(w, pfx, proj_name)
if weight is None: return None
d = x.device
return nvfp4_linear_ref(x, weight.to(d), ws.to(d),
ws2.to(d) if ws2 is not None else None,
isc.to(d) if isc is not None else None)
# =====================================================================
# Production Nvfp4Linear factory
# =====================================================================
def make_nvfp4_linear(in_features, out_features, device, all_w, pfx, proj_name):
from dsv4.layers.linear import Nvfp4Linear
d = device
weight, ws, ws2, isc = get_nvfp4_weight(all_w, pfx, proj_name)
assert weight is not None, f"{pfx}.{proj_name}.weight not found"
actual_out = weight.shape[0] # N_packed = GEMM output dimension
actual_in = weight.shape[1] * 2 # K_packed * 2 = BF16 input dim (for buffer allocation)
lin = Nvfp4Linear(actual_in, actual_out, max_num_tokens=8192, device=d)
lin.fp4 = [weight.to(d)]; lin.sf = [ws.to(d)]
lin.gs = [1.0] # base gs — finalize_weights will multiply by ws2
lin.ws2 = [ws2.to(d) if ws2 is not None else None]
# CRITICAL FIX: Compute gsa at RUNTIME from actual input magnitude.
# The checkpoint's input_scale is for training-time FP8 quantization.
# Using it as gsa causes E4M3 block scale overflow when x/gsa > 2688.
# We set a placeholder and override in the forward pass.
lin._activation_global_scale = 1.0 / (6.0 * 448.0) # placeholder
lin._use_runtime_gsa = True # flag to compute gsa at runtime
lin.finalize_weights(); return lin
# =====================================================================
# Compressor — CSA (ratio=4) and HCA (ratio=128) [PRODUCTION KERNELS]
# =====================================================================
class Compressor:
"""Production compressor: NVFP4 GEMM projections + CUDA softmax/reduce.
Pipeline:
1. NVFP4 GEMM: hidden_states @ kv_proj → (T, kv_dim) BF16
2. NVFP4 GEMM: hidden_states @ gate_proj → (T, kv_dim) BF16
3. CUDA kernel: token-level softmax + weighted sum + kv_norm
No PyTorch softmax. No reference fallback.
"""
def __init__(self, ratio, head_dim, hidden_size, device):
self.ratio, self.hd, self.H, self.device = ratio, head_dim, hidden_size, device
self.is_csa = (ratio == 4); self.kv_dim = 2 * head_dim if self.is_csa else head_dim
self.kv_lin = None # production Nvfp4Linear for kv_proj
self.gate_lin = None # production Nvfp4Linear for gate_proj
self.ape = None; self.kv_norm_w = None
self._reduce_loaded = False
# P7: Decode buffering — accumulate hidden_states until we have a complete block.
# HCA (r=128): skip GEMMs entirely at T=1 decode (n_complete=0 every time).
# CSA (r=4): buffer 4 decode steps, run GEMMs once per 4 tokens.
self._hs_buffer = None # (buf_len, H) BF16
self._pos_buffer = None # (buf_len,) long
self._buf_len = 0
def load(self, w, pfx, dev=None):
"""Load weights and build production Nvfp4Linear instances."""
if dev is None: dev = self.device
# Build production NVFP4 GEMM instances for the two projections
# kv_proj: in=7168, out=kv_dim (1024 for CSA, 512 for HCA)
# gate_proj: same shapes
kv_w, kv_ws, kv_ws2, kv_isc = get_nvfp4_weight(w, pfx, 'kv_proj')
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(w, pfx, 'gate_proj')
if kv_w is not None:
kv_out = kv_w.shape[0] # N_packed
kv_in = kv_w.shape[1] * 2 # K_packed * 2
self.kv_lin = make_nvfp4_linear(kv_in, kv_out, dev, w, pfx, 'kv_proj')
if gate_w is not None:
gate_out = gate_w.shape[0]
gate_in = gate_w.shape[1] * 2
self.gate_lin = make_nvfp4_linear(gate_in, gate_out, dev, w, pfx, 'gate_proj')
self.ape = w.get(f"{pfx}.position_bias")
self.kv_norm_w = w.get(f"{pfx}.kv_norm.weight")
def forward(self, hidden_states, positions):
if self.ratio == 0 or self.kv_lin is None: return None, None, None
T = hidden_states.shape[0]; r = self.ratio; dev = hidden_states.device
# P7: Buffer decode steps until we have a complete block.
# For HCA (r=128) at T=1 decode: n_complete is always 0, so we skip
# the 2 NVFP4 GEMM launches entirely. No wasted compute.
# For CSA (r=4): accumulate 4 tokens, run GEMMs once.
if T < r:
# Buffer this token's hidden_states + position
if self._hs_buffer is None:
self._hs_buffer = torch.zeros(r, self.H, dtype=torch.bfloat16, device=dev)
self._pos_buffer = torch.zeros(r, dtype=torch.long, device=dev)
if self._buf_len < r:
self._hs_buffer[self._buf_len] = hidden_states[0] if T == 1 else hidden_states[self._buf_len]
self._pos_buffer[self._buf_len] = positions[0] if positions.numel() == 1 else positions[self._buf_len]
self._buf_len += 1
if self._buf_len < r:
return None, None, None # Not enough tokens yet
# We have a full buffer — use it
hidden_states = self._hs_buffer[:self._buf_len]
positions = self._pos_buffer[:self._buf_len]
T = self._buf_len
self._buf_len = 0 # Reset for next block
n_complete = T // r
if n_complete == 0: return None, None, None
# Step 1-2: NVFP4 GEMM projections → BF16, then cast to FP32 for reduce
kv = self.kv_lin(hidden_states).float() # (T, kv_dim) FP32
gate = self.gate_lin(hidden_states).float() # (T, kv_dim) FP32
# Position bias is handled inside the CUDA kernel (added to both kv and gate)
# Step 3: CUDA softmax/reduce kernel
from dsv4.kernels.compressor.production_compress import csa_compress_production, hca_compress_production
if self.is_csa:
compressed = csa_compress_production(
kv, gate, self.ape, self.kv_norm_w, m=r)
else:
compressed = hca_compress_production(
kv, gate, self.ape, self.kv_norm_w, m=r)
if compressed.shape[0] == 0: return None, None, None
# Vectorized position computation — no Python loop, no .item()
bi = torch.arange(n_complete, device=dev)
pos_idx = ((bi + 1) * r - 1).clamp(max=positions.numel() - 1)
comp_pos = positions[pos_idx]
return compressed, comp_pos, torch.zeros(1, T, n_complete, dtype=torch.float32, device=dev)
# =====================================================================
# Indexer — CSA top-k [PRODUCTION NVFP4 GEMMs]
# =====================================================================
class Indexer:
"""Production indexer: NVFP4 GEMM projections + CUDA score+topk.
Pipeline:
1. NVFP4 GEMM: q_a (lora) @ q_b_proj → (T, n_ih * ihd) BF16
2. NVFP4 GEMM: hidden_states @ weights_proj → (T, n_ih) BF16
3. CUDA kernel: ReLU(Q·K) * w_head → score, top-k selection
"""
def __init__(self, n_ih, ihd, top_k, device):
self.n_ih, self.ihd, self.top_k, self.device = n_ih, ihd, top_k, device
self.q_b_lin = None # production Nvfp4Linear for q_b_proj
self.wp_lin = None # production Nvfp4Linear for weights_proj
self.compressor = None
def load(self, w, pfx, dev=None):
if dev is None: dev = self.device
qb_w, qb_ws, qb_ws2, qb_isc = get_nvfp4_weight(w, pfx, 'q_b_proj')
wp_w, wp_ws, wp_ws2, wp_isc = get_nvfp4_weight(w, pfx, 'weights_proj')
if qb_w is not None:
qb_out = qb_w.shape[0]
qb_in = qb_w.shape[1] * 2
self.q_b_lin = make_nvfp4_linear(qb_in, qb_out, dev, w, pfx, 'q_b_proj')
if wp_w is not None:
wp_out = wp_w.shape[0]
wp_in = wp_w.shape[1] * 2
self.wp_lin = make_nvfp4_linear(wp_in, wp_out, dev, w, pfx, 'weights_proj')
# Indexer compressor weights are directly under the indexer prefix
# (e.g. *.indexer.kv_proj.weight), NOT nested under *.indexer.compressor.
if f"{pfx}.kv_proj.weight" in w:
self.compressor = Compressor(4, self.ihd, 7168, dev)
self.compressor.load(w, pfx, dev)
def forward(self, q_lora, hidden_states, comp_indexer_kv, positions, layer_idx=None):
if self.q_b_lin is None or comp_indexer_kv is None or comp_indexer_kv.shape[0] == 0:
return None
dev = q_lora.device; T = q_lora.shape[0]; n_comp = comp_indexer_kv.shape[0]
# INDEXER PROBE: print shapes at layer_idx==0 only
li = layer_idx
if li == 0:
print(f"\n=== INDEXER PROBE L0 ===", flush=True)
print(f" q_lora: shape={tuple(q_lora.shape)} dtype={q_lora.dtype}", flush=True)
print(f" comp_idx_kv: shape={tuple(comp_indexer_kv.shape)} "
f"dtype={comp_indexer_kv.dtype} stride={comp_indexer_kv.stride()} "
f"contig={comp_indexer_kv.is_contiguous()}", flush=True)
print(f" self.n_ih={self.n_ih} self.ihd={self.ihd} n_ih*ihd={self.n_ih * self.ihd}", flush=True)
print(f" self.q_b_lin.in_features={self.q_b_lin.in_features} out_features={self.q_b_lin.out_features}", flush=True)
print(f" self.wp_lin.in_features={self.wp_lin.in_features} out_features={self.wp_lin.out_features}", flush=True)
if self.compressor is not None:
print(f" self.compressor.kv_dim={self.compressor.kv_dim} ratio={self.compressor.ratio} hd={self.compressor.hd}", flush=True)
q_idx = self.q_b_lin(q_lora).reshape(T, self.n_ih, self.ihd) # (T, n_ih, ihd)
w_h = self.wp_lin(hidden_states) # (T, n_ih)
# Stored indexer keys are (n_comp, ihd) — one vector per compressed block,
# shared across all indexer heads (paper's c_I = ihd = 128).
# NOT (n_comp, n_ih, ihd) — there is no per-head key decomposition.
k_idx = comp_indexer_kv # (n_comp, ihd)
if li == 0:
print(f"--- INDEXER L0 SCORING TENSORS ---", flush=True)
print(f" q_idx: shape={tuple(q_idx.shape)} dtype={q_idx.dtype}", flush=True)
print(f" k_idx: shape={tuple(k_idx.shape)} dtype={k_idx.dtype}", flush=True)
print(f" w_h: shape={tuple(w_h.shape)} dtype={w_h.dtype}", flush=True)
# Weighted ReLU MQA scoring (eq. 16):
# score(t, c) = sum_h w_h(t,h) * ReLU(q(t,h) · k(c))
# k is shared across heads: einsum 'tnd,cd->tnc' (c=n_comp, d=ihd)
scores = torch.einsum('tnd,cd->tnc', q_idx.float(), k_idx.float()) # (T, n_ih, n_comp)
scores = F.relu(scores)
total = (scores * w_h.unsqueeze(-1).float()).sum(1) # (T, n_comp)
tk = min(self.top_k, n_comp); _, idx = total.topk(tk, -1); return idx
# =====================================================================
# KV Cache
# =====================================================================
class KVCache:
def __init__(self, head_dim, window_size=128, max_comp=65536, device='cuda:0',
indexer_key_dim=128, compress_ratio=4, indexer_top_k=1024):
self.hd, self.ws, self.dev = head_dim, window_size, device
self.idx_key_dim = indexer_key_dim
self.ratio = compress_ratio
self.swa = torch.zeros(window_size, head_dim, dtype=torch.bfloat16, device=device)
self.swa_pos = torch.zeros(window_size, dtype=torch.long, device=device)
self.swa_len, self.swa_head = 0, 0
# P3: Pre-allocate compressed KV buffers (no more torch.cat / O(N²) growth)
self.comp_kv_buf = torch.zeros(max_comp, head_dim, dtype=torch.bfloat16, device=device)
self.comp_pos_buf = torch.zeros(max_comp, dtype=torch.long, device=device)
# Indexer compressed keys: width = ihd (c_I in the paper), NOT head_dim
self.comp_idx_buf = torch.zeros(max_comp, indexer_key_dim, dtype=torch.bfloat16, device=device)
# Pre-allocated gather buffer — top_k compressed + SWA window, zero torch.cat on hot path
self.gather_buf = torch.zeros(indexer_top_k + window_size, head_dim, dtype=torch.bfloat16, device=device)
self.n_comp = 0
self._has_idx = False
def append_swa(self, kv, pos):
"""P2: Vectorized SWA append — 2 kernel launches instead of 2T."""
T = kv.shape[0]
idx = (self.swa_head + torch.arange(T, device=self.dev)) % self.ws
self.swa.index_copy_(0, idx, kv)
self.swa_pos.index_copy_(0, idx, pos)
self.swa_head = (self.swa_head + T) % self.ws
self.swa_len = min(self.swa_len + T, self.ws)
def add_compressed(self, ckv, cpos, idx_kv=None):
"""P3: Pre-allocated buffer — O(1) instead of O(N) per call."""
if ckv is None: return
T = ckv.shape[0]
end = self.n_comp + T
self.comp_kv_buf[self.n_comp:end] = ckv
self.comp_pos_buf[self.n_comp:end] = cpos
if idx_kv is not None:
self.comp_idx_buf[self.n_comp:end] = idx_kv
self._has_idx = True
self.n_comp = end
@property
def comp_kv(self):
return self.comp_kv_buf[:self.n_comp] if self.n_comp > 0 else None
@property
def comp_pos(self):
return self.comp_pos_buf[:self.n_comp] if self.n_comp > 0 else None
@property
def comp_idx_kv(self):
return self.comp_idx_buf[:self.n_comp] if self._has_idx and self.n_comp > 0 else None
def get_swa(self):
"""Return SWA KV and positions as views (no clone). Caller copies into gather_buf."""
if self.swa_len == 0:
return self.swa[:0], self.swa_pos[:0]
if self.swa_len < self.ws:
return self.swa[:self.swa_len], self.swa_pos[:self.swa_len]
# Ring buffer wrap — gather non-contiguous rows
idx = torch.arange(self.swa_head, self.swa_head + self.ws, device=self.dev) % self.ws
return self.swa[idx], self.swa_pos[idx]
# =====================================================================
# HcHead
# =====================================================================
HC_EPS = 1e-6
class HcHead:
def __init__(self, hidden_dim=7168, n_hc=4, device='cuda:0'):
self.K, self.device, self.n_hc = n_hc * hidden_dim, device, n_hc
def load(self, fn, base, scale=None):
self.fn = fn.to(self.device, torch.float32).contiguous()
self.base = base.to(self.device, torch.float32).contiguous()
self.scale = scale.to(self.device, torch.float32).item() if scale is not None else 1.0
def forward(self, X):
T = X.shape[0]; Xn = unweighted_rmsnorm(X.reshape(T, self.K).bfloat16())
mix = F.linear(Xn, self.fn[:self.n_hc]).float()
pre = torch.sigmoid(mix * self.scale + self.base[:self.n_hc].unsqueeze(0)) + HC_EPS
return (pre.unsqueeze(-1) * X.float()).sum(1).bfloat16()
# =====================================================================
# Production FMHA
# =====================================================================
def _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w, pfx):
from dsv4.kernels.attention.production import dsv4_attention
# Head-packed dispatch: single kernel launch for all 128 heads (MQA: 1 KV head shared)
q = q_heads.permute(1, 0, 2).contiguous() # (n_h, T, hd)
k = all_kv.unsqueeze(0).contiguous() # (1, N, hd) — MQA single KV head
# K and V are the same in MQA — V = K transposed to (hd, N) format.
# .transpose(-1,-2).contiguous() creates a new tensor (no clone needed).
# This saves one full KV copy (~256KB per layer per decode step).
v = k
sinks = w.get(f"{pfx}.sinks"); sink_bias = None
if sinks is not None: sink_bias = sinks.to(device=dev).float().reshape(n_h)
attn_out = dsv4_attention(q=q, k=k, v=v, scale=scale, n_comp=0, sink_bias=sink_bias)
return attn_out.permute(1, 0, 2) # (T, n_h, hd)
# =====================================================================
# Attention — ALL production kernels
# =====================================================================
def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
kv_cache, positions, compressor, indexer, prod_lin,
_profile_detail=False, _profile_times=None):
dev = x_normed.device; T = x_normed.shape[0]
n_h = cfg["num_attention_heads"]; hd = cfg["head_dim"]; rd = cfg.get("qk_rope_head_dim", 64)
o_groups = cfg.get("o_groups", 16); o_rank = cfg.get("o_lora_rank", 1024)
ratio = compressor.ratio if compressor is not None else 0
scale = 1.0 / math.sqrt(hd); pfx = f"model.layers.{li}.self_attn"
if positions.device != rope_cos.device: positions = positions.to(rope_cos.device)
def _pt(tag):
"""Profile timing helper — records CUDA-sync'd timestamp."""
if _profile_detail and _profile_times is not None:
torch.cuda.synchronize()
_profile_times.append((tag, li, time.perf_counter()))
_pt('q_a_start')
# 1. Q: q_a (NVFP4 GEMM) → q_a_norm → q_b (NVFP4 GEMM) → q_b_norm
q_a = prod_lin['q_a'](x_normed)
_pt('q_a_end')
if VERBOSE >= 2 and li < 3:
# Compare q_a with PyTorch reference
q_a_ref = do_nvfp4_linear_ref(x_normed, w, pfx, 'q_a_proj')
if q_a_ref is not None:
cos_qa = torch.nn.functional.cosine_similarity(q_a.flatten().float(), q_a_ref.flatten().float(), dim=0).item()
print(f" L{li} q_a: |prod|={q_a.abs().max().item():.6f} |ref|={q_a_ref.abs().max().item():.6f} cos={cos_qa:.6f}", flush=True)
q_norm_w = w.get(f"{pfx}.q_a_norm.weight")
if q_norm_w is not None: q_a = rmsnorm(q_a, q_norm_w.to(dev, torch.float32))
_pt('q_b_start')
q = prod_lin['q_b'](q_a); q = unweighted_rmsnorm(q).bfloat16()
_pt('q_b_end')
q_heads = q.reshape(T, n_h, hd); q_heads = _apply_rope(q_heads, positions, rope_cos, rope_sin, rd)
_pt('rope_q_end')
# 2. KV (NVFP4 GEMM, MQA, single KV head)
_pt('kv_start')
kv = prod_lin['kv'](x_normed)
_pt('kv_end')
kv_norm_w = w.get(f"{pfx}.kv_norm.weight")
if kv_norm_w is not None: kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32))
kv_3d = kv.reshape(T, 1, hd); kv_3d = _apply_rope(kv_3d, positions, rope_cos, rope_sin, rd)
_pt('rope_kv_end')
kv_roped = kv_3d.reshape(T, hd); kv_cache.append_swa(kv_roped, positions)
# 3. Compressor → compressed KV
_pt('compress_start')
comp_kv, comp_pos, block_bias = None, None, None; comp_idx_kv = None
if compressor is not None and compressor.ratio > 0:
comp_kv, comp_pos, block_bias = compressor.forward(x_normed, positions)
if comp_kv is not None:
comp_kv_3d = comp_kv.unsqueeze(1)
comp_kv_3d = _apply_rope(comp_kv_3d, comp_pos, rope_cos, rope_sin, rd)
comp_kv = comp_kv_3d.squeeze(1)
if compressor.is_csa and indexer is not None and indexer.compressor is not None:
comp_idx_kv, _, _ = indexer.compressor.forward(x_normed, positions)
kv_cache.add_compressed(comp_kv, comp_pos, comp_idx_kv)
_pt('compress_end')
# 4. Indexer top-k (CSA)
topk_idx = None
if indexer is not None and ratio == 4:
topk_idx = indexer.forward(q_a, x_normed, kv_cache.comp_idx_kv, positions, layer_idx=li)
# 5. Gather KV — pre-allocated buffer, zero torch.cat on hot path
_pt('gather_start')
swa_kv, _swa_pos = kv_cache.get_swa()
swa_len = swa_kv.shape[0]
gbuf = kv_cache.gather_buf # (indexer_top_k + window_size, hd) pre-allocated
if kv_cache.comp_kv is not None and kv_cache.n_comp > 0:
if ratio == 4:
assert topk_idx is not None, f"CSA layer {li}: indexer returned no top-k — indexer is broken"
tk = topk_idx[0].clamp(0, kv_cache.n_comp - 1)
n_tk = tk.shape[0]
gbuf[:n_tk] = kv_cache.comp_kv[tk]
gbuf[n_tk:n_tk + swa_len] = swa_kv
all_kv = gbuf[:n_tk + swa_len]
elif ratio > 4:
n_comp = kv_cache.n_comp
gbuf[:n_comp] = kv_cache.comp_kv
gbuf[n_comp:n_comp + swa_len] = swa_kv
all_kv = gbuf[:n_comp + swa_len]
else:
gbuf[:swa_len] = swa_kv
all_kv = gbuf[:swa_len]
else:
gbuf[:swa_len] = swa_kv
all_kv = gbuf[:swa_len]
seq_len = all_kv.shape[0]
if seq_len == 0: return torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev), q_a
# 6. Production FMHA
_pt('fmha_start')
attn_out = _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w, pfx)
_pt('fmha_end')
if VERBOSE >= 2 and li < 3:
# Compare with PyTorch reference
k_exp = all_kv.unsqueeze(0).expand(n_h, -1, -1).contiguous()
v_exp = k_exp.clone()
q_in = q_heads.permute(1, 0, 2)
ref_scores = torch.matmul(q_in, k_exp.transpose(-1, -2)) * scale
ref_attn = torch.matmul(torch.softmax(ref_scores.float(), -1).bfloat16(), v_exp).permute(1, 0, 2)
cos_sim = torch.nn.functional.cosine_similarity(attn_out.flatten().float(), ref_attn.flatten().float(), dim=0).item()
print(f" L{li} FMHA: |prod|={attn_out.abs().max().item():.6f} |ref|={ref_attn.abs().max().item():.6f} cos={cos_sim:.6f}", flush=True)
# 7. Inverse RoPE
_pt('inv_rope_start')
attn_out = _apply_rope(attn_out, positions, rope_cos, rope_sin, rd, inverse=True)
_pt('inv_rope_end')
# 8. Output: wo_a (NVFP4 grouped GEMM) + wo_b (NVFP4 GEMM)
_pt('o_proj_start')
wo_a_lin = prod_lin.get('o_a')
if wo_a_lin is not None:
# Nvfp4GroupedLinear: (T, n_h, hd) → (T, n_groups, o_rank) → flatten for o_b
g_3d = wo_a_lin.run(attn_out) # (T, n_groups, o_rank) BF16
g_flat = g_3d.reshape(T, -1) # (T, n_groups * o_rank) BF16
F_attn = prod_lin['o_b'](g_flat)
else:
# BF16 grouped BMM fallback (should not happen in production)
hpg_fb = n_h // o_groups; gid_fb = hpg_fb * hd
oa_full = w.get(f"{pfx}.o_a_proj.weight")
if oa_full is not None:
oa_bf = oa_full.bfloat16().to(dev); a_flat = attn_out.reshape(T, n_h * hd)
a_grp = a_flat.reshape(T, o_groups, gid_fb); oa_3d = oa_bf.reshape(o_groups, o_rank, gid_fb)
g_out = torch.bmm(a_grp.permute(1, 0, 2), oa_3d.transpose(1, 2))
g_flat = g_out.permute(1, 0, 2).reshape(T, o_groups * o_rank)
F_attn = prod_lin['o_b'](g_flat)
else:
log.warning(f"L{li}: No o_a_proj weight, zero attention output")
F_attn = torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev)
_pt('o_proj_end')
if VERBOSE >= 2 and li < 3:
print(f" L{li} F_attn: |F_attn|={F_attn.abs().max().item():.6f}", flush=True)
return F_attn, q_a
# =====================================================================
# MoE — production kernels
# =====================================================================
def moe_forward(x, li, moe_runner, se_runner, router, token_id):
# Ensure token_id is on same GPU as router
token_id_dev = token_id.to(x.device) if token_id.device != x.device else token_id
topk_w, topk_ids = router(x, token_ids=token_id_dev)
# DEBUG: check topk_ids validity (only for first 3 and last 3 layers)
if VERBOSE >= 2 and (li < 3 or li >= 58):
if topk_ids.max().item() >= 384 or topk_ids.min().item() < 0:
print(f" L{li} BAD topk_ids: min={topk_ids.min().item()} max={topk_ids.max().item()}", flush=True)
if VERBOSE >= 2 and li >= 58:
print(f" L{li} MoE DIAG: topk_ids={topk_ids[0].tolist()} topk_w=[{','.join(f'{w:.3f}' for w in topk_w[0].tolist())}]", flush=True)
# Also print gate logits for debugging
if hasattr(router, '_gate_lin') and router._gate_lin is not None:
gate_logits = router._gate_lin(x).float()
print(f" L{li} gate logits: [{gate_logits.min().item():.3f}, {gate_logits.max().item():.3f}] mean={gate_logits.mean().item():.3f}", flush=True)
if VERBOSE >= 2 and li < 3:
print(f" L{li} MoE input: |x|={x.abs().max().item():.4f} has_nan={torch.isnan(x).any().item()}", flush=True)
routed_out = moe_runner.run(x, topk_w, topk_ids)
shared_out = se_runner.run(x)
if VERBOSE >= 2 and li >= 58:
print(f" L{li} MoE DIAG: |routed|={routed_out.abs().max().item():.1f} |shared|={shared_out.abs().max().item():.1f} |x|={x.abs().max().item():.1f}", flush=True)
if VERBOSE >= 2 and li < 3:
has_nan = torch.isnan(shared_out).any().item()
out_max = shared_out.abs().max().item() if not has_nan else float('nan')
print(f" L{li} MoE shared: |out|={out_max:.4f} has_nan={has_nan}", flush=True)
# Check weight integrity
if hasattr(se_runner, '_l1_mat_b') and se_runner._l1_mat_b is not None:
wb = se_runner._l1_mat_b.view(torch.uint8)
print(f" L{li} SE l1 weight: shape={list(se_runner._l1_mat_b.shape)} dtype={se_runner._l1_mat_b.dtype} uint8_range=[{wb.min().item()},{wb.max().item()}]", flush=True)
if hasattr(se_runner, '_l1_scale_b') and se_runner._l1_scale_b is not None:
sb = se_runner._l1_scale_b.float()
print(f" L{li} SE l1 scale: shape={list(se_runner._l1_scale_b.shape)} dtype={se_runner._l1_scale_b.dtype} float_range=[{sb.min().item():.6f},{sb.max().item():.6f}] has_nan={torch.isnan(sb).any().item()}", flush=True)
print(f" L{li} SE gsa: l1={se_runner._l1_activation_global_scale:.6f} l2={se_runner._l2_activation_global_scale:.6f} gsb: l1={se_runner._l1_gsb[0].item():.6f} l2={se_runner._l2_gsb[0].item():.6f}", flush=True)
return routed_out + shared_out
# =====================================================================
# Layer forward
# =====================================================================
def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
attn_mhc, ffn_mhc, attn_norm_w, ffn_norm_w,
kv_cache, positions, token_id,
compressor=None, indexer=None,
moe_runner=None, se_runner=None, router=None,
prod_lin=None, _profile_detail=False, _profile_times=None):
x_in, ctx_a = attn_mhc.pre_block(X_l); x_normed = rmsnorm(x_in, attn_norm_w)
if _profile_detail: torch.cuda.synchronize(); t_attn0 = time.perf_counter()
F_attn, _ = forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
kv_cache, positions, compressor, indexer, prod_lin,
_profile_detail=_profile_detail, _profile_times=_profile_times)
if _profile_detail: torch.cuda.synchronize(); t_attn1 = time.perf_counter()
X_mid = attn_mhc.post_block(X_l, F_attn, ctx_a)
x_in_f, ctx_f = ffn_mhc.pre_block(X_mid); x_ffn = rmsnorm(x_in_f, ffn_norm_w)
if _profile_detail: torch.cuda.synchronize(); t_ffn0 = time.perf_counter()
F_ffn = moe_forward(x_ffn, li, moe_runner, se_runner, router, token_id)
if _profile_detail: torch.cuda.synchronize(); t_ffn1 = time.perf_counter()
X_next = ffn_mhc.post_block(X_mid, F_ffn, ctx_f)
if VERBOSE >= 2 and (li < 3 or li >= 58):
print(f" L{li}: |X|={X_l.abs().max().item():.1f}->{X_next.abs().max().item():.1f} "
f"|Fa|={F_attn.abs().max().item():.1f} |Ff|={F_ffn.abs().max().item():.1f}", flush=True)
# Detailed diagnostics — only with VERBOSE >= 2 to avoid .item() syncs on hot path
if VERBOSE >= 2 and (li >= 58 or (li > 0 and X_next.abs().max().item() > 200)):
A_a, B_a, C_a = attn_mhc._dynamic_params(X_l)
A_f, B_f, C_f = ffn_mhc._dynamic_params(X_mid)
print(f" L{li} DIAG: A_attn=[{A_a.min().item():.4f},{A_a.max().item():.4f}] "
f"C_attn=[{C_a.min().item():.4f},{C_a.max().item():.4f}] "
f"A_ffn=[{A_f.min().item():.4f},{A_f.max().item():.4f}] "
f"C_ffn=[{C_f.min().item():.4f},{C_f.max().item():.4f}]", flush=True)
print(f" L{li} DIAG: B_attn row_sum=[{B_a.sum(-1).min().item():.4f},{B_a.sum(-1).max().item():.4f}] "
f"col_sum=[{B_a.sum(-2).min().item():.4f},{B_a.sum(-2).max().item():.4f}] "
f"B_ffn row_sum=[{B_f.sum(-1).min().item():.4f},{B_f.sum(-1).max().item():.4f}] "
f"col_sum=[{B_f.sum(-2).min().item():.4f},{B_f.sum(-2).max().item():.4f}]", flush=True)
print(f" L{li} DIAG: |x_in_attn|={x_in.abs().max().item():.1f} "
f"|x_in_ffn|={x_in_f.abs().max().item():.1f} "
f"|X_l|={X_l.abs().max().item():.1f} "
f"|X_mid|={X_mid.abs().max().item():.1f} "
f"|X_next|={X_next.abs().max().item():.1f}", flush=True)
if _profile_detail and (li < 3 or li == 30 or li >= 58):
torch.cuda.synchronize()
attn_ms = (t_attn1 - t_attn0) * 1000
ffn_ms = (t_ffn1 - t_ffn0) * 1000
print(f" L{li}: attn={attn_ms:.2f}ms ffn={ffn_ms:.2f}ms", flush=True)
return X_next
# =====================================================================
# MoE weight loading
# =====================================================================
def _load_moe_weights_stacked(all_w, li, pfx, dev, moe, cfg):
n_e = cfg["n_routed_experts"]
l1_fp4_list, l1_sf_list, l1_gs_list, l1_ws2_list, l1_gsa_list = [], [], [], [], []
l2_fp4_list, l2_sf_list, l2_gs_list, l2_ws2_list, l2_gsa_list = [], [], [], [], []
for eid in range(n_e):
ep = f"{pfx}.experts.{eid}"
gw, gws, gws2, gisc = get_nvfp4_weight(all_w, ep, 'gate_proj')
uw, uws, uws2, uisc = get_nvfp4_weight(all_w, ep, 'up_proj')
if gw is not None and uw is not None:
l1_fp4_list.append(torch.cat([gw, uw], dim=0).to(dev))
if gws is not None and uws is not None: l1_sf_list.append(torch.cat([gws, uws], dim=0).to(dev))
gs = gisc.float().item() if gisc is not None else 1.0 / (6.0 * 448.0)
l1_gs_list.append(1.0) # gsb base — ws2 will be folded in by _ensure_stacked
l1_gsa_list.append(gs) # gsa = input_scale
# weight_scale_2: scalar, folded into global_scale_b
l1_ws2_list.append(gws2.to(dev) if gws2 is not None else None)
dw, dws, dws2, disc = get_nvfp4_weight(all_w, ep, 'down_proj')
if dw is not None:
l2_fp4_list.append(dw.to(dev))
if dws is not None: l2_sf_list.append(dws.to(dev))
gs2 = disc.float().item() if disc is not None else 1.0 / (6.0 * 448.0)
l2_gs_list.append(1.0) # gsb base
l2_gsa_list.append(gs2) # gsa = input_scale
l2_ws2_list.append(dws2.to(dev) if dws2 is not None else None)
if not l1_fp4_list: log.warning(f"L{li}: No expert weights found"); return
l1_stacked = torch.stack(l1_fp4_list).to(dev)
l1_sf_stacked = torch.stack(l1_sf_list).to(dev) if l1_sf_list else None
l2_stacked = torch.stack(l2_fp4_list).to(dev) if l2_fp4_list else None
l2_sf_stacked = torch.stack(l2_sf_list).to(dev) if l2_sf_list else None
del l1_fp4_list, l1_sf_list, l2_fp4_list, l2_sf_list
moe.prepare_weights_from_stacked(l1_stacked, l1_sf_stacked, l1_gs_list, l2_stacked, l2_sf_stacked, l2_gs_list)
# Save activation global scales — _ensure_stacked will override them from l1_gs (which is 1.0)
# We must re-set them AFTER _ensure_stacked
moe._saved_l1_gsa = l1_gsa_list[0] if l1_gsa_list else 1.0 / (6.0 * 448.0)
moe._saved_l2_gsa = l2_gsa_list[0] if l2_gsa_list else 1.0 / (6.0 * 448.0)
moe.l1_ws2 = l1_ws2_list
moe.l2_ws2 = l2_ws2_list
def _load_shared_expert_weights(all_w, li, pfx, dev, se, cfg):
gw, gws, gws2, gisc = get_nvfp4_weight(all_w, f"{pfx}.shared_experts", 'gate_proj')
uw, uws, uws2, uisc = get_nvfp4_weight(all_w, f"{pfx}.shared_experts", 'up_proj')
dw, dws, dws2, disc = get_nvfp4_weight(all_w, f"{pfx}.shared_experts", 'down_proj')
if gw is not None and uw is not None:
se.l1_fp4 = [torch.cat([gw, uw], dim=0).to(dev)]
se.l1_sf = [torch.cat([gws, uws], dim=0).to(dev)] if gws is not None and uws is not None else [torch.zeros(1, device=dev, dtype=torch.float8_e4m3fn)]
l1_isc = gisc.float().item() if gisc is not None else 1.0 / (6.0 * 448.0)
se.l1_gs = [1.0] # gsb base — ws2 will be folded in by finalize_weights
se.l1_ws2 = [gws2.to(dev) if gws2 is not None else None]
se._l1_activation_global_scale = l1_isc # Will be overridden by _ensure_initialized
se._saved_l1_gsa = l1_isc # Save for after _ensure_initialized
if dw is not None:
se.l2_fp4 = [dw.to(dev)]
se.l2_sf = [dws.to(dev)] if dws is not None else [torch.zeros(1, device=dev, dtype=torch.float8_e4m3fn)]
l2_isc = disc.float().item() if disc is not None else 1.0 / (6.0 * 448.0)
se.l2_gs = [1.0] # gsb base
se.l2_ws2 = [dws2.to(dev) if dws2 is not None else None]
se._l2_activation_global_scale = l2_isc # Will be overridden by _ensure_initialized
se._saved_l2_gsa = l2_isc # Save for after _ensure_initialized
def _cache_layer_weights_no_experts(all_w, n_layers, devices):
cached = {}
for li in range(n_layers):
dev = devices[li % len(devices)]; pfx = f"model.layers.{li}."
w = {k: v.to(device=dev, non_blocking=True) for k, v in all_w.items()
if k.startswith(pfx) and '.experts.' not in k and '.shared_experts.' not in k}
cached[li] = w
if (li+1) % 10 == 0: log.info(f" Cached {li+1}/{n_layers} layers")
return cached
# =====================================================================
# Main
# =====================================================================
def kill_stale_gpu_processes():
"""Kill any leftover python processes on all GPUs before starting."""
import subprocess
try:
result = subprocess.run(['nvidia-smi', '--query-compute-apps=pid', '--format=csv,noheader'],
capture_output=True, text=True, timeout=5)
if result.returncode == 0 and result.stdout.strip():
pids = [p.strip() for p in result.stdout.strip().split('\n') if p.strip()]
for pid in pids:
try:
import os; os.kill(int(pid), 9)
log.info(f" Killed stale GPU process {pid}")
except (ValueError, ProcessLookupError):
pass
except Exception as e:
log.warning(f"Could not check GPU processes: {e}")
def main():
t0 = time.time(); torch.manual_seed(SEED)
print("=" * 70)
print("DSV4 Single-Shot Inference - PRODUCTION KERNEL STACK")
print(" FMHA: 6-warp TMA multi-tile + sink bias")
print(" NVFP4 GEMM (CuTeDSL) for ALL projections")
print(" Production MoE + Router | Production mHC")
print(" NO PyTorch SDPA | NO dequant+matmul | NO reference fallback")
print("=" * 70)
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
cfg = json.load(f)
n_layers = cfg["num_hidden_layers"]; H = cfg["hidden_size"]
hd = cfg["head_dim"]; n_h = cfg["num_attention_heads"]
rd = cfg.get("qk_rope_head_dim", 64)
cr = cfg.get("compress_ratios", [128] * n_layers)
o_groups = cfg.get("o_groups", 16); o_rank = cfg.get("o_lora_rank", 1024)
print(f"Model: {n_layers} layers, {n_h} heads, hd={hd}, rope_dim={rd}")
print(f"Compress ratios: first5={cr[:5]} len={len(cr)}")
print(f"Experts: {cfg['n_routed_experts']}, top-{cfg.get('num_experts_per_tok', 6)}")
# ---- Phase 1: Load weights ----
print(f"\nPhase 1: Loading weights..."); all_w = load_all_weights(CHECKPOINT_DIR)
print(f" {time.time()-t0:.1f}s")
# ---- Phase 2: Build production components ----
print("Building production components...")
from dsv4.layers.mhc import mHCLayer
from dsv4.layers.router import Router
from dsv4.layers.moe import Nvfp4MoE
from dsv4.layers.shared_expert import Nvfp4SharedExpert
# Kill stale GPU processes from prior runs (OOM, crash, etc.)
kill_stale_gpu_processes()
for g in range(NUM_GPUS): torch.cuda.set_device(g); torch.cuda.empty_cache()
torch.cuda.set_device(0)
# mHC + norms
attn_mhcs, ffn_mhcs, attn_norms, ffn_norms = {}, {}, {}, {}
for li in range(n_layers):
dev = f"cuda:{li % NUM_GPUS}"
for tag, blocks, fn_s, base_s, scale_s in [
("attn", attn_mhcs, f"model.layers.{li}.attn_hc.fn", f"model.layers.{li}.attn_hc.base", f"model.layers.{li}.attn_hc.scale"),
("ffn", ffn_mhcs, f"model.layers.{li}.ffn_hc.fn", f"model.layers.{li}.ffn_hc.base", f"model.layers.{li}.ffn_hc.scale"),
]:
fn, base, scale = all_w.get(fn_s), all_w.get(base_s), all_w.get(scale_s)
if fn is not None and base is not None and scale is not None:
m = mHCLayer(hidden_dim=H, n_hc=4, t_max_sinkhorn=20, device=dev)
n = 4
m.load_weights(
W_pre=fn[0:n].to(dev, torch.float32), W_post=fn[n:2*n].to(dev, torch.float32),
W_comb=fn[2*n:].to(dev, torch.float32),
S_pre=base[0:n].reshape(1, n).to(dev, torch.float32),
S_post=base[n:2*n].reshape(n, 1).to(dev, torch.float32),
S_comb=base[2*n:].reshape(n, n).to(dev, torch.float32),
alpha_pre=scale[0].item(), alpha_post=scale[1].item(), alpha_comb=scale[2].item(),
)
blocks[li] = m
an_k = f"model.layers.{li}.input_layernorm.weight"
if an_k in all_w: attn_norms[li] = all_w[an_k].to(dev, torch.float32)
fn_k = f"model.layers.{li}.post_attention_layernorm.weight"
if fn_k in all_w: ffn_norms[li] = all_w[fn_k].to(dev, torch.float32)
# Production Nvfp4Linear for attention projections
print(" Building production Nvfp4Linear for attention projections...")
prod_lins = {}
# Weight dimensions (from checkpoint):
# q_a_proj: (1536, 3584) uint8 -> in=7168, out=1536
# q_b_proj: (65536, 768) uint8 -> in=1536, out=65536
# kv_proj: (512, 3584) uint8 -> in=7168, out=512
# o_a_proj: (16384, 4096) BF16 -> Nvfp4GroupedLinear (16 groups, 1024×4096 each)
# o_b_proj: (7168, 8192) uint8 -> in=16384, out=7168
from dsv4.layers.grouped_linear import Nvfp4GroupedLinear
for li in range(n_layers):
dev = f"cuda:{li % NUM_GPUS}"; pfx = f"model.layers.{li}.self_attn"
torch.cuda.set_device(li % NUM_GPUS)
pl = {}
pl['q_a'] = make_nvfp4_linear(7168, 1536, dev, all_w, pfx, 'q_a_proj')
pl['q_b'] = make_nvfp4_linear(1536, 65536, dev, all_w, pfx, 'q_b_proj')
pl['kv'] = make_nvfp4_linear(7168, 512, dev, all_w, pfx, 'kv_proj')
# o_a_proj: Nvfp4GroupedLinear (NVFP4 grouped GEMM)
n_local_groups = cfg.get('o_groups', 16)
heads_per_group = n_h // n_local_groups
o_rank_val = cfg.get('o_lora_rank', 1024)
wo_a = Nvfp4GroupedLinear(
n_local_groups=n_local_groups,
heads_per_group=heads_per_group,
head_dim=hd,
o_lora_rank=o_rank_val,
max_num_tokens=8192,
device=dev,
)
oa_w_nvfp4, oa_ws, oa_ws2, oa_isc = get_nvfp4_weight(all_w, pfx, 'o_a_proj')
if oa_w_nvfp4 is not None and oa_ws is not None:
# Checkpoint has NVFP4 weights — load directly (no dequant/re-quant)
wo_a.load_nvfp4_weight(oa_w_nvfp4.to(dev), oa_ws.to(dev),
oa_ws2.to(dev) if oa_ws2 is not None else None,
oa_isc.to(dev) if oa_isc is not None else None)
else:
# BF16 checkpoint weight
oa_bf = all_w.get(f"{pfx}.o_a_proj.weight")
if oa_bf is not None:
wo_a.set_bf16_weight(oa_bf.bfloat16().to(dev))
pl['o_a'] = wo_a
wo_a._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow
pl['o_b'] = make_nvfp4_linear(16384, 7168, dev, all_w, pfx, 'o_b_proj')
prod_lins[li] = pl
if (li+1) % 10 == 0: print(f" {li+1}/{n_layers} layers")
print(" All attention projections: production NVFP4 GEMM (o_a now NVFP4 grouped)")
# Routers, MoE, shared experts
routers, moe_runners, se_runners = {}, {}, {}
for li in range(n_layers):
dev = f"cuda:{li % NUM_GPUS}"; pfx = f"model.layers.{li}.mlp"
torch.cuda.set_device(li % NUM_GPUS); torch.cuda.synchronize()
is_hash = (li < cfg.get("num_hash_layers", 3)) and (f"{pfx}.gate.tid2eid" in all_w)
router = Router(hidden_size=H, num_experts=cfg["n_routed_experts"],
top_k=cfg.get("num_experts_per_tok", 6),
routed_scaling_factor=cfg.get("routed_scaling_factor", 2.5),
mode="hash" if is_hash else "dense",
vocab_size=cfg.get("vocab_size", 128000) if is_hash else None, device=dev)
if is_hash:
router.load_weights(hash_lut=all_w[f"{pfx}.gate.tid2eid"].to(dev, torch.int32))
else:
eb = all_w.get(f"{pfx}.gate.e_score_correction_bias")
# NVFP4 production GEMM for router gate
# Custom CuTeDSL fused kernel crashes MLIR optimizer,
# so we use Nvfp4Linear (proven production path).
from dsv4.layers.linear import Nvfp4Linear
gate_w, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(all_w, pfx, 'gate')
E = cfg["n_routed_experts"]
if gate_w is not None and gate_ws is not None:
# Checkpoint has NVFP4 gate weight (N_packed, K_packed) — correct layout
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
gate_w_view = gate_w.to(dev).view(torch.float4_e2m1fn_x2) if gate_w.dtype == torch.uint8 else gate_w.to(dev)
gate_lin.fp4 = [gate_w_view]
gate_lin.sf = [gate_ws.to(dev)]
ws2_v = gate_ws2.float().item() if gate_ws2 is not None else 1.0
isc_v = gate_isc.float().item() if gate_isc is not None else 1.0/(6.0*448.0)
gate_lin.gs = [1.0]
gate_lin.ws2 = [torch.tensor([ws2_v], device=dev, dtype=torch.float32)]
gate_lin._activation_global_scale = isc_v # placeholder — runtime gsa overrides this
gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow
gate_lin.finalize_weights()
router.load_nvfp4_gate(gate_lin)
router.load_weights(e_bias=eb.to(dev, torch.float32))
if li < 5: print(f" L{li}: NVFP4 router gate (checkpoint)", flush=True)
else:
# BF16 gate weight: quantize to NVFP4
gw = all_w.get(f"{pfx}.gate.weight")
if gw is not None:
g_bf16 = gw if gw.shape == (E, H) else gw.T.contiguous()
g_bf16 = g_bf16.bfloat16().to(dev)
from dsv4.ops.quantize import quantize_to_nvfp4
g_fp4, g_sf, g_gs = quantize_to_nvfp4(g_bf16)
gate_lin = Nvfp4Linear(in_features=H, out_features=E, device=dev)
gate_lin.fp4 = [g_fp4]
gate_lin.sf = [g_sf]
gate_lin.gs = [g_gs]
gate_lin.ws2 = [torch.tensor([g_gs], device=dev, dtype=torch.float32)]
gate_lin._activation_global_scale = 1.0 / (6.0 * 448.0) # placeholder — runtime gsa overrides
gate_lin._use_runtime_gsa = True # compute gsa from actual input to avoid E4M3 overflow
gate_lin.finalize_weights()
router.load_nvfp4_gate(gate_lin)
router.load_weights(e_bias=eb.to(dev, torch.float32))
if li < 5: print(f" L{li}: NVFP4 router gate (quantized, gs={g_gs:.6f})", flush=True)
else:
router.load_weights(e_bias=eb.to(dev, torch.float32))
router.load_weights(e_bias=eb.to(dev, torch.float32))
router.finalize_weights(); routers[li] = router
moe = Nvfp4MoE(num_experts=cfg["n_routed_experts"], hidden_size=H,
intermediate_size=cfg.get("moe_intermediate_size", 3072),
top_k=cfg.get("num_experts_per_tok", 6), device=dev)
moe.set_swiglu_limit(cfg.get("swiglu_limit", 10.0))
# P0: ENABLE fused SwiGLU — NVFP4 GEMM + SiLU in kernel registers.
# Saves 240+ unfused BF16 kernel launches per token (gate_silu, clamp, mul, quantize).
moe.set_fused_swiglu(True)
_load_moe_weights_stacked(all_w, li, pfx, dev, moe, cfg)
# EAGERLY process stacked weights → K-major + swizzle, free raw tensors
moe._ensure_stacked()
# Fix activation global scales — _ensure_stacked sets gsa from l1_gs (which is 1.0)
# FIX: Do NOT use checkpoint input_scale as gsa — causes E4M3 overflow.
# Instead, compute gsa at runtime from actual activation magnitude.
# The MoE runner's compute_activation_global_scales() does this correctly.
# We enable runtime gsa for both MoE and SharedExpert.
moe._use_runtime_gsa = True
moe_runners[li] = moe
se = Nvfp4SharedExpert(hidden_size=H, intermediate_size=cfg.get("moe_intermediate_size", 3072),
device=dev, swiglu_limit=cfg.get("swiglu_limit", 10.0))
_load_shared_expert_weights(all_w, li, pfx, dev, se, cfg)
# P1: ENABLE fused SwiGLU for shared expert (1-group variant of MoE fused kernel)
se.set_fused_swiglu(True)
# EAGERLY process shared expert weights
se._ensure_initialized()
# Fix activation global scales — _ensure_initialized sets gsa from l1_gs (which is 1.0)
# FIX: Same runtime gsa for SharedExpert
se._use_runtime_gsa = True
se_runners[li] = se
if (li+1) % 10 == 0: print(f" Built {li+1}/{n_layers} MoE layers")
torch.cuda.empty_cache()
# Global weights
torch.cuda.set_device(0)
embed_w = all_w.get("model.embed_tokens.weight")
embed = torch.nn.Embedding.from_pretrained(embed_w.bfloat16().to('cuda:0'))
# lm_head: NVFP4 production GEMM
lm_w_raw = all_w.get("lm_head.weight", embed_w).bfloat16().to('cuda:0')
from dsv4.layers.linear import Nvfp4Linear
lm_head_lin = Nvfp4Linear(lm_w_raw.shape[1], lm_w_raw.shape[0], max_num_tokens=8192, device='cuda:0')
from dsv4.ops.quantize import quantize_weight_to_nvfp4
lm_fp4, lm_sf, lm_gs = quantize_weight_to_nvfp4(lm_w_raw.T.contiguous())
lm_head_lin.fp4 = [lm_fp4.permute(1, 0).contiguous()]
lm_head_lin.sf = [lm_sf.permute(1, 0).contiguous()]
lm_head_lin.gs = [lm_gs]
lm_head_lin.ws2 = [None]
lm_head_lin._activation_global_scale = 1.0 / (6.0 * 448.0)
lm_head_lin._use_runtime_gsa = True
lm_head_lin.finalize_weights()
lm_w = None
print(" lm_head: NVFP4 production GEMM")
final_norm_w = all_w.get("model.norm.weight")
if final_norm_w is not None: final_norm_w = final_norm_w.to('cuda:0', torch.float32)
hc_head = HcHead(H, 4, 'cuda:0')
hc_fn = all_w.get("model.hc_head.hc_fn"); hc_base = all_w.get("model.hc_head.hc_base"); hc_scale = all_w.get("model.hc_head.hc_scale")
if hc_fn is not None and hc_base is not None: hc_head.load(hc_fn, hc_base, hc_scale); print(" hc_head loaded")
# RoPE (FP32)
rp = cfg.get("rope_scaling", cfg.get("rope_parameters", {}))
rt = rp.get("type", rp.get("rope_type", "yarn")); rf = rp.get("factor", 16.0)
rtheta = cfg.get("rope_theta", 10000.); romax = rp.get("original_max_position_embeddings", 65536)
rbfast, rbslow = rp.get("beta_fast", 32), rp.get("beta_slow", 1)
rope_caches = {g: build_rope_cache(romax, rd, f"cuda:{g}", rtheta, rt, rf, romax, rbfast, rbslow) for g in range(NUM_GPUS)}
# KV caches, compressors, indexers
kv_caches, compressors, indexers = {}, {}, {}
n_ih = cfg.get("index_n_heads", 64); ihd = cfg.get("index_head_dim", 128); itk = cfg.get("index_topk", 1024)
max_ctx = _args.max_context
print(f" Max context: {max_ctx} tokens (governs KV cache pre-allocation)")
for li in range(n_layers):
dev = f"cuda:{li % NUM_GPUS}"; ratio = cr[li] if li < len(cr) else 128
# C1: max_comp derived from target context and compress ratio
max_comp = (max_ctx + ratio - 1) // ratio if ratio > 0 else 0
kv_caches[li] = KVCache(hd, cfg.get("sliding_window", 128), max_comp=max_comp, device=dev,
indexer_key_dim=ihd, compress_ratio=ratio, indexer_top_k=itk)
if ratio > 0: compressors[li] = Compressor(ratio, hd, H, dev)
if ratio == 4: indexers[li] = Indexer(n_ih, ihd, itk, dev)
# Cache layer weights (no MoE/SE)
print("Caching layer weights to GPUs (excluding MoE expert weights)...")
devs = [f"cuda:{g}" for g in range(NUM_GPUS)]
layer_w = _cache_layer_weights_no_experts(all_w, n_layers, devs)
del all_w; import gc; gc.collect()
for g in range(NUM_GPUS): torch.cuda.set_device(g); torch.cuda.empty_cache()
torch.cuda.set_device(0)
print(f" {time.time()-t0:.1f}s")
# Load compressor/indexer weights
for li in range(n_layers):
pfx = f"model.layers.{li}.self_attn.compressor"
if li in compressors: compressors[li].load(layer_w[li], pfx, dev=f"cuda:{li % NUM_GPUS}")
if li in indexers: indexers[li].load(layer_w[li], f"{pfx}.indexer", dev=f"cuda:{li % NUM_GPUS}")
print(" Compressors/indexers loaded")
# ---- Phase 3: Inference ----
print(f"\nPhase 3: Inference")
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
bos = tokenizer.bos_token_id or 0
if _args.prefill_tokens:
generated = [int(x) for x in _args.prefill_tokens.split(',')]
else:
input_ids = [bos, USER_TOKEN]
input_ids += tokenizer.encode('\n\n' + PROMPT, add_special_tokens=False)
input_ids.append(ASSISTANT_TOKEN)
generated = input_ids
all_tokens = generated.copy()
print(f"Input: {len(generated)} tokens")
# Prefill — one token at a time (decode-style; TODO: batched prefill)
print(f"Prefilling {len(generated)} tokens...")
# Pre-allocate prefill buffers — no per-step torch.tensor()
pre_tid_buf = torch.zeros(1, dtype=torch.long, device='cuda:0')
pre_tid32_buf = torch.zeros(1, dtype=torch.int32, device='cuda:0')
pre_pos_buf = torch.zeros(1, dtype=torch.long, device='cuda:0')
for pi, tid_val in enumerate(generated):
t1 = time.time()
pre_tid_buf[0] = tid_val
pre_tid32_buf[0] = tid_val
pre_pos_buf[0] = pi
X = mHCLayer.init_state(embed(pre_tid_buf))
for li in range(n_layers):
gpu = li % NUM_GPUS
if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}")
torch.cuda.set_device(gpu)
try:
X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu],
attn_mhcs.get(li), ffn_mhcs.get(li),
attn_norms.get(li), ffn_norms.get(li),
kv_caches[li], pre_pos_buf, pre_tid32_buf,
compressors.get(li), indexers.get(li),
moe_runners.get(li), se_runners.get(li), routers.get(li),
prod_lin=prod_lins.get(li))
except Exception as e:
torch.cuda.synchronize()
err = torch.cuda.current_stream(gpu).query()
print(f" CRASH at token {pi} layer {li} gpu {gpu}: {e}", flush=True)
raise
if VERBOSE >= 2 and pi == 0 and li < 3:
torch.cuda.synchronize(gpu)
print(f" Token {pi} L{li}: OK |X|={X.abs().max().item():.1f}", flush=True)
X = X.to('cuda:0'); torch.cuda.set_device(0)
if pi % 10 == 0: print(f" Token {pi}/{len(generated)}: {time.time()-t1:.2f}s", flush=True)
print(f" Prefill done ({time.time()-t0:.1f}s)")
if _args.prefill_only: print("Prefill-only mode, stopping."); return
# ---- Build sampler ----
from dsv4.model.sampler import CUDASampler
sampler = CUDASampler(device='cuda:0', max_penalty_tokens=256)
sample_temp = _args.temperature
sample_topk = _args.top_k
sample_topp = _args.top_p
sample_rep_pen = _args.repetition_penalty
is_greedy = (sample_temp == 0.0)
print(f" Sampler: temp={sample_temp} top_k={sample_topk} top_p={sample_topp} "
f"rep_pen={sample_rep_pen} greedy={is_greedy}")
print(f" DSV4 reasoning model: thinking_start={THINK_START} thinking_end={THINK_END}")
print(f" Thinking tokens are NOT garbage — model uses )、... format")
# Pre-allocate decode buffers — zero per-step allocation
dec_tid_buf = torch.zeros(1, dtype=torch.long, device='cuda:0')
dec_pos_buf = torch.zeros(1, dtype=torch.long, device='cuda:0')
dec_tid32_buf = torch.zeros(1, dtype=torch.int32, device='cuda:0')
# Decode
print(f"\nDecoding (max {MAX_NEW_TOKENS} tokens)...")
in_thinking = False
profile = _args.profile
warmup_gsa = _args.warmup_gsa
prof_embed_layers = 0.0
prof_lm_head = 0.0
prof_sample = 0.0
prof_sample_start = 0.0
# CUDA event profiling — measures ACTUAL GPU time, not wall clock
# Only profile steps 1-3 (after warmup) to get stable results
cuda_events = {}
if profile:
for tag in ['embed', 'layers', 'hc_norm_lm', 'sample', 'diagnostics']:
cuda_events[tag] = (torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True))
# Per-layer category events (sampled on step 1 only)
layer_event_tags = ['mhc_pre', 'attn_proj', 'rope_kv', 'compress_idx', 'fmha', 'inv_rope', 'o_proj',
'mhc_post', 'mhc_pre_ffn', 'router', 'moe', 'shared_expert', 'mhc_post_ffn']
cuda_layer_events = {}
for tag in layer_event_tags:
cuda_layer_events[tag] = (torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True))
layer_event_accum = {tag: 0.0 for tag in layer_event_tags}
layer_event_count = 0
cuda_layer_events = [] # list of (tag, li, timestamp) for fine-grained profiling
for step in range(MAX_NEW_TOKENS):
t1 = time.time()
dec_tid_buf[0] = all_tokens[-1]
dec_tid32_buf[0] = all_tokens[-1]
dec_pos_buf[0] = len(all_tokens) - 1
t_e = time.perf_counter()
X = mHCLayer.init_state(embed(dec_tid_buf))
for li in range(n_layers):
gpu = li % NUM_GPUS
if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}")
torch.cuda.set_device(gpu)
X = forward_layer(X, layer_w[li], li, cfg, *rope_caches[gpu],
attn_mhcs.get(li), ffn_mhcs.get(li),
attn_norms.get(li), ffn_norms.get(li),
kv_caches[li], dec_pos_buf, dec_tid32_buf,
compressors.get(li), indexers.get(li),
moe_runners.get(li), se_runners.get(li), routers.get(li),
prod_lin=prod_lins.get(li),
_profile_detail=(profile and step == 1),
_profile_times=cuda_layer_events if (profile and step == 1) else None)
X = X.to('cuda:0'); torch.cuda.set_device(0)
t_layers = time.perf_counter()
# After first decode step: fix gsa values from runtime amax
# This eliminates amax_gsa kernel launches on subsequent steps
# Only applies to attention linears and router gate (fixed per-projection gsa)
# MoE/SE keep runtime gsa (gsa varies per token)
if warmup_gsa and step == 0:
torch.cuda.synchronize()
n_fixed = 0
for li in range(n_layers):
pl = prod_lins.get(li)
if pl is None: continue
for key, lin in pl.items():
if hasattr(lin, '_gsa_buf') and hasattr(lin, '_use_runtime_gsa') and lin._use_runtime_gsa:
fixed_gsa = lin._gsa_buf.item() # One-time sync
lin._activation_global_scale = fixed_gsa
lin._use_runtime_gsa = False
n_fixed += 1
# Router gate
router = routers.get(li)
if router and hasattr(router, '_gate_lin') and router._gate_lin is not None:
gl = router._gate_lin
if hasattr(gl, '_gsa_buf') and hasattr(gl, '_use_runtime_gsa') and gl._use_runtime_gsa:
fixed_gsa = gl._gsa_buf.item()
gl._activation_global_scale = fixed_gsa
gl._use_runtime_gsa = False
n_fixed += 1
# lm_head
if hasattr(lm_head_lin, '_gsa_buf') and hasattr(lm_head_lin, '_use_runtime_gsa') and lm_head_lin._use_runtime_gsa:
fixed_gsa = lm_head_lin._gsa_buf.item()
lm_head_lin._activation_global_scale = fixed_gsa
lm_head_lin._use_runtime_gsa = False
n_fixed += 1
print(f" Warmup gsa: fixed {n_fixed} projection gsa values from step 0 (MoE/SE keep runtime gsa)", flush=True)
x_out = hc_head.forward(X) if hc_head is not None else X[:, 0, :]
if final_norm_w is not None: x_out = rmsnorm(x_out, final_norm_w)
logits = lm_head_lin(x_out)
if profile: torch.cuda.synchronize()
t_lm = time.perf_counter()
# Check thinking start token logit on first step
if step == 0:
ls = logits.float()
for tid, name in [(THINK_START, 'think_start'), (THINK_END, 'think_end'), (USER_TOKEN, 'user'), (ASSISTANT_TOKEN, 'assistant')]:
print(f" {name}({tid}) logit={ls[0, tid].item():.2f}", flush=True)
# Paris token check — only check known token IDs, no 129K iteration
for t in [11111, 51119, 60107]:
if t < ls.shape[-1]:
print(f" Paris-candidate({t}) logit={ls[0, t].item():.2f}", flush=True)
# Sync for profiling and error check
if profile: torch.cuda.synchronize()
t_sample_start = time.perf_counter()
# Only sync + validate on first 3 steps and every 20th step (reduces pipeline stalls)
if step < 3 or (step + 1) % 20 == 0:
torch.cuda.synchronize() # catch CUDA errors at source
ls = logits.float()
if step < 3 or (step + 1) % 20 == 0:
has_nan = torch.isnan(ls).any().item()
has_inf = torch.isinf(ls).any().item()
print(f" logits: shape={list(logits.shape)} dtype={logits.dtype} "
f"min={ls.min().item():.1f} max={ls.max().item():.1f} "
f"nan={has_nan} inf={has_inf}", flush=True)
if has_nan or has_inf:
print(f" NaN/Inf in logits at step {step}, aborting", flush=True)
break
# Sampling — fused CUDA kernel (or greedy argmax for temp=0)
if is_greedy:
next_id = torch.argmax(logits, -1).item()
else:
sampled = sampler(
logits,
temperature=sample_temp,
top_k=sample_topk,
top_p=sample_topp,
repetition_penalty=sample_rep_pen,
recent_tokens=all_tokens[-256:],
seed=SEED,
)
# Check for async CUDA errors from sampler
if step < 3:
torch.cuda.synchronize()
next_id = sampled[0].item()
all_tokens.append(next_id)
dt = time.time() - t1
if profile: torch.cuda.synchronize()
t_s = time.perf_counter()
# Track thinking state
if next_id == THINK_START: in_thinking = True
elif next_id == THINK_END: in_thinking = False
if profile:
prof_embed_layers += (t_layers - t_e)
prof_lm_head += (t_lm - t_layers)
prof_sample_start = t_sample_start
prof_sample += (t_s - t_sample_start)
# Diagnostics — reduce CPU syncs, only top-5 every 5 steps
if step % 5 == 0 or step < 5:
tv, ti = torch.topk(logits[0].float(), 5)
top5 = ' '.join(f'{tokenizer.decode([t.item()])}({v.item():.1f})' for t, v in zip(ti[:5], tv[:5]))
think_tag = " [THINKING]" if in_thinking else ""
print(f" Step {step}: {next_id} '{tokenizer.decode([next_id])}' ({dt:.2f}s) "
f"logits=[{logits.float().min().item():.1f},{logits.float().max().item():.1f}] "
f"|X|={X.abs().max().item():.1f} top5: {top5}{think_tag}", flush=True)
# NaN safety — periodic check only
if step == 0 or (step+1) % 20 == 0:
if torch.isnan(logits.float()).any().item():
print(f" NaN at step {step}", flush=True); break
if next_id == tokenizer.eos_token_id:
print(f" EOS at step {step}", flush=True); break
if profile and MAX_NEW_TOKENS > 0:
n = MAX_NEW_TOKENS
print(f"\n PROFILE (sync'd wall clock, {n} steps):")
print(f" Embed + 61 layers: {prof_embed_layers:.3f}s total, {prof_embed_layers/n*1000:.1f}ms/token")
print(f" hc_head + norm + lm_head: {prof_lm_head:.3f}s total, {prof_lm_head/n*1000:.1f}ms/token")
print(f" Sampling: {prof_sample:.3f}s total, {prof_sample/n*1000:.1f}ms/token")
# Fine-grained attention profile (from step 1)
if hasattr(cuda_layer_events, '__len__') and len(cuda_layer_events) >= 2:
print(f"\n FINE-GRAINED ATTENTION PROFILE (step 1, CUDA-sync'd):")
prev_t = None
for tag, li, t in cuda_layer_events:
if prev_t is not None:
dt_ms = (t - prev_t) * 1000
if li <= 2 or li >= 58: # Only print for first/last layers
print(f" L{li} {tag}: {dt_ms:.2f}ms")
prev_t = t
out = tokenizer.decode(all_tokens, skip_special_tokens=True)
print(f"\n{'='*70}")
print(f"Input: '{PROMPT}'")
print(f"Output: '{out}'")
print(f"Total: {time.time()-t0:.1f}s")
print(f"{'='*70}")
if __name__ == "__main__":
main()