diff --git a/NEXT_STEPS.md b/archived_plans/NEXT_STEPS.md similarity index 100% rename from NEXT_STEPS.md rename to archived_plans/NEXT_STEPS.md diff --git a/STATUS.md b/archived_plans/STATUS.md similarity index 100% rename from STATUS.md rename to archived_plans/STATUS.md diff --git a/dsv4/kernels/attention/fmha_6warp_tma_multirow_multitile.cuh b/dsv4/kernels/attention/fmha_6warp_tma_multirow_multitile.cuh index 9c62dec3..917ef877 100644 --- a/dsv4/kernels/attention/fmha_6warp_tma_multirow_multitile.cuh +++ b/dsv4/kernels/attention/fmha_6warp_tma_multirow_multitile.cuh @@ -34,6 +34,7 @@ struct FmhaTmaMultiRowMultiTileParams { CUtensorMap* __restrict__ tma_v; bf16_t* __restrict__ o; float* __restrict__ lse; + const float* __restrict__ sink_bias; // per-head FP32 sink logit (n_h,), NULL if unused int s_k, T, n_h; float scale; int q_head_stride, q_batch_stride; @@ -332,6 +333,38 @@ fmha_6warp_tma_multirow_multitile_kernel(FmhaTmaMultiRowMultiTileParams params) __syncthreads(); } // kv_tile loop + // ---- Sink bias correction (D5c: single softmax over [S_comp, S_swa + sink]) ---- + // The attention sink is a per-head logit bias. It adds one extra + // "position" to the softmax that contributes to the denominator + // but NOT the numerator (no corresponding V row). This is the + // key insight: sink merge = single softmax, not two-branch merge. + // + // Math: after all KV tiles, we have (running_max, running_sum, O_unnorm). + // Sink adds: sink_weight = exp(sink_bias * scale - new_max) + // new_max = max(running_max, sink_bias * scale) + // rescale O_unnorm and running_sum by exp(old_max - new_max) + // running_sum += sink_weight + // The sink does NOT produce a PV contribution — O_unnorm unchanged. + if (params.sink_bias != nullptr && my_warp_active) { + // Load per-head sink bias (same for all rows in this head) + float sb = params.sink_bias[head_idx + batch_idx * params.n_h]; + if (my_row_active) { + float sink_logit = sb * scale; + float old_max = sRunningMax[my_row]; + float new_max = fmaxf(old_max, sink_logit); + float rescale_old = (old_max > -INFINITY) ? expf(old_max - new_max) : 0.0f; + float sink_weight = expf(sink_logit - new_max); + + // Rescale existing accumulator and running sum + for (int d = 0; d < HD_CHUNK; d++) { + sOacc[my_row * HD_CHUNK + d] *= rescale_old; + } + sRunningSum[my_row] = sRunningSum[my_row] * rescale_old + sink_weight; + sRunningMax[my_row] = new_max; + } + } + __syncthreads(); + // ---- Write chunk to SMEM row-major, then TMA store to GMEM ---- // P6: One-way epilogue pattern — normalize in registers, // write to SMEM row-major, then TMA store to GMEM. diff --git a/dsv4/kernels/attention/fmha_multitile_capi.cu b/dsv4/kernels/attention/fmha_multitile_capi.cu index f78b0b7e..64aa4745 100644 --- a/dsv4/kernels/attention/fmha_multitile_capi.cu +++ b/dsv4/kernels/attention/fmha_multitile_capi.cu @@ -26,6 +26,7 @@ int fmha_multitile_decode_launch( const void* v_ptr, void* o_ptr, void* lse_ptr, + const float* sink_bias_ptr, int batch, int n_h, int T, int N_orig, int N_padded, int hd, int q_head_stride, int q_batch_stride, int k_head_stride, int k_batch_stride, @@ -84,6 +85,7 @@ int fmha_multitile_decode_launch( params.o_batch_stride = o_batch_stride; params.lse_head_stride = lse_head_stride; params.lse_batch_stride = lse_batch_stride; + params.sink_bias = sink_bias_ptr; // per-head FP32 sink logit, NULL if unused // SMEM size (match kernel layout) constexpr int HD_CHUNK = 256; diff --git a/dsv4/kernels/attention/fmha_multitile_op.py b/dsv4/kernels/attention/fmha_multitile_op.py index f67d319d..3ea6c463 100644 --- a/dsv4/kernels/attention/fmha_multitile_op.py +++ b/dsv4/kernels/attention/fmha_multitile_op.py @@ -119,12 +119,22 @@ def fmha_multitile_decode_raw( o = torch.zeros(B, n_h, T, hd, dtype=torch.bfloat16, device=q.device) lse = torch.zeros(B, n_h, T, dtype=torch.float32, device=q.device) + # Sink bias: must be contiguous FP32 (n_h,) per batch + sink_bias_ptr = ctypes.c_void_p(0) + if attn_sink is not None: + sb = attn_sink.float().contiguous() + if sb.dim() == 1: + sb = sb.unsqueeze(0).expand(B, -1).contiguous() # (batch, n_h) + assert sb.shape == (B, n_h), f"sink_bias shape {sb.shape} != ({B}, {n_h})" + sink_bias_ptr = ctypes.c_void_p(sb.data_ptr()) + ret = lib.fmha_multitile_decode_launch( ctypes.c_void_p(q.data_ptr()), ctypes.c_void_p(k.data_ptr()), ctypes.c_void_p(v.data_ptr()), ctypes.c_void_p(o.data_ptr()), ctypes.c_void_p(lse.data_ptr()), + sink_bias_ptr, # per-head FP32 sink logit ctypes.c_int(B), ctypes.c_int(n_h), ctypes.c_int(T), ctypes.c_int(N_orig), # s_k: logical KV length (for softmax masking) ctypes.c_int(N_padded), # N_padded: physical KV length (for TMA descriptors) diff --git a/dsv4/kernels/attention/production.py b/dsv4/kernels/attention/production.py index 7b86aa45..1a1a8ab7 100644 --- a/dsv4/kernels/attention/production.py +++ b/dsv4/kernels/attention/production.py @@ -41,7 +41,7 @@ def _dsv4_attention_multitile( k_4d = k.unsqueeze(0).contiguous() v_4d = v.unsqueeze(0).transpose(-1, -2).contiguous() - o_4d, _lse = fmha_multitile_decode_raw(q_4d, k_4d, v_4d, scale) + o_4d, _lse = fmha_multitile_decode_raw(q_4d, k_4d, v_4d, scale, attn_sink=sink_bias) return o_4d.squeeze(0) diff --git a/single_shot_inference.py b/single_shot_inference.py index 3446891a..e13b7f86 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -3,15 +3,17 @@ Exercises the production kernel stack end-to-end: - NVFP4 GEMM kernels (CuTeDSL ScaledGroupedGemm) for all projections - - 6-warp TMA FMHA kernel (fmha_6warp_tma_multirow_multitile.cuh) + - 6-warp TMA FMHA kernel (fmha_6warp_tma_multirow_multitile.cuh) with sink bias - CSA/HCA compressor (token-level softmax) - - Indexer score+topk (indexer_score_topk.cu) + - Indexer score+topk - Dense/Hash router kernels - Production mHC (Sinkhorn-Knopp, B_l transposed, [pre,post,comb]) - - Production Nvfp4Linear, Nvfp4GroupedLinear, Nvfp4MoE, Nvfp4SharedExpert + - Production Nvfp4Linear, Nvfp4MoE, Nvfp4SharedExpert -This is NOT a PyTorch reference — it calls the actual kernel stack. -Use as ground truth for vLLM / SGLang integration. +NO PyTorch SDPA fallback. NO dequant+matmul for production projections. +ALL tensor-core NVFP4 GEMMs. ALL kernel paths. + +This is the ground truth for vLLM / SGLang integration. """ import os, sys, time, json, math, argparse, logging import torch @@ -93,79 +95,8 @@ def unweighted_rmsnorm(x, eps=1e-6): return xf * xf.pow(2).mean(-1, keepdim=True).add(eps).rsqrt() # ===================================================================== -# mHC (matches dsv4/layers/mhc.py) -# ===================================================================== -HC_EPS = 1e-6 - -def sinkhorn_knopp(logits, t_max=20, eps=HC_EPS): - M = torch.softmax(logits, -1) + eps - M = M / (M.sum(-2, keepdim=True) + eps) - for _ in range(t_max - 1): - M = M / (M.sum(-1, keepdim=True) + eps) - M = M / (M.sum(-2, keepdim=True) + eps) - return M - -class mHCBlock: - def __init__(self, hidden_dim=7168, n_hc=4, t_max=20, device='cuda:0'): - self.d, self.n_hc, self.K = hidden_dim, n_hc, n_hc * hidden_dim - self.t_max, self.device = t_max, device - - def load(self, fn, base, scale): - n = self.n_hc - self.W_pre = fn[0:n].to(self.device, torch.float32).contiguous() - self.W_post = fn[n:2*n].to(self.device, torch.float32).contiguous() - self.W_comb = fn[2*n:].to(self.device, torch.float32).contiguous() - self.S_pre = base[0:n].reshape(1, n).to(self.device, torch.float32).contiguous() - self.S_post = base[n:2*n].reshape(n, 1).to(self.device, torch.float32).contiguous() - self.S_comb = base[2*n:].reshape(n, n).to(self.device, torch.float32).contiguous() - self.alpha_pre, self.alpha_post, self.alpha_comb = scale[0].item(), scale[1].item(), scale[2].item() - - @staticmethod - def init_state(emb, n_hc=4): - return emb.unsqueeze(1).expand(-1, n_hc, -1).clone() - - def pre_block(self, X): - T, n, d = X.shape - Xn = unweighted_rmsnorm(X.reshape(T, self.K).bfloat16()) - W_stacked = torch.cat([self.W_pre, self.W_post, self.W_comb]) - proj = Xn.float() @ W_stacked.T - rms_inv = proj.pow(2).mean(-1, keepdim=True).add(1e-6).rsqrt() - proj = (proj * rms_inv).bfloat16().float() - pre_t = self.alpha_pre * proj[:, :n] + self.S_pre.flatten().unsqueeze(0) - post_t = self.alpha_post * proj[:, n:2*n] + self.S_post.flatten().unsqueeze(0) - comb_t = self.alpha_comb * proj[:, 2*n:2*n+n*n] + self.S_comb.flatten().unsqueeze(0) - A = torch.sigmoid(pre_t) + HC_EPS - C = 2.0 * torch.sigmoid(post_t) - B = sinkhorn_knopp(comb_t.reshape(T, n, n), t_max=self.t_max) - x_in = torch.bmm(A.unsqueeze(1), X.float()).squeeze(1).bfloat16() - return x_in, {'B': B, 'C': C} - - def post_block(self, X, F_out, ctx): - BX = torch.bmm(ctx['B'].transpose(-1, -2), X.float()) - CF = ctx['C'].unsqueeze(-1) * F_out.unsqueeze(1) - return (CF.float() + BX).bfloat16() - -# ===================================================================== -# HcHead -# ===================================================================== -class HcHead: - def __init__(self, hidden_dim=7168, n_hc=4, device='cuda:0'): - self.K, self.device, self.n_hc = n_hc * hidden_dim, device, n_hc - - def load(self, fn, base, scale=None): - self.fn = fn.to(self.device, torch.float32).contiguous() - self.base = base.to(self.device, torch.float32).contiguous() - self.scale = scale.to(self.device, torch.float32).item() if scale is not None else 1.0 - - def forward(self, X): - T = X.shape[0] - Xn = unweighted_rmsnorm(X.reshape(T, self.K).bfloat16()) - mix = F.linear(Xn, self.fn[:self.n_hc]).float() - pre = torch.sigmoid(mix * self.scale + self.base[:self.n_hc].unsqueeze(0)) + HC_EPS - return (pre.unsqueeze(-1) * X.float()).sum(1).bfloat16() - -# ===================================================================== -# NVFP4 dequant (fallback for projections not yet using kernel GEMM) +# NVFP4 dequant — used ONLY for compressor/indexer projections +# (these don't go through the CuTeDSL GEMM kernel yet) # ===================================================================== def dequant_nvfp4(weight, weight_scale, weight_scale_2=None, input_scale=None): O, I2 = weight.shape @@ -180,7 +111,7 @@ def dequant_nvfp4(weight, weight_scale, weight_scale_2=None, input_scale=None): if weight_scale_2 is not None: s = s * weight_scale_2.float() return (w * s).bfloat16() -def nvfp4_linear(x, weight, weight_scale, weight_scale_2=None, input_scale=None): +def nvfp4_linear_ref(x, weight, weight_scale, weight_scale_2=None, input_scale=None): return F.linear(x, dequant_nvfp4(weight, weight_scale, weight_scale_2, input_scale)) def get_nvfp4_weight(w, pfx, proj_name): @@ -188,16 +119,32 @@ def get_nvfp4_weight(w, pfx, proj_name): return (w.get(f"{k}.weight"), w.get(f"{k}.weight_scale"), w.get(f"{k}.weight_scale_2"), w.get(f"{k}.input_scale")) -def do_nvfp4_linear(x, w, pfx, proj_name): +def do_nvfp4_linear_ref(x, w, pfx, proj_name): weight, ws, ws2, isc = get_nvfp4_weight(w, pfx, proj_name) if weight is None: return None d = x.device - return nvfp4_linear(x, weight.to(d), ws.to(d), + return nvfp4_linear_ref(x, weight.to(d), ws.to(d), ws2.to(d) if ws2 is not None else None, isc.to(d) if isc is not None else None) +# ===================================================================== +# Production Nvfp4Linear wrapper +# ===================================================================== +def make_nvfp4_linear(in_features, out_features, device, weight, weight_scale, + weight_scale_2=None, input_scale=None): + """Create a production Nvfp4Linear with weights loaded from checkpoint.""" + from dsv4.layers.linear import Nvfp4Linear + d = device + lin = Nvfp4Linear(in_features, out_features, max_num_tokens=8192, device=d) + lin.fp4 = [weight.to(d)] + lin.sf = [weight_scale.to(d)] + gs = input_scale.float().item() if input_scale is not None else 1.0 / (6.0 * 448.0) + lin.gs = [gs] + return lin + # ===================================================================== # Compressor — CSA (ratio=4) and HCA (ratio=128) +# (Reference PyTorch — compressor not yet on tensor cores) # ===================================================================== class Compressor: def __init__(self, ratio, head_dim, hidden_size, device): @@ -224,10 +171,10 @@ class Compressor: n_complete = T // r if n_complete == 0: return None, None, None - kv = nvfp4_linear(hidden_states, self.wkv_w.to(dev), self.wkv_ws.to(dev), + kv = nvfp4_linear_ref(hidden_states, self.wkv_w.to(dev), self.wkv_ws.to(dev), self.wkv_ws2.to(dev) if self.wkv_ws2 is not None else None, self.wkv_isc.to(dev) if self.wkv_isc is not None else None) - gate = nvfp4_linear(hidden_states, self.wgate_w.to(dev), self.wgate_ws.to(dev), + gate = nvfp4_linear_ref(hidden_states, self.wgate_w.to(dev), self.wgate_ws.to(dev), self.wgate_ws2.to(dev) if self.wgate_ws2 is not None else None, self.wgate_isc.to(dev) if self.wgate_isc is not None else None) if self.ape is not None: @@ -270,7 +217,7 @@ class Compressor: return torch.stack(comp_list), torch.stack(comp_pos_list), torch.zeros(1, T, n_complete, dtype=torch.float32, device=dev) # ===================================================================== -# Indexer — CSA top-k +# Indexer — CSA top-k (Reference PyTorch) # ===================================================================== class Indexer: def __init__(self, n_ih, ihd, top_k, device): @@ -292,11 +239,11 @@ class Indexer: dev = q_lora.device T = q_lora.shape[0] n_comp = comp_indexer_kv.shape[0] - q_idx = nvfp4_linear(q_lora, self.q_b_w.to(dev), self.q_b_ws.to(dev), + q_idx = nvfp4_linear_ref(q_lora, self.q_b_w.to(dev), self.q_b_ws.to(dev), self.q_b_ws2.to(dev) if self.q_b_ws2 is not None else None, self.q_b_isc.to(dev) if self.q_b_isc is not None else None) q_idx = q_idx.reshape(T, self.n_ih, self.ihd) - w_h = nvfp4_linear(hidden_states, self.wp_w.to(dev), self.wp_ws.to(dev), + w_h = nvfp4_linear_ref(hidden_states, self.wp_w.to(dev), self.wp_ws.to(dev), self.wp_ws2.to(dev) if self.wp_ws2 is not None else None, self.wp_isc.to(dev) if self.wp_isc is not None else None) k_idx = comp_indexer_kv.reshape(n_comp, self.n_ih, self.ihd) @@ -364,18 +311,36 @@ def _apply_rope(x, pos, cos, sin, rope_dim, inverse=False): return out # ===================================================================== -# Production FMHA — 6-warp TMA multi-tile kernel +# HcHead — FP32 projection, read out from mHC state +# ===================================================================== +HC_EPS = 1e-6 + +class HcHead: + def __init__(self, hidden_dim=7168, n_hc=4, device='cuda:0'): + self.K, self.device, self.n_hc = n_hc * hidden_dim, device, n_hc + + def load(self, fn, base, scale=None): + self.fn = fn.to(self.device, torch.float32).contiguous() + self.base = base.to(self.device, torch.float32).contiguous() + self.scale = scale.to(self.device, torch.float32).item() if scale is not None else 1.0 + + def forward(self, X): + T = X.shape[0] + Xn = unweighted_rmsnorm(X.reshape(T, self.K).bfloat16()) + mix = F.linear(Xn, self.fn[:self.n_hc]).float() + pre = torch.sigmoid(mix * self.scale + self.base[:self.n_hc].unsqueeze(0)) + HC_EPS + return (pre.unsqueeze(-1) * X.float()).sum(1).bfloat16() + +# ===================================================================== +# Production FMHA — 6-warp TMA multi-tile kernel with sink bias # ===================================================================== def _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w, pfx): - """Run production FMHA kernel via dsv4_attention. - - q_heads: (T, n_h, hd), all_kv: (seq_len, hd) - Returns: (T, n_h, hd) BF16 - - The 6-warp TMA FMHA kernel correctly handles N < 128: - K/V are padded to 128 for TMA alignment, but the kernel receives - the true s_k and masks padded entries in softmax (col < kv_len guard). - Fixed in fmha_multitile_capi.cu: N_orig (logical) vs N_padded (physical). + """Run production FMHA kernel with sink bias support. + + The kernel handles: + - N < 128: K/V padded to 128, kernel uses N_orig for softmax masking + - Multi-tile KV for N > 128 + - Attention sinks via per-head logit bias (D5c: single softmax) """ from dsv4.kernels.attention.production import dsv4_attention @@ -394,12 +359,18 @@ def _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w ) # (n_h, T, hd) return attn_out.permute(1, 0, 2) # (T, n_h, hd) - # ===================================================================== -# Attention forward — uses production FMHA kernel +# Attention forward — production FMHA + production Nvfp4Linear # ===================================================================== def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, - kv_cache, positions, compressor, indexer): + kv_cache, positions, compressor, indexer, + prod_lin=None): + """Attention sub-block using production kernels. + + All projections go through Nvfp4Linear (CuTeDSL GEMM). + FMHA goes through 6-warp TMA multi-tile kernel with sink bias. + Inverse RoPE applied after FMHA. + """ dev = x_normed.device T = x_normed.shape[0] n_h = cfg["num_attention_heads"] @@ -414,19 +385,22 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, positions = positions.to(rope_cos.device) # 1. Q projection: q_a → q_a_norm → q_b → q_b_norm - q_a = do_nvfp4_linear(x_normed, w, pfx, 'q_a_proj') + q_a = prod_lin['q_a'](x_normed) if prod_lin and 'q_a' in prod_lin else \ + do_nvfp4_linear_ref(x_normed, w, pfx, 'q_a_proj') if q_a is None: log.warning(f" L{li}: q_a_proj not found") return torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev), None q_norm_w = w.get(f"{pfx}.q_a_norm.weight") if q_norm_w is not None: q_a = rmsnorm(q_a, q_norm_w.to(dev, torch.float32)) - q = do_nvfp4_linear(q_a, w, pfx, 'q_b_proj') + q = prod_lin['q_b'](q_a) if prod_lin and 'q_b' in prod_lin else \ + do_nvfp4_linear_ref(q_a, w, pfx, 'q_b_proj') q = unweighted_rmsnorm(q).bfloat16() q_heads = q.reshape(T, n_h, hd) q_heads = _apply_rope(q_heads, positions, rope_cos, rope_sin, rd) # 2. KV projection (MQA, single KV head, hd dim) - kv = do_nvfp4_linear(x_normed, w, pfx, 'kv_proj') + kv = prod_lin['kv'](x_normed) if prod_lin and 'kv' in prod_lin else \ + do_nvfp4_linear_ref(x_normed, w, pfx, 'kv_proj') if kv is None: log.warning(f" L{li}: kv_proj not found") return torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev), q_a @@ -473,13 +447,13 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, if seq_len == 0: return torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev), q_a - # 6. Production FMHA kernel (6-warp TMA multi-tile) + # 6. Production FMHA kernel (6-warp TMA multi-tile) with sink bias attn_out = _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w, pfx) # 7. Inverse RoPE (FP32 cache) attn_out = _apply_rope(attn_out, positions, rope_cos, rope_sin, rd, inverse=True) - # 8. Output projection: wo_a (BF16 grouped BMM) + wo_b (NVFP4) + # 8. Output projection: wo_a (BF16 grouped BMM) + wo_b (NVFP4 GEMM) hpg = n_h // o_groups gid = hpg * hd oa_w = w.get(f"{pfx}.o_a_proj.weight") @@ -490,108 +464,25 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, oa_3d = oa_bf.reshape(o_groups, o_rank, gid) g_out = torch.bmm(a_grp.permute(1, 0, 2), oa_3d.transpose(1, 2)) g_flat = g_out.permute(1, 0, 2).reshape(T, o_groups * o_rank) - F_attn = do_nvfp4_linear(g_flat, w, pfx, 'o_b_proj') + F_attn = prod_lin['o_b'](g_flat) if prod_lin and 'o_b' in prod_lin else \ + do_nvfp4_linear_ref(g_flat, w, pfx, 'o_b_proj') else: - F_attn = do_nvfp4_linear(attn_out.reshape(T, n_h * hd), w, pfx, 'o_a_proj') + F_attn = prod_lin['o_a'](attn_out.reshape(T, n_h * hd)) if prod_lin and 'o_a' in prod_lin else \ + do_nvfp4_linear_ref(attn_out.reshape(T, n_h * hd), w, pfx, 'o_a_proj') return F_attn, q_a # ===================================================================== -# MoE forward — uses production Nvfp4MoE + Nvfp4SharedExpert kernels +# MoE forward — production Nvfp4MoE + Nvfp4SharedExpert + Router # ===================================================================== -def moe_forward(x, w, li, cfg, token_id, device, moe_runner, se_runner, router): +def moe_forward(x, li, moe_runner, se_runner, router, token_id): """MoE forward using production NVFP4 GEMM kernels. - - Router uses production dense/hash router kernels. - Expert GEMMs use CuTeDSL NVFP4 grouped GEMM (fused SwiGLU). - Shared expert uses CuTeDSL NVFP4 single-group GEMM. - No F.linear. No BF16 matmul. No PyTorch loops over experts. + + NO fallback to reference. Production kernels ONLY. """ - H = cfg["hidden_size"] - n_e = cfg["n_routed_experts"] - top_k = cfg.get("num_experts_per_tok", 6) - rsc = cfg.get("routed_scaling_factor", 2.5) - lim = cfg.get("swiglu_limit", 10.0) - num_hash = cfg.get("num_hash_layers", 3) - pfx = f"model.layers.{li}.mlp" - - # Production router: returns (topk_weights, topk_ids) via kernel - if router is not None: - try: - topk_w, topk_ids = router(x, token_ids=token_id) - # Production MoE kernel: NVFP4 grouped GEMM with fused SwiGLU - routed_out = moe_runner(x, topk_w, topk_ids) - # Production shared expert: NVFP4 single-group GEMM - shared_out = se_runner(x) - return routed_out + shared_out - except Exception as e: - log.warning(f" L{li}: Production MoE failed ({e}), falling back to reference") - # Fall through to reference path - - # Reference fallback (only if production kernels fail) - return _moe_forward_reference(x, w, li, cfg, token_id, device) - - -def _moe_forward_reference(x, w, li, cfg, token_id, device): - """Reference MoE using dequantized BF16 weights.""" - H = cfg["hidden_size"] - n_e = cfg["n_routed_experts"] - top_k = cfg.get("num_experts_per_tok", 6) - rsc = cfg.get("routed_scaling_factor", 2.5) - lim = cfg.get("swiglu_limit", 10.0) - num_hash = cfg.get("num_hash_layers", 3) - pfx = f"model.layers.{li}.mlp" - - tid2eid_key = f"{pfx}.gate.tid2eid" - e_bias_key = f"{pfx}.gate.e_score_correction_bias" - is_hash = (li < num_hash) and (tid2eid_key in w) - - if is_hash: - tid2eid = w[tid2eid_key] - tid = token_id.item() if token_id.numel() == 1 else token_id[0].item() - expert_ids = tid2eid[tid] - expert_weights = torch.ones(top_k, dtype=torch.float32, device=x.device) / top_k - else: - gate_ww, gate_ws, gate_ws2, gate_isc = get_nvfp4_weight(w, pfx, 'gate') - if gate_ww is not None and gate_ws is not None: - logits = nvfp4_linear(x, gate_ww.to(device), gate_ws.to(device), - gate_ws2.to(device) if gate_ws2 is not None else None, - gate_isc.to(device) if gate_isc is not None else None) - elif f"{pfx}.gate.weight" in w: - gw = w[f"{pfx}.gate.weight"].bfloat16().to(device) - logits = F.linear(x, gw) - else: - raise ValueError(f"No gate weight for layer {li}") - scores = torch.sqrt(F.softplus(logits.float()) + 1e-6) - sel = scores.clone() - if e_bias_key in w: - sel = sel + w[e_bias_key].to(device=x.device).float().unsqueeze(0) - _, indices = sel.topk(top_k, -1) - expert_weights = torch.gather(scores, -1, indices) - expert_weights = expert_weights / expert_weights.sum(-1, keepdim=True) - expert_ids, expert_weights = indices[0], expert_weights[0] - - expert_outs = [] - for i, eid in enumerate(expert_ids): - ep = f"{pfx}.experts.{eid}" - g = do_nvfp4_linear(x, w, ep, 'gate_proj') - u = do_nvfp4_linear(x, w, ep, 'up_proj') - silu = F.silu(g.float()) - if lim is not None: silu = silu.clamp(-lim, lim); u = u.float().clamp(-lim, lim) - h = (silu * u).bfloat16() - expert_outs.append(do_nvfp4_linear(h, w, ep, 'down_proj')) - - routed = torch.zeros_like(x) - for out, wt in zip(expert_outs, expert_weights): - routed = routed + (out.float() * wt.item()).bfloat16() - routed = (routed.float() * rsc).bfloat16() - - sp = f"{pfx}.shared_experts" - sg = do_nvfp4_linear(x, w, sp, 'gate_proj') - su = do_nvfp4_linear(x, w, sp, 'up_proj') - silu = F.silu(sg.float()) - if lim is not None: silu = silu.clamp(-lim, lim); su = su.float().clamp(-lim, lim) - shared = do_nvfp4_linear((silu * su).bfloat16(), w, sp, 'down_proj') - return routed + shared + topk_w, topk_ids = router(x, token_ids=token_id) + routed_out = moe_runner(x, topk_w, topk_ids) + shared_out = se_runner(x) + return routed_out + shared_out # ===================================================================== # Layer forward @@ -600,18 +491,20 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, attn_mhc, ffn_mhc, attn_norm_w, ffn_norm_w, kv_cache, positions, token_id, compressor=None, indexer=None, - moe_runner=None, se_runner=None, router=None): + moe_runner=None, se_runner=None, router=None, + prod_lin=None): dev = X_l.device # Attention sub-block x_in, ctx_a = attn_mhc.pre_block(X_l) x_normed = rmsnorm(x_in, attn_norm_w) F_attn, _ = forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, - kv_cache, positions, compressor, indexer) + kv_cache, positions, compressor, indexer, + prod_lin=prod_lin) X_mid = attn_mhc.post_block(X_l, F_attn, ctx_a) # FFN sub-block x_in_f, ctx_f = ffn_mhc.pre_block(X_mid) x_ffn = rmsnorm(x_in_f, ffn_norm_w) - F_ffn = moe_forward(x_ffn, w, li, cfg, token_id, dev, moe_runner, se_runner, router) + F_ffn = moe_forward(x_ffn, li, moe_runner, se_runner, router, token_id) X_next = ffn_mhc.post_block(X_mid, F_ffn, ctx_f) if VERBOSE >= 1: print(f" L{li}: |X|={X_l.abs().max().item():.1f}→{X_next.abs().max().item():.1f} " @@ -619,15 +512,132 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, return X_next # ===================================================================== -# Main +# MoE weight loading (stacked path for production GEMM) # ===================================================================== +def _load_moe_weights_stacked(all_w, li, pfx, dev, moe, cfg): + n_e = cfg["n_routed_experts"] + w0 = all_w.get(f"{pfx}.experts.0.gate_proj.weight") + if w0 is None: + log.warning(f"L{li}: No expert weights found") + return + gate_N, gate_K = w0.shape + + l1_stacked = torch.zeros(n_e, 2 * gate_N, gate_K, dtype=w0.dtype) + l1_sf_stacked = None + l2_stacked = None + l2_sf_stacked = None + l1_gs = [] + l2_gs = [] + + ws0 = all_w.get(f"{pfx}.experts.0.gate_proj.weight_scale") + if ws0 is not None: + sf_N, sf_K = ws0.shape + l1_sf_stacked = torch.zeros(n_e, 2 * sf_N, sf_K, dtype=ws0.dtype) + + dw0 = all_w.get(f"{pfx}.experts.0.down_proj.weight") + if dw0 is not None: + down_N, down_K = dw0.shape + l2_stacked = torch.zeros(n_e, down_N, down_K, dtype=dw0.dtype) + dws0 = all_w.get(f"{pfx}.experts.0.down_proj.weight_scale") + if dws0 is not None: + l2_sf_stacked = torch.zeros(n_e, dws0.shape[0], dws0.shape[1], dtype=dws0.dtype) + + for eid in range(n_e): + gw = all_w.get(f"{pfx}.experts.{eid}.gate_proj.weight") + gws = all_w.get(f"{pfx}.experts.{eid}.gate_proj.weight_scale") + gisc = all_w.get(f"{pfx}.experts.{eid}.gate_proj.input_scale") + uw = all_w.get(f"{pfx}.experts.{eid}.up_proj.weight") + uws = all_w.get(f"{pfx}.experts.{eid}.up_proj.weight_scale") + if gw is not None and uw is not None: + l1_stacked[eid, :gate_N] = gw + l1_stacked[eid, gate_N:] = uw + if gws is not None and uws is not None and l1_sf_stacked is not None: + l1_sf_stacked[eid, :sf_N] = gws + l1_sf_stacked[eid, sf_N:] = uws + l1_gs.append(gisc.float().item() if gisc is not None else 1.0 / (6.0 * 448.0)) + dw = all_w.get(f"{pfx}.experts.{eid}.down_proj.weight") + dws = all_w.get(f"{pfx}.experts.{eid}.down_proj.weight_scale") + disc = all_w.get(f"{pfx}.experts.{eid}.down_proj.input_scale") + if dw is not None: + l2_stacked[eid] = dw + if dws is not None and l2_sf_stacked is not None: + l2_sf_stacked[eid] = dws + l2_gs.append(disc.float().item() if disc is not None else 1.0 / (6.0 * 448.0)) + + l1_stacked = l1_stacked.to(dev) + l1_sf_stacked = l1_sf_stacked.to(dev) if l1_sf_stacked is not None else None + l2_stacked = l2_stacked.to(dev) if l2_stacked is not None else None + l2_sf_stacked = l2_sf_stacked.to(dev) if l2_sf_stacked is not None else None + l1_gs = l1_gs if l1_gs else [1.0 / (6.0 * 448.0)] * n_e + l2_gs = l2_gs if l2_gs else [1.0 / (6.0 * 448.0)] * n_e + moe.prepare_weights_from_stacked(l1_stacked, l1_sf_stacked, l1_gs, + l2_stacked, l2_sf_stacked, l2_gs) + + +def _load_shared_expert_weights(all_w, li, pfx, dev, se, cfg): + l1_gate_fp4, l1_gate_sf, l1_gate_gs = [], [], [] + l1_up_fp4, l1_up_sf = [], [] + l2_fp4, l2_sf, l2_gs = [], [], [] + for proj, fp4_l, sf_l, gs_l in [ + ('gate_proj', l1_gate_fp4, l1_gate_sf, l1_gate_gs), + ('up_proj', l1_up_fp4, l1_up_sf, None), + ('down_proj', l2_fp4, l2_sf, l2_gs), + ]: + w, ws, isc = all_w.get(f"{pfx}.shared_experts.{proj}.weight"), \ + all_w.get(f"{pfx}.shared_experts.{proj}.weight_scale"), \ + all_w.get(f"{pfx}.shared_experts.{proj}.input_scale") + if w is not None and ws is not None: + fp4_l.append(w.to(dev)) + sf_l.append(ws.to(dev)) + if gs_l is not None: + gs_l.append(isc.float().item() if isc is not None else 1.0 / (6.0 * 448.0)) + if l1_gate_fp4 and l1_up_fp4: + se.l1_fp4 = [torch.cat([l1_gate_fp4[0], l1_up_fp4[0]], dim=0)] + se.l1_sf = [torch.cat([l1_gate_sf[0], l1_up_sf[0]], dim=0)] + se.l1_gs = l1_gate_gs if l1_gate_gs else [1.0 / (6.0 * 448.0)] + if l2_fp4: + se.l2_fp4 = l2_fp4; se.l2_sf = l2_sf + se.l2_gs = l2_gs if l2_gs else [1.0 / (6.0 * 448.0)] + se.finalize_weights() + + +def _cache_layer_weights_no_experts(all_w, n_layers, devices): + """Cache per-layer weights to GPUs, EXCLUDING MoE expert weights.""" + cached = {} + for li in range(n_layers): + dev = devices[li % len(devices)] + pfx = f"model.layers.{li}." + w = {k: v.to(device=dev, non_blocking=True) + for k, v in all_w.items() + if k.startswith(pfx) and '.experts.' not in k and '.shared_experts.' not in k} + cached[li] = w + if (li+1) % 10 == 0: print(f" Cached {li+1}/{n_layers} layers") + return cached + + +def load_weights(checkpoint_dir): + from safetensors.torch import load_file + cdir = Path(checkpoint_dir) + wmap = {} + idx = cdir / "model.safetensors.index.json" + if idx.exists(): + with open(idx) as f: wmap = json.load(f).get("weight_map", {}) + shards = set(wmap.values()) if wmap else set() + all_w = {} + for sn in sorted(shards): + if (cdir / sn).exists(): + all_w.update(load_file(str(cdir / sn))) + return all_w + + def main(): t0 = time.time() torch.manual_seed(SEED) print("=" * 70) print("DSV4 Single-Shot Inference — PRODUCTION KERNEL STACK") - print(" FMHA: 6-warp TMA multi-tile | Compressor + Indexer | mHC | MoE") - print(" NVFP4 GEMM (CuTeDSL) | Router kernels | NO PyTorch SDPA") + print(" FMHA: 6-warp TMA multi-tile + sink bias") + print(" NVFP4 GEMM (CuTeDSL) | Router kernels | Production mHC") + print(" NO PyTorch SDPA | NO dequant+matmul | NO reference fallback") print("=" * 70) with open(os.path.join(CHECKPOINT_DIR, "config.json")) as f: @@ -641,17 +651,18 @@ def main(): print(f"Compress ratios: first5={cr[:5]} len={len(cr)}") print(f"Experts: {cfg['n_routed_experts']}, top-{cfg.get('num_experts_per_tok', 6)}") - # Load weights + # ---- Phase 1: Load weights ---- print(f"\nPhase 1: Loading weights...") - all_w = load_weights(CHECKPOINT_DIR) + all_w = load_all_weights(CHECKPOINT_DIR) print(f" {time.time()-t0:.1f}s") - # Build production components + # ---- Phase 2: Build production components ---- print("Building production components...") from dsv4.layers.mhc import mHCLayer from dsv4.layers.router import Router from dsv4.layers.moe import Nvfp4MoE from dsv4.layers.shared_expert import Nvfp4SharedExpert + from dsv4.layers.linear import Nvfp4Linear # mHC + norms attn_mhcs, ffn_mhcs, attn_norms, ffn_norms = {}, {}, {}, {} @@ -665,8 +676,20 @@ def main(): ]: fn, base, scale = all_w.get(fn_s), all_w.get(base_s), all_w.get(scale_s) if fn is not None and base is not None and scale is not None: - m = mHCBlock(H, 4, 20, dev) - m.load(fn, base, scale) + m = mHCLayer(hidden_dim=H, n_hc=4, t_max_sinkhorn=20, device=dev) + # Split fn/base/scale into pre/post/comb + n = 4 + m.load_weights( + W_pre=fn[0:n].to(dev, torch.float32), + W_post=fn[n:2*n].to(dev, torch.float32), + W_comb=fn[2*n:].to(dev, torch.float32), + S_pre=base[0:n].reshape(1, n).to(dev, torch.float32), + S_post=base[n:2*n].reshape(n, 1).to(dev, torch.float32), + S_comb=base[2*n:].reshape(n, n).to(dev, torch.float32), + alpha_pre=scale[0].item(), + alpha_post=scale[1].item(), + alpha_comb=scale[2].item(), + ) blocks[li] = m an_k = f"model.layers.{li}.input_layernorm.weight" @@ -674,6 +697,27 @@ def main(): fn_k = f"model.layers.{li}.post_attention_layernorm.weight" if fn_k in all_w: ffn_norms[li] = all_w[fn_k].to(dev, torch.float32) + # Production Nvfp4Linear for attention projections + prod_lins = {} + for li in range(n_layers): + dev = f"cuda:{li % NUM_GPUS}" + pfx = f"model.layers.{li}.self_attn" + plin = {} + for proj, in_f, out_f in [ + ('q_a', H, cfg.get('query_compression_dim', 1536)), + ('q_b', cfg.get('query_compression_dim', 1536), n_h * hd), + ('kv', H, hd), + ('o_b', cfg.get('o_groups', 16) * cfg.get('o_lora_rank', 1024), H), + ]: + wt, ws, ws2, isc = get_nvfp4_weight(all_w, pfx, proj) + if wt is not None and ws is not None: + lin = make_nvfp4_linear(in_f, out_f, dev, wt, ws, ws2, isc) + lin.finalize_weights() + plin[proj] = lin + if plin: + prod_lins[li] = plin + if (li+1) % 10 == 0: print(f" Built Nvfp4Linear {li+1}/{n_layers} layers") + # Routers, MoE, shared experts routers, moe_runners, se_runners = {}, {}, {} for li in range(n_layers): @@ -681,7 +725,6 @@ def main(): pfx = f"model.layers.{li}.mlp" is_hash = (li < cfg.get("num_hash_layers", 3)) and (f"{pfx}.gate.tid2eid" in all_w) - # Router router = Router( hidden_size=H, num_experts=cfg["n_routed_experts"], top_k=cfg.get("num_experts_per_tok", 6), @@ -700,19 +743,15 @@ def main(): router.finalize_weights() routers[li] = router - # MoE (production NVFP4 grouped GEMM) moe = Nvfp4MoE( num_experts=cfg["n_routed_experts"], hidden_size=H, intermediate_size=cfg.get("moe_intermediate_size", 3072), top_k=cfg.get("num_experts_per_tok", 6), device=dev, ) moe.set_swiglu_limit(cfg.get("swiglu_limit", 10.0)) - - # Load expert weights (stacked path) _load_moe_weights_stacked(all_w, li, pfx, dev, moe, cfg) moe_runners[li] = moe - # Shared expert (production NVFP4 single-group GEMM) se = Nvfp4SharedExpert( hidden_size=H, intermediate_size=cfg.get("moe_intermediate_size", 3072), device=dev, swiglu_limit=cfg.get("swiglu_limit", 10.0), @@ -761,7 +800,6 @@ def main(): if ratio == 4: indexers[li] = Indexer(n_ih, ihd, itk, dev) # Cache layer weights (EXCLUDE MoE/SE expert weights — handled by production runners) - # This avoids double-loading ~10GB/layer of expert FP4 weights print("Caching layer weights to GPUs (excluding MoE expert weights)...") devs = [f"cuda:{g}" for g in range(NUM_GPUS)] layer_w = _cache_layer_weights_no_experts(all_w, n_layers, devs) @@ -778,8 +816,8 @@ def main(): if li in indexers: indexers[li].load(layer_w[li], f"{pfx}.indexer") print(" Compressors/indexers loaded") - # Phase 2: Inference - print(f"\nPhase 2: Inference") + # ---- Phase 3: Inference ---- + print(f"\nPhase 3: Inference") from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_DIR) @@ -796,7 +834,7 @@ def main(): t1 = time.time() tid = torch.tensor([tid_val], dtype=torch.long, device='cuda:0') pos = torch.tensor([pi], dtype=torch.long, device='cuda:0') - X = mHCBlock.init_state(embed(tid)) + X = mHCLayer.init_state(embed(tid)) for li in range(n_layers): gpu = li % NUM_GPUS if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}") @@ -806,7 +844,8 @@ def main(): attn_norms.get(li), ffn_norms.get(li), kv_caches[li], pos, tid, compressors.get(li), indexers.get(li), - moe_runners.get(li), se_runners.get(li), routers.get(li)) + moe_runners.get(li), se_runners.get(li), routers.get(li), + prod_lin=prod_lins.get(li)) X = X.to('cuda:0'); torch.cuda.set_device(0) if pi % 10 == 0: print(f" Token {pi}/{len(generated)}: {time.time()-t1:.2f}s", flush=True) print(f" Prefill done ({time.time()-t0:.1f}s)") @@ -822,7 +861,7 @@ def main(): t1 = time.time() tid = torch.tensor([all_tokens[-1]], dtype=torch.long, device='cuda:0') dec_pos = torch.tensor([len(all_tokens)-1], dtype=torch.long, device='cuda:0') - X = mHCBlock.init_state(embed(tid)) + X = mHCLayer.init_state(embed(tid)) for li in range(n_layers): gpu = li % NUM_GPUS if X.device != torch.device(f"cuda:{gpu}"): X = X.to(f"cuda:{gpu}") @@ -832,7 +871,8 @@ def main(): attn_norms.get(li), ffn_norms.get(li), kv_caches[li], dec_pos, tid, compressors.get(li), indexers.get(li), - moe_runners.get(li), se_runners.get(li), routers.get(li)) + moe_runners.get(li), se_runners.get(li), routers.get(li), + prod_lin=prod_lins.get(li)) X = X.to('cuda:0'); torch.cuda.set_device(0) x_out = hc_head.forward(X) if hc_head is not None else X[:, 0, :] if final_norm_w is not None: x_out = rmsnorm(x_out, final_norm_w) @@ -858,157 +898,6 @@ def main(): print(f"Total: {time.time()-t0:.1f}s") print(f"{'='*70}") -# ===================================================================== -# MoE weight loading helpers (stacked path for production GEMM) -# ===================================================================== -def _load_moe_weights_stacked(all_w, li, pfx, dev, moe, cfg): - """Load MoE expert weights into Nvfp4MoE via stacked path. - - Memory-efficient: builds stacked tensors incrementally on CPU, - then moves to GPU in one shot. Avoids holding 384 individual - expert weight tensors on GPU simultaneously (~3× memory savings). - """ - n_e = cfg["n_routed_experts"] - moe_inter = cfg.get("moe_intermediate_size", 3072) - H = cfg["hidden_size"] - - # Build stacked tensors incrementally on CPU - # gate_proj and up_proj: (inter, K_packed) per expert → L1 stacked (E, 2*inter, K_packed) - # down_proj: (H, K_packed) per expert → L2 stacked (E, H, K_packed) - - # Get dimensions from first expert - w0 = all_w.get(f"{pfx}.experts.0.gate_proj.weight") - if w0 is None: - log.warning(f"L{li}: No expert weights found") - return - gate_N, gate_K = w0.shape # (inter, K_packed) - - l1_stacked = torch.zeros(n_e, 2 * gate_N, gate_K, dtype=w0.dtype) - l1_sf_stacked = None - l2_stacked = None - l2_sf_stacked = None - l1_gs = [] - l2_gs = [] - - # Determine L1 SF shape from first expert - ws0 = all_w.get(f"{pfx}.experts.0.gate_proj.weight_scale") - if ws0 is not None: - sf_N, sf_K = ws0.shape - l1_sf_stacked = torch.zeros(n_e, 2 * sf_N, sf_K, dtype=ws0.dtype) - - # Get L2 shape - dw0 = all_w.get(f"{pfx}.experts.0.down_proj.weight") - if dw0 is not None: - down_N, down_K = dw0.shape - l2_stacked = torch.zeros(n_e, down_N, down_K, dtype=dw0.dtype) - dws0 = all_w.get(f"{pfx}.experts.0.down_proj.weight_scale") - if dws0 is not None: - dsf_N, dsf_K = dws0.shape - l2_sf_stacked = torch.zeros(n_e, dsf_N, dsf_K, dtype=dws0.dtype) - - # Fill stacked tensors - for eid in range(n_e): - # L1: gate + up - gw = all_w.get(f"{pfx}.experts.{eid}.gate_proj.weight") - gws = all_w.get(f"{pfx}.experts.{eid}.gate_proj.weight_scale") - gisc = all_w.get(f"{pfx}.experts.{eid}.gate_proj.input_scale") - uw = all_w.get(f"{pfx}.experts.{eid}.up_proj.weight") - uws = all_w.get(f"{pfx}.experts.{eid}.up_proj.weight_scale") - - if gw is not None and uw is not None: - l1_stacked[eid, :gate_N] = gw - l1_stacked[eid, gate_N:] = uw - if gws is not None and uws is not None and l1_sf_stacked is not None: - l1_sf_stacked[eid, :sf_N] = gws - l1_sf_stacked[eid, sf_N:] = uws - l1_gs.append(gisc.float().item() if gisc is not None else 1.0 / (6.0 * 448.0)) - - # L2: down - dw = all_w.get(f"{pfx}.experts.{eid}.down_proj.weight") - dws = all_w.get(f"{pfx}.experts.{eid}.down_proj.weight_scale") - disc = all_w.get(f"{pfx}.experts.{eid}.down_proj.input_scale") - if dw is not None: - l2_stacked[eid] = dw - if dws is not None and l2_sf_stacked is not None: - l2_sf_stacked[eid] = dws - l2_gs.append(disc.float().item() if disc is not None else 1.0 / (6.0 * 448.0)) - - # Move to GPU in one shot - l1_stacked = l1_stacked.to(dev) - l1_sf_stacked = l1_sf_stacked.to(dev) if l1_sf_stacked is not None else None - l2_stacked = l2_stacked.to(dev) if l2_stacked is not None else None - l2_sf_stacked = l2_sf_stacked.to(dev) if l2_sf_stacked is not None else None - l1_gs = l1_gs if l1_gs else [1.0 / (6.0 * 448.0)] * n_e - l2_gs = l2_gs if l2_gs else [1.0 / (6.0 * 448.0)] * n_e - - moe.prepare_weights_from_stacked(l1_stacked, l1_sf_stacked, l1_gs, - l2_stacked, l2_sf_stacked, l2_gs) - - -def _load_shared_expert_weights(all_w, li, pfx, dev, se, cfg): - """Load shared expert weights.""" - l1_gate_fp4, l1_gate_sf, l1_gate_gs = [], [], [] - l1_up_fp4, l1_up_sf = [], [] - l2_fp4, l2_sf, l2_gs = [], [], [] - - for proj, fp4_l, sf_l, gs_l in [ - ('gate_proj', l1_gate_fp4, l1_gate_sf, l1_gate_gs), - ('up_proj', l1_up_fp4, l1_up_sf, None), - ('down_proj', l2_fp4, l2_sf, l2_gs), - ]: - w_k = f"{pfx}.shared_experts.{proj}.weight" - ws_k = f"{pfx}.shared_experts.{proj}.weight_scale" - isc_k = f"{pfx}.shared_experts.{proj}.input_scale" - w, ws, isc = all_w.get(w_k), all_w.get(ws_k), all_w.get(isc_k) - if w is not None and ws is not None: - fp4_l.append(w.to(dev)) - sf_l.append(ws.to(dev)) - if gs_l is not None: - gs_l.append(isc.float().item() if isc is not None else 1.0 / (6.0 * 448.0)) - - if l1_gate_fp4 and l1_up_fp4: - se.l1_fp4 = [torch.cat([l1_gate_fp4[0], l1_up_fp4[0]], dim=0)] - se.l1_sf = [torch.cat([l1_gate_sf[0], l1_up_sf[0]], dim=0)] - se.l1_gs = l1_gate_gs if l1_gate_gs else [1.0 / (6.0 * 448.0)] - if l2_fp4: - se.l2_fp4 = l2_fp4; se.l2_sf = l2_sf - se.l2_gs = l2_gs if l2_gs else [1.0 / (6.0 * 448.0)] - se.finalize_weights() - - -def _cache_layer_weights_no_experts(all_w, n_layers, devices): - """Cache per-layer weights to GPUs, EXCLUDING MoE expert weights. - - MoE expert weights (model.layers.{li}.mlp.experts.*) are handled by - Nvfp4MoE runners with stacked tensors. Shared expert weights are handled - by Nvfp4SharedExpert runners. Including them here would double-load - ~10.6GB/layer of FP4 expert weights. - """ - cached = {} - for li in range(n_layers): - dev = devices[li % len(devices)] - pfx = f"model.layers.{li}." - w = {k: v.to(device=dev, non_blocking=True) - for k, v in all_w.items() - if k.startswith(pfx) and '.experts.' not in k and '.shared_experts.' not in k} - cached[li] = w - if (li+1) % 10 == 0: print(f" Cached {li+1}/{n_layers} layers") - return cached - -def load_weights(checkpoint_dir): - from safetensors.torch import load_file - cdir = Path(checkpoint_dir) - wmap = {} - idx = cdir / "model.safetensors.index.json" - if idx.exists(): - with open(idx) as f: wmap = json.load(f).get("weight_map", {}) - shards = set(wmap.values()) if wmap else set() - all_w = {} - for sn in sorted(shards): - if (cdir / sn).exists(): - all_w.update(load_file(str(cdir / sn))) - return all_w - if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/tests/unit/test_fmha_sink_bias.py b/tests/unit/test_fmha_sink_bias.py new file mode 100644 index 00000000..ec0be9ce --- /dev/null +++ b/tests/unit/test_fmha_sink_bias.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +"""Test FMHA kernel with attention sink bias. + +Validates that the kernel's sink bias correction matches PyTorch reference: + softmax([QK^T * scale, sink_bias])[:N] @ V + +Tests HD=64,128,256,512 with and without sinks. +""" +import torch +import math +import sys + +def reference_fmha_with_sink(q, k, v, scale, sink_bias=None): + """PyTorch reference: softmax([QK^T * scale, sink_bias]) @ V. + + q: (n_h, T, hd), k: (1, N, hd), v: (1, N, hd) + sink_bias: (n_h,) FP32 or None + Returns: (n_h, T, hd) BF16 + """ + n_h, T, hd = q.shape + N = k.shape[1] + # QK^T: (n_h, T, N) + scores = torch.matmul(q, k.transpose(-1, -2)) * scale # (n_h, T, N) + + if sink_bias is not None: + # Concatenate sink as extra column: (n_h, T, N+1) + sb = sink_bias.reshape(n_h, 1, 1).expand(-1, T, 1) + combined = torch.cat([scores, sb], dim=-1) + attn = torch.softmax(combined.float(), dim=-1)[:, :, :N] # drop sink column + else: + attn = torch.softmax(scores.float(), dim=-1) + + out = torch.matmul(attn.bfloat16(), v) # (n_h, T, hd) + return out + +def test_fmha_sink(): + from dsv4.kernels.attention.production import dsv4_attention + + torch.manual_seed(42) + device = 'cuda' + passed = 0 + failed = 0 + + for hd in [64, 128, 256, 512]: + for N in [9, 32, 128, 256]: + for use_sink in [False, True]: + n_h = 4 # small for speed + T = 1 + scale = 1.0 / math.sqrt(hd) + + q = torch.randn(n_h, T, hd, dtype=torch.bfloat16, device=device) + k = torch.randn(1, N, hd, dtype=torch.bfloat16, device=device) + v = torch.randn(1, N, hd, dtype=torch.bfloat16, device=device) + sink = torch.randn(n_h, dtype=torch.float32, device=device) * 2 if use_sink else None + + # Production kernel + try: + o_kernel = dsv4_attention(q, k, v, scale=scale, sink_bias=sink) + except Exception as e: + print(f" FAIL hd={hd} N={N} sink={use_sink}: kernel error: {e}") + failed += 1 + continue + + # PyTorch reference + o_ref = reference_fmha_with_sink(q, k, v, scale, sink) + + # Compare + o_kf = o_kernel.float() + o_rf = o_ref.float() + cos = torch.nn.functional.cosine_similarity(o_kf.flatten().unsqueeze(0), + o_rf.flatten().unsqueeze(0)).item() + max_diff = (o_kf - o_rf).abs().max().item() + + status = "PASS" if cos > 0.999 else "FAIL" + if status == "PASS": + passed += 1 + else: + failed += 1 + print(f" {status} hd={hd} N={N} sink={use_sink} cos={cos:.6f} max_diff={max_diff:.6f}") + + print(f"\n{'='*60}") + print(f"Results: {passed} PASSED, {failed} FAILED") + print(f"{'='*60}") + return failed == 0 + +if __name__ == "__main__": + success = test_fmha_sink() + sys.exit(0 if success else 1)