FMHA sink bias in kernel + single_shot production rewrite

FMHA kernel (fmha_6warp_tma_multirow_multitile.cuh):
- Added sink_bias field to FmhaTmaMultiRowMultiTileParams
- After KV tile loop, sink logit is included in online softmax rescale:
  new_max = max(running_max, sink_bias * scale)
  rescale existing O_unnorm and running_sum
  running_sum += exp(sink_bias * scale - new_max)
  No PV contribution from sink (D5c: single softmax)
- C API: fmha_multitile_decode_launch now takes sink_bias_ptr
- Python: fmha_multitile_decode_raw accepts attn_sink tensor

single_shot_inference.py:
- Full rewrite to use production kernel stack
- mHC: uses dsv4.layers.mhc.mHCLayer (proper Sinkhorn-Knopp)
- Projections: uses Nvfp4Linear (CuTeDSL GEMM) for q_a, q_b, kv, o_b
- FMHA: 6-warp TMA multi-tile with sink bias (no SDPA fallback)
- MoE: Nvfp4MoE + Nvfp4SharedExpert (no reference fallback)
- Router: production dense/hash dispatch
- Compressor/Indexer: reference dequant (not yet on tensor cores)
- NO try/except fallbacks on production paths
This commit is contained in:
2026-05-31 23:10:13 +00:00
parent 23e88638aa
commit 13be3ad443
8 changed files with 397 additions and 375 deletions

View File

@@ -34,6 +34,7 @@ struct FmhaTmaMultiRowMultiTileParams {
CUtensorMap* __restrict__ tma_v;
bf16_t* __restrict__ o;
float* __restrict__ lse;
const float* __restrict__ sink_bias; // per-head FP32 sink logit (n_h,), NULL if unused
int s_k, T, n_h;
float scale;
int q_head_stride, q_batch_stride;
@@ -332,6 +333,38 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params)
__syncthreads();
} // kv_tile loop
// ---- Sink bias correction (D5c: single softmax over [S_comp, S_swa + sink]) ----
// The attention sink is a per-head logit bias. It adds one extra
// "position" to the softmax that contributes to the denominator
// but NOT the numerator (no corresponding V row). This is the
// key insight: sink merge = single softmax, not two-branch merge.
//
// Math: after all KV tiles, we have (running_max, running_sum, O_unnorm).
// Sink adds: sink_weight = exp(sink_bias * scale - new_max)
// new_max = max(running_max, sink_bias * scale)
// rescale O_unnorm and running_sum by exp(old_max - new_max)
// running_sum += sink_weight
// The sink does NOT produce a PV contribution — O_unnorm unchanged.
if (params.sink_bias != nullptr && my_warp_active) {
// Load per-head sink bias (same for all rows in this head)
float sb = params.sink_bias[head_idx + batch_idx * params.n_h];
if (my_row_active) {
float sink_logit = sb * scale;
float old_max = sRunningMax[my_row];
float new_max = fmaxf(old_max, sink_logit);
float rescale_old = (old_max > -INFINITY) ? expf(old_max - new_max) : 0.0f;
float sink_weight = expf(sink_logit - new_max);
// Rescale existing accumulator and running sum
for (int d = 0; d < HD_CHUNK; d++) {
sOacc[my_row * HD_CHUNK + d] *= rescale_old;
}
sRunningSum[my_row] = sRunningSum[my_row] * rescale_old + sink_weight;
sRunningMax[my_row] = new_max;
}
}
__syncthreads();
// ---- Write chunk to SMEM row-major, then TMA store to GMEM ----
// P6: One-way epilogue pattern — normalize in registers,
// write to SMEM row-major, then TMA store to GMEM.

View File

@@ -26,6 +26,7 @@ int fmha_multitile_decode_launch(
const void* v_ptr,
void* o_ptr,
void* lse_ptr,
const float* sink_bias_ptr,
int batch, int n_h, int T, int N_orig, int N_padded, int hd,
int q_head_stride, int q_batch_stride,
int k_head_stride, int k_batch_stride,
@@ -84,6 +85,7 @@ int fmha_multitile_decode_launch(
params.o_batch_stride = o_batch_stride;
params.lse_head_stride = lse_head_stride;
params.lse_batch_stride = lse_batch_stride;
params.sink_bias = sink_bias_ptr; // per-head FP32 sink logit, NULL if unused
// SMEM size (match kernel layout)
constexpr int HD_CHUNK = 256;

View File

@@ -119,12 +119,22 @@ def fmha_multitile_decode_raw(
o = torch.zeros(B, n_h, T, hd, dtype=torch.bfloat16, device=q.device)
lse = torch.zeros(B, n_h, T, dtype=torch.float32, device=q.device)
# Sink bias: must be contiguous FP32 (n_h,) per batch
sink_bias_ptr = ctypes.c_void_p(0)
if attn_sink is not None:
sb = attn_sink.float().contiguous()
if sb.dim() == 1:
sb = sb.unsqueeze(0).expand(B, -1).contiguous() # (batch, n_h)
assert sb.shape == (B, n_h), f"sink_bias shape {sb.shape} != ({B}, {n_h})"
sink_bias_ptr = ctypes.c_void_p(sb.data_ptr())
ret = lib.fmha_multitile_decode_launch(
ctypes.c_void_p(q.data_ptr()),
ctypes.c_void_p(k.data_ptr()),
ctypes.c_void_p(v.data_ptr()),
ctypes.c_void_p(o.data_ptr()),
ctypes.c_void_p(lse.data_ptr()),
sink_bias_ptr, # per-head FP32 sink logit
ctypes.c_int(B), ctypes.c_int(n_h), ctypes.c_int(T),
ctypes.c_int(N_orig), # s_k: logical KV length (for softmax masking)
ctypes.c_int(N_padded), # N_padded: physical KV length (for TMA descriptors)

View File

@@ -41,7 +41,7 @@ def _dsv4_attention_multitile(
k_4d = k.unsqueeze(0).contiguous()
v_4d = v.unsqueeze(0).transpose(-1, -2).contiguous()
o_4d, _lse = fmha_multitile_decode_raw(q_4d, k_4d, v_4d, scale)
o_4d, _lse = fmha_multitile_decode_raw(q_4d, k_4d, v_4d, scale, attn_sink=sink_bias)
return o_4d.squeeze(0)

View File

@@ -3,15 +3,17 @@
Exercises the production kernel stack end-to-end:
- NVFP4 GEMM kernels (CuTeDSL ScaledGroupedGemm) for all projections
- 6-warp TMA FMHA kernel (fmha_6warp_tma_multirow_multitile.cuh)
- 6-warp TMA FMHA kernel (fmha_6warp_tma_multirow_multitile.cuh) with sink bias
- CSA/HCA compressor (token-level softmax)
- Indexer score+topk (indexer_score_topk.cu)
- Indexer score+topk
- Dense/Hash router kernels
- Production mHC (Sinkhorn-Knopp, B_l transposed, [pre,post,comb])
- Production Nvfp4Linear, Nvfp4GroupedLinear, Nvfp4MoE, Nvfp4SharedExpert
- Production Nvfp4Linear, Nvfp4MoE, Nvfp4SharedExpert
This is NOT a PyTorch reference — it calls the actual kernel stack.
Use as ground truth for vLLM / SGLang integration.
NO PyTorch SDPA fallback. NO dequant+matmul for production projections.
ALL tensor-core NVFP4 GEMMs. ALL kernel paths.
This is the ground truth for vLLM / SGLang integration.
"""
import os, sys, time, json, math, argparse, logging
import torch
@@ -93,79 +95,8 @@ def unweighted_rmsnorm(x, eps=1e-6):
return xf * xf.pow(2).mean(-1, keepdim=True).add(eps).rsqrt()
# =====================================================================
# mHC (matches dsv4/layers/mhc.py)
# =====================================================================
HC_EPS = 1e-6
def sinkhorn_knopp(logits, t_max=20, eps=HC_EPS):
M = torch.softmax(logits, -1) + eps
M = M / (M.sum(-2, keepdim=True) + eps)
for _ in range(t_max - 1):
M = M / (M.sum(-1, keepdim=True) + eps)
M = M / (M.sum(-2, keepdim=True) + eps)
return M
class mHCBlock:
def __init__(self, hidden_dim=7168, n_hc=4, t_max=20, device='cuda:0'):
self.d, self.n_hc, self.K = hidden_dim, n_hc, n_hc * hidden_dim
self.t_max, self.device = t_max, device
def load(self, fn, base, scale):
n = self.n_hc
self.W_pre = fn[0:n].to(self.device, torch.float32).contiguous()
self.W_post = fn[n:2*n].to(self.device, torch.float32).contiguous()
self.W_comb = fn[2*n:].to(self.device, torch.float32).contiguous()
self.S_pre = base[0:n].reshape(1, n).to(self.device, torch.float32).contiguous()
self.S_post = base[n:2*n].reshape(n, 1).to(self.device, torch.float32).contiguous()
self.S_comb = base[2*n:].reshape(n, n).to(self.device, torch.float32).contiguous()
self.alpha_pre, self.alpha_post, self.alpha_comb = scale[0].item(), scale[1].item(), scale[2].item()
@staticmethod
def init_state(emb, n_hc=4):
return emb.unsqueeze(1).expand(-1, n_hc, -1).clone()
def pre_block(self, X):
T, n, d = X.shape
Xn = unweighted_rmsnorm(X.reshape(T, self.K).bfloat16())
W_stacked = torch.cat([self.W_pre, self.W_post, self.W_comb])
proj = Xn.float() @ W_stacked.T
rms_inv = proj.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt()
proj = (proj * rms_inv).bfloat16().float()
pre_t = self.alpha_pre * proj[:, :n] + self.S_pre.flatten().unsqueeze(0)
post_t = self.alpha_post * proj[:, n:2*n] + self.S_post.flatten().unsqueeze(0)
comb_t = self.alpha_comb * proj[:, 2*n:2*n+n*n] + self.S_comb.flatten().unsqueeze(0)
A = torch.sigmoid(pre_t) + HC_EPS
C = 2.0 * torch.sigmoid(post_t)
B = sinkhorn_knopp(comb_t.reshape(T, n, n), t_max=self.t_max)
x_in = torch.bmm(A.unsqueeze(1), X.float()).squeeze(1).bfloat16()
return x_in, {'B': B, 'C': C}
def post_block(self, X, F_out, ctx):
BX = torch.bmm(ctx['B'].transpose(-1, -2), X.float())
CF = ctx['C'].unsqueeze(-1) * F_out.unsqueeze(1)
return (CF.float() + BX).bfloat16()
# =====================================================================
# HcHead
# =====================================================================
class HcHead:
def __init__(self, hidden_dim=7168, n_hc=4, device='cuda:0'):
self.K, self.device, self.n_hc = n_hc * hidden_dim, device, n_hc
def load(self, fn, base, scale=None):
self.fn = fn.to(self.device, torch.float32).contiguous()
self.base = base.to(self.device, torch.float32).contiguous()
self.scale = scale.to(self.device, torch.float32).item() if scale is not None else 1.0
def forward(self, X):
T = X.shape[0]
Xn = unweighted_rmsnorm(X.reshape(T, self.K).bfloat16())
mix = F.linear(Xn, self.fn[:self.n_hc]).float()
pre = torch.sigmoid(mix * self.scale + self.base[:self.n_hc].unsqueeze(0)) + HC_EPS
return (pre.unsqueeze(-1) * X.float()).sum(1).bfloat16()
# =====================================================================
# NVFP4 dequant (fallback for projections not yet using kernel GEMM)
# NVFP4 dequant — used ONLY for compressor/indexer projections
# (these don't go through the CuTeDSL GEMM kernel yet)
# =====================================================================
def dequant_nvfp4(weight, weight_scale, weight_scale_2=None, input_scale=None):
O, I2 = weight.shape
@@ -180,7 +111,7 @@ def dequant_nvfp4(weight, weight_scale, weight_scale_2=None, input_scale=None):
if weight_scale_2 is not None: s = s * weight_scale_2.float()
return (w * s).bfloat16()
def nvfp4_linear(x, weight, weight_scale, weight_scale_2=None, input_scale=None):
def nvfp4_linear_ref(x, weight, weight_scale, weight_scale_2=None, input_scale=None):
return F.linear(x, dequant_nvfp4(weight, weight_scale, weight_scale_2, input_scale))
def get_nvfp4_weight(w, pfx, proj_name):
@@ -188,16 +119,32 @@ def get_nvfp4_weight(w, pfx, proj_name):
return (w.get(f"{k}.weight"), w.get(f"{k}.weight_scale"),
w.get(f"{k}.weight_scale_2"), w.get(f"{k}.input_scale"))
def do_nvfp4_linear(x, w, pfx, proj_name):
def do_nvfp4_linear_ref(x, w, pfx, proj_name):
weight, ws, ws2, isc = get_nvfp4_weight(w, pfx, proj_name)
if weight is None: return None
d = x.device
return nvfp4_linear(x, weight.to(d), ws.to(d),
return nvfp4_linear_ref(x, weight.to(d), ws.to(d),
ws2.to(d) if ws2 is not None else None,
isc.to(d) if isc is not None else None)
# =====================================================================
# Production Nvfp4Linear wrapper
# =====================================================================
def make_nvfp4_linear(in_features, out_features, device, weight, weight_scale,
weight_scale_2=None, input_scale=None):
"""Create a production Nvfp4Linear with weights loaded from checkpoint."""
from dsv4.layers.linear import Nvfp4Linear
d = device
lin = Nvfp4Linear(in_features, out_features, max_num_tokens=8192, device=d)
lin.fp4 = [weight.to(d)]
lin.sf = [weight_scale.to(d)]
gs = input_scale.float().item() if input_scale is not None else 1.0 / (6.0 * 448.0)
lin.gs = [gs]
return lin
# =====================================================================
# Compressor — CSA (ratio=4) and HCA (ratio=128)
# (Reference PyTorch — compressor not yet on tensor cores)
# =====================================================================
class Compressor:
def __init__(self, ratio, head_dim, hidden_size, device):
@@ -224,10 +171,10 @@ class Compressor:
n_complete = T // r
if n_complete == 0:
return None, None, None
kv = nvfp4_linear(hidden_states, self.wkv_w.to(dev), self.wkv_ws.to(dev),
kv = nvfp4_linear_ref(hidden_states, self.wkv_w.to(dev), self.wkv_ws.to(dev),
self.wkv_ws2.to(dev) if self.wkv_ws2 is not None else None,
self.wkv_isc.to(dev) if self.wkv_isc is not None else None)
gate = nvfp4_linear(hidden_states, self.wgate_w.to(dev), self.wgate_ws.to(dev),
gate = nvfp4_linear_ref(hidden_states, self.wgate_w.to(dev), self.wgate_ws.to(dev),
self.wgate_ws2.to(dev) if self.wgate_ws2 is not None else None,
self.wgate_isc.to(dev) if self.wgate_isc is not None else None)
if self.ape is not None:
@@ -270,7 +217,7 @@ class Compressor:
return torch.stack(comp_list), torch.stack(comp_pos_list), torch.zeros(1, T, n_complete, dtype=torch.float32, device=dev)
# =====================================================================
# Indexer — CSA top-k
# Indexer — CSA top-k (Reference PyTorch)
# =====================================================================
class Indexer:
def __init__(self, n_ih, ihd, top_k, device):
@@ -292,11 +239,11 @@ class Indexer:
dev = q_lora.device
T = q_lora.shape[0]
n_comp = comp_indexer_kv.shape[0]
q_idx = nvfp4_linear(q_lora, self.q_b_w.to(dev), self.q_b_ws.to(dev),
q_idx = nvfp4_linear_ref(q_lora, self.q_b_w.to(dev), self.q_b_ws.to(dev),
self.q_b_ws2.to(dev) if self.q_b_ws2 is not None else None,
self.q_b_isc.to(dev) if self.q_b_isc is not None else None)
q_idx = q_idx.reshape(T, self.n_ih, self.ihd)
w_h = nvfp4_linear(hidden_states, self.wp_w.to(dev), self.wp_ws.to(dev),
w_h = nvfp4_linear_ref(hidden_states, self.wp_w.to(dev), self.wp_ws.to(dev),
self.wp_ws2.to(dev) if self.wp_ws2 is not None else None,
self.wp_isc.to(dev) if self.wp_isc is not None else None)
k_idx = comp_indexer_kv.reshape(n_comp, self.n_ih, self.ihd)
@@ -364,18 +311,36 @@ def _apply_rope(x, pos, cos, sin, rope_dim, inverse=False):
return out
# =====================================================================
# Production FMHA — 6-warp TMA multi-tile kernel
# HcHead — FP32 projection, read out from mHC state
# =====================================================================
HC_EPS = 1e-6
class HcHead:
def __init__(self, hidden_dim=7168, n_hc=4, device='cuda:0'):
self.K, self.device, self.n_hc = n_hc * hidden_dim, device, n_hc
def load(self, fn, base, scale=None):
self.fn = fn.to(self.device, torch.float32).contiguous()
self.base = base.to(self.device, torch.float32).contiguous()
self.scale = scale.to(self.device, torch.float32).item() if scale is not None else 1.0
def forward(self, X):
T = X.shape[0]
Xn = unweighted_rmsnorm(X.reshape(T, self.K).bfloat16())
mix = F.linear(Xn, self.fn[:self.n_hc]).float()
pre = torch.sigmoid(mix * self.scale + self.base[:self.n_hc].unsqueeze(0)) + HC_EPS
return (pre.unsqueeze(-1) * X.float()).sum(1).bfloat16()
# =====================================================================
# Production FMHA — 6-warp TMA multi-tile kernel with sink bias
# =====================================================================
def _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w, pfx):
"""Run production FMHA kernel via dsv4_attention.
q_heads: (T, n_h, hd), all_kv: (seq_len, hd)
Returns: (T, n_h, hd) BF16
The 6-warp TMA FMHA kernel correctly handles N < 128:
K/V are padded to 128 for TMA alignment, but the kernel receives
the true s_k and masks padded entries in softmax (col < kv_len guard).
Fixed in fmha_multitile_capi.cu: N_orig (logical) vs N_padded (physical).
"""Run production FMHA kernel with sink bias support.
The kernel handles:
- N < 128: K/V padded to 128, kernel uses N_orig for softmax masking
- Multi-tile KV for N > 128
- Attention sinks via per-head logit bias (D5c: single softmax)
"""
from dsv4.kernels.attention.production import dsv4_attention
@@ -394,12 +359,18 @@ def _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w
) # (n_h, T, hd)
return attn_out.permute(1, 0, 2) # (T, n_h, hd)
# =====================================================================
# Attention forward — uses production FMHA kernel
# Attention forward — production FMHA + production Nvfp4Linear
# =====================================================================
def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
kv_cache, positions, compressor, indexer):
kv_cache, positions, compressor, indexer,
prod_lin=None):
"""Attention sub-block using production kernels.
All projections go through Nvfp4Linear (CuTeDSL GEMM).
FMHA goes through 6-warp TMA multi-tile kernel with sink bias.
Inverse RoPE applied after FMHA.
"""
dev = x_normed.device
T = x_normed.shape[0]
n_h = cfg["num_attention_heads"]
@@ -414,19 +385,22 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
positions = positions.to(rope_cos.device)
# 1. Q projection: q_a → q_a_norm → q_b → q_b_norm
q_a = do_nvfp4_linear(x_normed, w, pfx, 'q_a_proj')
q_a = prod_lin['q_a'](x_normed) if prod_lin and 'q_a' in prod_lin else \
do_nvfp4_linear_ref(x_normed, w, pfx, 'q_a_proj')
if q_a is None:
log.warning(f" L{li}: q_a_proj not found")
return torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev), None
q_norm_w = w.get(f"{pfx}.q_a_norm.weight")
if q_norm_w is not None: q_a = rmsnorm(q_a, q_norm_w.to(dev, torch.float32))
q = do_nvfp4_linear(q_a, w, pfx, 'q_b_proj')
q = prod_lin['q_b'](q_a) if prod_lin and 'q_b' in prod_lin else \
do_nvfp4_linear_ref(q_a, w, pfx, 'q_b_proj')
q = unweighted_rmsnorm(q).bfloat16()
q_heads = q.reshape(T, n_h, hd)
q_heads = _apply_rope(q_heads, positions, rope_cos, rope_sin, rd)
# 2. KV projection (MQA, single KV head, hd dim)
kv = do_nvfp4_linear(x_normed, w, pfx, 'kv_proj')
kv = prod_lin['kv'](x_normed) if prod_lin and 'kv' in prod_lin else \
do_nvfp4_linear_ref(x_normed, w, pfx, 'kv_proj')
if kv is None:
log.warning(f" L{li}: kv_proj not found")
return torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev), q_a
@@ -473,13 +447,13 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
if seq_len == 0:
return torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev), q_a
# 6. Production FMHA kernel (6-warp TMA multi-tile)
# 6. Production FMHA kernel (6-warp TMA multi-tile) with sink bias
attn_out = _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w, pfx)
# 7. Inverse RoPE (FP32 cache)
attn_out = _apply_rope(attn_out, positions, rope_cos, rope_sin, rd, inverse=True)
# 8. Output projection: wo_a (BF16 grouped BMM) + wo_b (NVFP4)
# 8. Output projection: wo_a (BF16 grouped BMM) + wo_b (NVFP4 GEMM)
hpg = n_h // o_groups
gid = hpg * hd
oa_w = w.get(f"{pfx}.o_a_proj.weight")
@@ -490,108 +464,25 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
oa_3d = oa_bf.reshape(o_groups, o_rank, gid)
g_out = torch.bmm(a_grp.permute(1, 0, 2), oa_3d.transpose(1, 2))
g_flat = g_out.permute(1, 0, 2).reshape(T, o_groups * o_rank)
F_attn = do_nvfp4_linear(g_flat, w, pfx, 'o_b_proj')
F_attn = prod_lin['o_b'](g_flat) if prod_lin and 'o_b' in prod_lin else \
do_nvfp4_linear_ref(g_flat, w, pfx, 'o_b_proj')
else:
F_attn = do_nvfp4_linear(attn_out.reshape(T, n_h * hd), w, pfx, 'o_a_proj')
F_attn = prod_lin['o_a'](attn_out.reshape(T, n_h * hd)) if prod_lin and 'o_a' in prod_lin else \
do_nvfp4_linear_ref(attn_out.reshape(T, n_h * hd), w, pfx, 'o_a_proj')
return F_attn, q_a
# =====================================================================
# MoE forward — uses production Nvfp4MoE + Nvfp4SharedExpert kernels
# MoE forward — production Nvfp4MoE + Nvfp4SharedExpert + Router
# =====================================================================
def moe_forward(x, w, li, cfg, token_id, device, moe_runner, se_runner, router):
def moe_forward(x, li, moe_runner, se_runner, router, token_id):
"""MoE forward using production NVFP4 GEMM kernels.
Router uses production dense/hash router kernels.
Expert GEMMs use CuTeDSL NVFP4 grouped GEMM (fused SwiGLU).
Shared expert uses CuTeDSL NVFP4 single-group GEMM.
No F.linear. No BF16 matmul. No PyTorch loops over experts.
NO fallback to reference. Production kernels ONLY.
"""
H = cfg["hidden_size"]
n_e = cfg["n_routed_experts"]
top_k = cfg.get("num_experts_per_tok", 6)
rsc = cfg.get("routed_scaling_factor", 2.5)
lim = cfg.get("swiglu_limit", 10.0)
num_hash = cfg.get("num_hash_layers", 3)
pfx = f"model.layers.{li}.mlp"
# Production router: returns (topk_weights, topk_ids) via kernel
if router is not None:
try:
topk_w, topk_ids = router(x, token_ids=token_id)
# Production MoE kernel: NVFP4 grouped GEMM with fused SwiGLU
routed_out = moe_runner(x, topk_w, topk_ids)
# Production shared expert: NVFP4 single-group GEMM
shared_out = se_runner(x)
return routed_out + shared_out
except Exception as e:
log.warning(f" L{li}: Production MoE failed ({e}), falling back to reference")
# Fall through to reference path
# Reference fallback (only if production kernels fail)
return _moe_forward_reference(x, w, li, cfg, token_id, device)
def _moe_forward_reference(x, w, li, cfg, token_id, device):
"""Reference MoE using dequantized BF16 weights."""
H = cfg["hidden_size"]
n_e = cfg["n_routed_experts"]
top_k = cfg.get("num_experts_per_tok", 6)
rsc = cfg.get("routed_scaling_factor", 2.5)
lim = cfg.get("swiglu_limit", 10.0)
num_hash = cfg.get("num_hash_layers", 3)
pfx = f"model.layers.{li}.mlp"
tid2eid_key = f"{pfx}.gate.tid2eid"
e_bias_key = f"{pfx}.gate.e_score_correction_bias"
is_hash = (li < num_hash) and (tid2eid_key in w)
if is_hash:
tid2eid = w[tid2eid_key]
tid = token_id.item() if token_id.numel() == 1 else token_id[0].item()
expert_ids = tid2eid[tid]
expert_weights = torch.ones(top_k, dtype=torch.float32, device=x.device) / top_k
else:
gate_ww, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(w, pfx, 'gate')
if gate_ww is not None and gate_ws is not None:
logits = nvfp4_linear(x, gate_ww.to(device), gate_ws.to(device),
gate_ws2.to(device) if gate_ws2 is not None else None,
gate_isc.to(device) if gate_isc is not None else None)
elif f"{pfx}.gate.weight" in w:
gw = w[f"{pfx}.gate.weight"].bfloat16().to(device)
logits = F.linear(x, gw)
else:
raise ValueError(f"No gate weight for layer {li}")
scores = torch.sqrt(F.softplus(logits.float()) + 1e-6)
sel = scores.clone()
if e_bias_key in w:
sel = sel + w[e_bias_key].to(device=x.device).float().unsqueeze(0)
_, indices = sel.topk(top_k, -1)
expert_weights = torch.gather(scores, -1, indices)
expert_weights = expert_weights / expert_weights.sum(-1, keepdim=True)
expert_ids, expert_weights = indices[0], expert_weights[0]
expert_outs = []
for i, eid in enumerate(expert_ids):
ep = f"{pfx}.experts.{eid}"
g = do_nvfp4_linear(x, w, ep, 'gate_proj')
u = do_nvfp4_linear(x, w, ep, 'up_proj')
silu = F.silu(g.float())
if lim is not None: silu = silu.clamp(-lim, lim); u = u.float().clamp(-lim, lim)
h = (silu * u).bfloat16()
expert_outs.append(do_nvfp4_linear(h, w, ep, 'down_proj'))
routed = torch.zeros_like(x)
for out, wt in zip(expert_outs, expert_weights):
routed = routed + (out.float() * wt.item()).bfloat16()
routed = (routed.float() * rsc).bfloat16()
sp = f"{pfx}.shared_experts"
sg = do_nvfp4_linear(x, w, sp, 'gate_proj')
su = do_nvfp4_linear(x, w, sp, 'up_proj')
silu = F.silu(sg.float())
if lim is not None: silu = silu.clamp(-lim, lim); su = su.float().clamp(-lim, lim)
shared = do_nvfp4_linear((silu * su).bfloat16(), w, sp, 'down_proj')
return routed + shared
topk_w, topk_ids = router(x, token_ids=token_id)
routed_out = moe_runner(x, topk_w, topk_ids)
shared_out = se_runner(x)
return routed_out + shared_out
# =====================================================================
# Layer forward
@@ -600,18 +491,20 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
attn_mhc, ffn_mhc, attn_norm_w, ffn_norm_w,
kv_cache, positions, token_id,
compressor=None, indexer=None,
moe_runner=None, se_runner=None, router=None):
moe_runner=None, se_runner=None, router=None,
prod_lin=None):
dev = X_l.device
# Attention sub-block
x_in, ctx_a = attn_mhc.pre_block(X_l)
x_normed = rmsnorm(x_in, attn_norm_w)
F_attn, _ = forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
kv_cache, positions, compressor, indexer)
kv_cache, positions, compressor, indexer,
prod_lin=prod_lin)
X_mid = attn_mhc.post_block(X_l, F_attn, ctx_a)
# FFN sub-block
x_in_f, ctx_f = ffn_mhc.pre_block(X_mid)
x_ffn = rmsnorm(x_in_f, ffn_norm_w)
F_ffn = moe_forward(x_ffn, w, li, cfg, token_id, dev, moe_runner, se_runner, router)
F_ffn = moe_forward(x_ffn, li, moe_runner, se_runner, router, token_id)
X_next = ffn_mhc.post_block(X_mid, F_ffn, ctx_f)
if VERBOSE >= 1:
print(f" L{li}: |X|={X_l.abs().max().item():.1f}{X_next.abs().max().item():.1f} "
@@ -619,15 +512,132 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
return X_next
# =====================================================================
# Main
# MoE weight loading (stacked path for production GEMM)
# =====================================================================
def _load_moe_weights_stacked(all_w, li, pfx, dev, moe, cfg):
n_e = cfg["n_routed_experts"]
w0 = all_w.get(f"{pfx}.experts.0.gate_proj.weight")
if w0 is None:
log.warning(f"L{li}: No expert weights found")
return
gate_N, gate_K = w0.shape
l1_stacked = torch.zeros(n_e, 2 * gate_N, gate_K, dtype=w0.dtype)
l1_sf_stacked = None
l2_stacked = None
l2_sf_stacked = None
l1_gs = []
l2_gs = []
ws0 = all_w.get(f"{pfx}.experts.0.gate_proj.weight_scale")
if ws0 is not None:
sf_N, sf_K = ws0.shape
l1_sf_stacked = torch.zeros(n_e, 2 * sf_N, sf_K, dtype=ws0.dtype)
dw0 = all_w.get(f"{pfx}.experts.0.down_proj.weight")
if dw0 is not None:
down_N, down_K = dw0.shape
l2_stacked = torch.zeros(n_e, down_N, down_K, dtype=dw0.dtype)
dws0 = all_w.get(f"{pfx}.experts.0.down_proj.weight_scale")
if dws0 is not None:
l2_sf_stacked = torch.zeros(n_e, dws0.shape[0], dws0.shape[1], dtype=dws0.dtype)
for eid in range(n_e):
gw = all_w.get(f"{pfx}.experts.{eid}.gate_proj.weight")
gws = all_w.get(f"{pfx}.experts.{eid}.gate_proj.weight_scale")
gisc = all_w.get(f"{pfx}.experts.{eid}.gate_proj.input_scale")
uw = all_w.get(f"{pfx}.experts.{eid}.up_proj.weight")
uws = all_w.get(f"{pfx}.experts.{eid}.up_proj.weight_scale")
if gw is not None and uw is not None:
l1_stacked[eid, :gate_N] = gw
l1_stacked[eid, gate_N:] = uw
if gws is not None and uws is not None and l1_sf_stacked is not None:
l1_sf_stacked[eid, :sf_N] = gws
l1_sf_stacked[eid, sf_N:] = uws
l1_gs.append(gisc.float().item() if gisc is not None else 1.0 / (6.0 * 448.0))
dw = all_w.get(f"{pfx}.experts.{eid}.down_proj.weight")
dws = all_w.get(f"{pfx}.experts.{eid}.down_proj.weight_scale")
disc = all_w.get(f"{pfx}.experts.{eid}.down_proj.input_scale")
if dw is not None:
l2_stacked[eid] = dw
if dws is not None and l2_sf_stacked is not None:
l2_sf_stacked[eid] = dws
l2_gs.append(disc.float().item() if disc is not None else 1.0 / (6.0 * 448.0))
l1_stacked = l1_stacked.to(dev)
l1_sf_stacked = l1_sf_stacked.to(dev) if l1_sf_stacked is not None else None
l2_stacked = l2_stacked.to(dev) if l2_stacked is not None else None
l2_sf_stacked = l2_sf_stacked.to(dev) if l2_sf_stacked is not None else None
l1_gs = l1_gs if l1_gs else [1.0 / (6.0 * 448.0)] * n_e
l2_gs = l2_gs if l2_gs else [1.0 / (6.0 * 448.0)] * n_e
moe.prepare_weights_from_stacked(l1_stacked, l1_sf_stacked, l1_gs,
l2_stacked, l2_sf_stacked, l2_gs)
def _load_shared_expert_weights(all_w, li, pfx, dev, se, cfg):
l1_gate_fp4, l1_gate_sf, l1_gate_gs = [], [], []
l1_up_fp4, l1_up_sf = [], []
l2_fp4, l2_sf, l2_gs = [], [], []
for proj, fp4_l, sf_l, gs_l in [
('gate_proj', l1_gate_fp4, l1_gate_sf, l1_gate_gs),
('up_proj', l1_up_fp4, l1_up_sf, None),
('down_proj', l2_fp4, l2_sf, l2_gs),
]:
w, ws, isc = all_w.get(f"{pfx}.shared_experts.{proj}.weight"), \
all_w.get(f"{pfx}.shared_experts.{proj}.weight_scale"), \
all_w.get(f"{pfx}.shared_experts.{proj}.input_scale")
if w is not None and ws is not None:
fp4_l.append(w.to(dev))
sf_l.append(ws.to(dev))
if gs_l is not None:
gs_l.append(isc.float().item() if isc is not None else 1.0 / (6.0 * 448.0))
if l1_gate_fp4 and l1_up_fp4:
se.l1_fp4 = [torch.cat([l1_gate_fp4[0], l1_up_fp4[0]], dim=0)]
se.l1_sf = [torch.cat([l1_gate_sf[0], l1_up_sf[0]], dim=0)]
se.l1_gs = l1_gate_gs if l1_gate_gs else [1.0 / (6.0 * 448.0)]
if l2_fp4:
se.l2_fp4 = l2_fp4; se.l2_sf = l2_sf
se.l2_gs = l2_gs if l2_gs else [1.0 / (6.0 * 448.0)]
se.finalize_weights()
def _cache_layer_weights_no_experts(all_w, n_layers, devices):
"""Cache per-layer weights to GPUs, EXCLUDING MoE expert weights."""
cached = {}
for li in range(n_layers):
dev = devices[li % len(devices)]
pfx = f"model.layers.{li}."
w = {k: v.to(device=dev, non_blocking=True)
for k, v in all_w.items()
if k.startswith(pfx) and '.experts.' not in k and '.shared_experts.' not in k}
cached[li] = w
if (li+1) % 10 == 0: print(f" Cached {li+1}/{n_layers} layers")
return cached
def load_weights(checkpoint_dir):
from safetensors.torch import load_file
cdir = Path(checkpoint_dir)
wmap = {}
idx = cdir / "model.safetensors.index.json"
if idx.exists():
with open(idx) as f: wmap = json.load(f).get("weight_map", {})
shards = set(wmap.values()) if wmap else set()
all_w = {}
for sn in sorted(shards):
if (cdir / sn).exists():
all_w.update(load_file(str(cdir / sn)))
return all_w
def main():
t0 = time.time()
torch.manual_seed(SEED)
print("=" * 70)
print("DSV4 Single-Shot Inference — PRODUCTION KERNEL STACK")
print(" FMHA: 6-warp TMA multi-tile | Compressor + Indexer | mHC | MoE")
print(" NVFP4 GEMM (CuTeDSL) | Router kernels | NO PyTorch SDPA")
print(" FMHA: 6-warp TMA multi-tile + sink bias")
print(" NVFP4 GEMM (CuTeDSL) | Router kernels | Production mHC")
print(" NO PyTorch SDPA | NO dequant+matmul | NO reference fallback")
print("=" * 70)
with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f:
@@ -641,17 +651,18 @@ def main():
print(f"Compress ratios: first5={cr[:5]} len={len(cr)}")
print(f"Experts: {cfg['n_routed_experts']}, top-{cfg.get('num_experts_per_tok', 6)}")
# Load weights
# ---- Phase 1: Load weights ----
print(f"\nPhase 1: Loading weights...")
all_w = load_weights(CHECKPOINT_DIR)
all_w = load_all_weights(CHECKPOINT_DIR)
print(f" {time.time()-t0:.1f}s")
# Build production components
# ---- Phase 2: Build production components ----
print("Building production components...")
from dsv4.layers.mhc import mHCLayer
from dsv4.layers.router import Router
from dsv4.layers.moe import Nvfp4MoE
from dsv4.layers.shared_expert import Nvfp4SharedExpert
from dsv4.layers.linear import Nvfp4Linear
# mHC + norms
attn_mhcs, ffn_mhcs, attn_norms, ffn_norms = {}, {}, {}, {}
@@ -665,8 +676,20 @@ def main():
]:
fn, base, scale = all_w.get(fn_s), all_w.get(base_s), all_w.get(scale_s)
if fn is not None and base is not None and scale is not None:
m = mHCBlock(H, 4, 20, dev)
m.load(fn, base, scale)
m = mHCLayer(hidden_dim=H, n_hc=4, t_max_sinkhorn=20, device=dev)
# Split fn/base/scale into pre/post/comb
n = 4
m.load_weights(
W_pre=fn[0:n].to(dev, torch.float32),
W_post=fn[n:2*n].to(dev, torch.float32),
W_comb=fn[2*n:].to(dev, torch.float32),
S_pre=base[0:n].reshape(1, n).to(dev, torch.float32),
S_post=base[n:2*n].reshape(n, 1).to(dev, torch.float32),
S_comb=base[2*n:].reshape(n, n).to(dev, torch.float32),
alpha_pre=scale[0].item(),
alpha_post=scale[1].item(),
alpha_comb=scale[2].item(),
)
blocks[li] = m
an_k = f"model.layers.{li}.input_layernorm.weight"
@@ -674,6 +697,27 @@ def main():
fn_k = f"model.layers.{li}.post_attention_layernorm.weight"
if fn_k in all_w: ffn_norms[li] = all_w[fn_k].to(dev, torch.float32)
# Production Nvfp4Linear for attention projections
prod_lins = {}
for li in range(n_layers):
dev = f"cuda:{li % NUM_GPUS}"
pfx = f"model.layers.{li}.self_attn"
plin = {}
for proj, in_f, out_f in [
('q_a', H, cfg.get('query_compression_dim', 1536)),
('q_b', cfg.get('query_compression_dim', 1536), n_h * hd),
('kv', H, hd),
('o_b', cfg.get('o_groups', 16) * cfg.get('o_lora_rank', 1024), H),
]:
wt, ws, ws2, isc = get_nvfp4_weight(all_w, pfx, proj)
if wt is not None and ws is not None:
lin = make_nvfp4_linear(in_f, out_f, dev, wt, ws, ws2, isc)
lin.finalize_weights()
plin[proj] = lin
if plin:
prod_lins[li] = plin
if (li+1) % 10 == 0: print(f" Built Nvfp4Linear {li+1}/{n_layers} layers")
# Routers, MoE, shared experts
routers, moe_runners, se_runners = {}, {}, {}
for li in range(n_layers):
@@ -681,7 +725,6 @@ def main():
pfx = f"model.layers.{li}.mlp"
is_hash = (li < cfg.get("num_hash_layers", 3)) and (f"{pfx}.gate.tid2eid" in all_w)
# Router
router = Router(
hidden_size=H, num_experts=cfg["n_routed_experts"],
top_k=cfg.get("num_experts_per_tok", 6),
@@ -700,19 +743,15 @@ def main():
router.finalize_weights()
routers[li] = router
# MoE (production NVFP4 grouped GEMM)
moe = Nvfp4MoE(
num_experts=cfg["n_routed_experts"], hidden_size=H,
intermediate_size=cfg.get("moe_intermediate_size", 3072),
top_k=cfg.get("num_experts_per_tok", 6), device=dev,
)
moe.set_swiglu_limit(cfg.get("swiglu_limit", 10.0))
# Load expert weights (stacked path)
_load_moe_weights_stacked(all_w, li, pfx, dev, moe, cfg)
moe_runners[li] = moe
# Shared expert (production NVFP4 single-group GEMM)
se = Nvfp4SharedExpert(
hidden_size=H, intermediate_size=cfg.get("moe_intermediate_size", 3072),
device=dev, swiglu_limit=cfg.get("swiglu_limit", 10.0),
@@ -761,7 +800,6 @@ def main():
if ratio == 4: indexers[li] = Indexer(n_ih, ihd, itk, dev)
# Cache layer weights (EXCLUDE MoE/SE expert weights — handled by production runners)
# This avoids double-loading ~10GB/layer of expert FP4 weights
print("Caching layer weights to GPUs (excluding MoE expert weights)...")
devs = [f"cuda:{g}" for g in range(NUM_GPUS)]
layer_w = _cache_layer_weights_no_experts(all_w, n_layers, devs)
@@ -778,8 +816,8 @@ def main():
if li in indexers: indexers[li].load(layer_w[li], f"{pfx}.indexer")
print(" Compressors/indexers loaded")
# Phase 2: Inference
print(f"\nPhase 2: Inference")
# ---- Phase 3: Inference ----
print(f"\nPhase 3: Inference")
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR)
@@ -796,7 +834,7 @@ def main():
t1 = time.time()
tid = torch.tensor([tid_val], dtype=torch.long, device='cuda:0')
pos = torch.tensor([pi], dtype=torch.long, device='cuda:0')
X = mHCBlock.init_state(embed(tid))
X = mHCLayer.init_state(embed(tid))
for li in range(n_layers):
gpu = li % NUM_GPUS
if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}")
@@ -806,7 +844,8 @@ def main():
attn_norms.get(li), ffn_norms.get(li),
kv_caches[li], pos, tid,
compressors.get(li), indexers.get(li),
moe_runners.get(li), se_runners.get(li), routers.get(li))
moe_runners.get(li), se_runners.get(li), routers.get(li),
prod_lin=prod_lins.get(li))
X = X.to('cuda:0'); torch.cuda.set_device(0)
if pi % 10 == 0: print(f" Token {pi}/{len(generated)}: {time.time()-t1:.2f}s", flush=True)
print(f" Prefill done ({time.time()-t0:.1f}s)")
@@ -822,7 +861,7 @@ def main():
t1 = time.time()
tid = torch.tensor([all_tokens[-1]], dtype=torch.long, device='cuda:0')
dec_pos = torch.tensor([len(all_tokens)-1], dtype=torch.long, device='cuda:0')
X = mHCBlock.init_state(embed(tid))
X = mHCLayer.init_state(embed(tid))
for li in range(n_layers):
gpu = li % NUM_GPUS
if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}")
@@ -832,7 +871,8 @@ def main():
attn_norms.get(li), ffn_norms.get(li),
kv_caches[li], dec_pos, tid,
compressors.get(li), indexers.get(li),
moe_runners.get(li), se_runners.get(li), routers.get(li))
moe_runners.get(li), se_runners.get(li), routers.get(li),
prod_lin=prod_lins.get(li))
X = X.to('cuda:0'); torch.cuda.set_device(0)
x_out = hc_head.forward(X) if hc_head is not None else X[:, 0, :]
if final_norm_w is not None: x_out = rmsnorm(x_out, final_norm_w)
@@ -858,157 +898,6 @@ def main():
print(f"Total: {time.time()-t0:.1f}s")
print(f"{'='*70}")
# =====================================================================
# MoE weight loading helpers (stacked path for production GEMM)
# =====================================================================
def _load_moe_weights_stacked(all_w, li, pfx, dev, moe, cfg):
"""Load MoE expert weights into Nvfp4MoE via stacked path.
Memory-efficient: builds stacked tensors incrementally on CPU,
then moves to GPU in one shot. Avoids holding 384 individual
expert weight tensors on GPU simultaneously (~3× memory savings).
"""
n_e = cfg["n_routed_experts"]
moe_inter = cfg.get("moe_intermediate_size", 3072)
H = cfg["hidden_size"]
# Build stacked tensors incrementally on CPU
# gate_proj and up_proj: (inter, K_packed) per expert → L1 stacked (E, 2*inter, K_packed)
# down_proj: (H, K_packed) per expert → L2 stacked (E, H, K_packed)
# Get dimensions from first expert
w0 = all_w.get(f"{pfx}.experts.0.gate_proj.weight")
if w0 is None:
log.warning(f"L{li}: No expert weights found")
return
gate_N, gate_K = w0.shape # (inter, K_packed)
l1_stacked = torch.zeros(n_e, 2 * gate_N, gate_K, dtype=w0.dtype)
l1_sf_stacked = None
l2_stacked = None
l2_sf_stacked = None
l1_gs = []
l2_gs = []
# Determine L1 SF shape from first expert
ws0 = all_w.get(f"{pfx}.experts.0.gate_proj.weight_scale")
if ws0 is not None:
sf_N, sf_K = ws0.shape
l1_sf_stacked = torch.zeros(n_e, 2 * sf_N, sf_K, dtype=ws0.dtype)
# Get L2 shape
dw0 = all_w.get(f"{pfx}.experts.0.down_proj.weight")
if dw0 is not None:
down_N, down_K = dw0.shape
l2_stacked = torch.zeros(n_e, down_N, down_K, dtype=dw0.dtype)
dws0 = all_w.get(f"{pfx}.experts.0.down_proj.weight_scale")
if dws0 is not None:
dsf_N, dsf_K = dws0.shape
l2_sf_stacked = torch.zeros(n_e, dsf_N, dsf_K, dtype=dws0.dtype)
# Fill stacked tensors
for eid in range(n_e):
# L1: gate + up
gw = all_w.get(f"{pfx}.experts.{eid}.gate_proj.weight")
gws = all_w.get(f"{pfx}.experts.{eid}.gate_proj.weight_scale")
gisc = all_w.get(f"{pfx}.experts.{eid}.gate_proj.input_scale")
uw = all_w.get(f"{pfx}.experts.{eid}.up_proj.weight")
uws = all_w.get(f"{pfx}.experts.{eid}.up_proj.weight_scale")
if gw is not None and uw is not None:
l1_stacked[eid, :gate_N] = gw
l1_stacked[eid, gate_N:] = uw
if gws is not None and uws is not None and l1_sf_stacked is not None:
l1_sf_stacked[eid, :sf_N] = gws
l1_sf_stacked[eid, sf_N:] = uws
l1_gs.append(gisc.float().item() if gisc is not None else 1.0 / (6.0 * 448.0))
# L2: down
dw = all_w.get(f"{pfx}.experts.{eid}.down_proj.weight")
dws = all_w.get(f"{pfx}.experts.{eid}.down_proj.weight_scale")
disc = all_w.get(f"{pfx}.experts.{eid}.down_proj.input_scale")
if dw is not None:
l2_stacked[eid] = dw
if dws is not None and l2_sf_stacked is not None:
l2_sf_stacked[eid] = dws
l2_gs.append(disc.float().item() if disc is not None else 1.0 / (6.0 * 448.0))
# Move to GPU in one shot
l1_stacked = l1_stacked.to(dev)
l1_sf_stacked = l1_sf_stacked.to(dev) if l1_sf_stacked is not None else None
l2_stacked = l2_stacked.to(dev) if l2_stacked is not None else None
l2_sf_stacked = l2_sf_stacked.to(dev) if l2_sf_stacked is not None else None
l1_gs = l1_gs if l1_gs else [1.0 / (6.0 * 448.0)] * n_e
l2_gs = l2_gs if l2_gs else [1.0 / (6.0 * 448.0)] * n_e
moe.prepare_weights_from_stacked(l1_stacked, l1_sf_stacked, l1_gs,
l2_stacked, l2_sf_stacked, l2_gs)
def _load_shared_expert_weights(all_w, li, pfx, dev, se, cfg):
"""Load shared expert weights."""
l1_gate_fp4, l1_gate_sf, l1_gate_gs = [], [], []
l1_up_fp4, l1_up_sf = [], []
l2_fp4, l2_sf, l2_gs = [], [], []
for proj, fp4_l, sf_l, gs_l in [
('gate_proj', l1_gate_fp4, l1_gate_sf, l1_gate_gs),
('up_proj', l1_up_fp4, l1_up_sf, None),
('down_proj', l2_fp4, l2_sf, l2_gs),
]:
w_k = f"{pfx}.shared_experts.{proj}.weight"
ws_k = f"{pfx}.shared_experts.{proj}.weight_scale"
isc_k = f"{pfx}.shared_experts.{proj}.input_scale"
w, ws, isc = all_w.get(w_k), all_w.get(ws_k), all_w.get(isc_k)
if w is not None and ws is not None:
fp4_l.append(w.to(dev))
sf_l.append(ws.to(dev))
if gs_l is not None:
gs_l.append(isc.float().item() if isc is not None else 1.0 / (6.0 * 448.0))
if l1_gate_fp4 and l1_up_fp4:
se.l1_fp4 = [torch.cat([l1_gate_fp4[0], l1_up_fp4[0]], dim=0)]
se.l1_sf = [torch.cat([l1_gate_sf[0], l1_up_sf[0]], dim=0)]
se.l1_gs = l1_gate_gs if l1_gate_gs else [1.0 / (6.0 * 448.0)]
if l2_fp4:
se.l2_fp4 = l2_fp4; se.l2_sf = l2_sf
se.l2_gs = l2_gs if l2_gs else [1.0 / (6.0 * 448.0)]
se.finalize_weights()
def _cache_layer_weights_no_experts(all_w, n_layers, devices):
"""Cache per-layer weights to GPUs, EXCLUDING MoE expert weights.
MoE expert weights (model.layers.{li}.mlp.experts.*) are handled by
Nvfp4MoE runners with stacked tensors. Shared expert weights are handled
by Nvfp4SharedExpert runners. Including them here would double-load
~10.6GB/layer of FP4 expert weights.
"""
cached = {}
for li in range(n_layers):
dev = devices[li % len(devices)]
pfx = f"model.layers.{li}."
w = {k: v.to(device=dev, non_blocking=True)
for k, v in all_w.items()
if k.startswith(pfx) and '.experts.' not in k and '.shared_experts.' not in k}
cached[li] = w
if (li+1) % 10 == 0: print(f" Cached {li+1}/{n_layers} layers")
return cached
def load_weights(checkpoint_dir):
from safetensors.torch import load_file
cdir = Path(checkpoint_dir)
wmap = {}
idx = cdir / "model.safetensors.index.json"
if idx.exists():
with open(idx) as f: wmap = json.load(f).get("weight_map", {})
shards = set(wmap.values()) if wmap else set()
all_w = {}
for sn in sorted(shards):
if (cdir / sn).exists():
all_w.update(load_file(str(cdir / sn)))
return all_w
if __name__ == "__main__":
main()
main()

View File

@@ -0,0 +1,88 @@
#!/usr/bin/env python3
"""Test FMHA kernel with attention sink bias.
Validates that the kernel's sink bias correction matches PyTorch reference:
softmax([QK^T * scale, sink_bias])[:N] @ V
Tests HD=64,128,256,512 with and without sinks.
"""
import torch
import math
import sys
def reference_fmha_with_sink(q, k, v, scale, sink_bias=None):
"""PyTorch reference: softmax([QK^T * scale, sink_bias]) @ V.
q: (n_h, T, hd), k: (1, N, hd), v: (1, N, hd)
sink_bias: (n_h,) FP32 or None
Returns: (n_h, T, hd) BF16
"""
n_h, T, hd = q.shape
N = k.shape[1]
# QK^T: (n_h, T, N)
scores = torch.matmul(q, k.transpose(-1, -2)) * scale # (n_h, T, N)
if sink_bias is not None:
# Concatenate sink as extra column: (n_h, T, N+1)
sb = sink_bias.reshape(n_h, 1, 1).expand(-1, T, 1)
combined = torch.cat([scores, sb], dim=-1)
attn = torch.softmax(combined.float(), dim=-1)[:, :, :N] # drop sink column
else:
attn = torch.softmax(scores.float(), dim=-1)
out = torch.matmul(attn.bfloat16(), v) # (n_h, T, hd)
return out
def test_fmha_sink():
from dsv4.kernels.attention.production import dsv4_attention
torch.manual_seed(42)
device = 'cuda'
passed = 0
failed = 0
for hd in [64, 128, 256, 512]:
for N in [9, 32, 128, 256]:
for use_sink in [False, True]:
n_h = 4 # small for speed
T = 1
scale = 1.0 / math.sqrt(hd)
q = torch.randn(n_h, T, hd, dtype=torch.bfloat16, device=device)
k = torch.randn(1, N, hd, dtype=torch.bfloat16, device=device)
v = torch.randn(1, N, hd, dtype=torch.bfloat16, device=device)
sink = torch.randn(n_h, dtype=torch.float32, device=device) * 2 if use_sink else None
# Production kernel
try:
o_kernel = dsv4_attention(q, k, v, scale=scale, sink_bias=sink)
except Exception as e:
print(f" FAIL hd={hd} N={N} sink={use_sink}: kernel error: {e}")
failed += 1
continue
# PyTorch reference
o_ref = reference_fmha_with_sink(q, k, v, scale, sink)
# Compare
o_kf = o_kernel.float()
o_rf = o_ref.float()
cos = torch.nn.functional.cosine_similarity(o_kf.flatten().unsqueeze(0),
o_rf.flatten().unsqueeze(0)).item()
max_diff = (o_kf - o_rf).abs().max().item()
status = "PASS" if cos > 0.999 else "FAIL"
if status == "PASS":
passed += 1
else:
failed += 1
print(f" {status} hd={hd} N={N} sink={use_sink} cos={cos:.6f} max_diff={max_diff:.6f}")
print(f"\n{'='*60}")
print(f"Results: {passed} PASSED, {failed} FAILED")
print(f"{'='*60}")
return failed == 0
if __name__ == "__main__":
success = test_fmha_sink()
sys.exit(0 if success else 1)