From 040b2eb6e7f1a3e4736776523ecff8e35131c46b Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 2 Jun 2026 06:59:25 +0000 Subject: [PATCH] =?UTF-8?q?perf:=20P0/P1/P2=20=E2=80=94=20fused=20SwiGLU?= =?UTF-8?q?=20for=20MoE+SE,=20eliminate=20per-call=20gsa=20fill?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit P0: Enable fused SwiGLU for all MoE instances (moe._fused_swiglu = True). Eliminates ~8 BF16 kernel launches per MoE per token (gate/up split, SiLU, clamp, elementwise multiply → single fused kernel launch). P1: Enable fused SwiGLU for shared expert (SE): - Added set_fused_swiglu() method to Nvfp4SharedExpert - Added _run_l1_fused() using run_fused_swiglu_grouped_gemm (1-group) - Interleave L1 weights at finalize time for fused kernel compatibility - Fused kernel handles SwiGLU + clamp in registers, outputs BF16 P2: Eliminate per-call _gsa_buf.fill_() in Nvfp4Linear: - _activation_global_scale is set once at warmup, never changes after - Skip redundant fill_() via _gsa_buf_initialized flag - Saves 244 CPU→GPU scalar fills per token (4 linears × 61 layers) P3: Deferred (in-kernel RoPE fusion — kernel-side change, not single_shot) --- PERFORMANCE_AUDIT.md | 721 ++++++++++++++++++----------------- dsv4/layers/linear.py | 7 +- dsv4/layers/shared_expert.py | 66 +++- single_shot_inference.py | 3 + 4 files changed, 423 insertions(+), 374 deletions(-) diff --git a/PERFORMANCE_AUDIT.md b/PERFORMANCE_AUDIT.md index e26cb2ba..d87c32f7 100644 --- a/PERFORMANCE_AUDIT.md +++ b/PERFORMANCE_AUDIT.md @@ -1,424 +1,427 @@ -# PERFORMANCE — verified hot-path audit and prioritized fixes +# PERFORMANCE — v17 roadmap toward end-to-end NVFP4 hot path -**First: congratulations. Paris-back is the milestone.** It means the math is -right end-to-end through all 61 layers, the production NVFP4 GEMM stack is -plumbed correctly, the multi-tile FMHA kernel works in real conditions, the -mHC bound holds well enough for a coherent answer, the indexer top-k is -selecting the right blocks, and the FP4 → BF16 dequant path is byte-correct. -That's a real architectural validation. +**Verified state.** v17 has the Tier-1 indexer fixes landed (weight path, +buffer width, MQA einsum). Hot-path syncs and allocator churn from earlier +perf rounds are gone. The single_shot now genuinely runs through the +production NVFP4 kernel stack. What remains is **fusion gaps and KV-cache +dtype choices** — the difference between "uses NVFP4 kernels" and "is +NVFP4 end-to-end." -**Second: about the agent's "1.45s/token is slow (weight loading overhead)" -line.** That diagnosis is wrong, and it's the kind of wrong that will steer -the next agent to optimize the cold path instead of the hot one. Weight -loading happens once during Phase 1 setup, before token 0. The decode step -timer (`t1 = time.time()` at `single_shot_inference.py:906`) starts *after* -weights are loaded and *after* every prior layer's setup is done. 1.45s is -**per-token decode time**, not per-token load + decode. Per-token decode at -hd=512, n_h=128, 61 layers, batch=1 should be in the **single-digit ms** ballpark -on a B200, not 1.45s. There is a ~100–300× gap, and it's not weights. - -The rest of this doc identifies where it actually is. - -**Method.** Every claim below is grounded in a line number. No guessing. +**On TurboQuant — verdict first, reasoning below.** Don't use it for DSv4. +It's not architecturally compatible with the heterogeneous compressed KV +cache, and the part it *would* help (the SWA branch) is already small. The +right move is FP4 storage for the compressed KV path (paper-aligned per +§5.2.1), not vector-quantization codebooks. Full reasoning in Section 4. --- -## WORK IN PROGRESS — What Was Being Done (Session 2026-06-01 20:21 UTC) +# PART 1 — THE NVFP4-EVERYWHERE GAP -### Completed fixes (committed, pushed, NOT YET TESTED ON B200): +## P0 — Fused SwiGLU exists in the library and is NEVER ENABLED -1. **P0 (COMPLETE)**: ALL `.item()` CPU-GPU syncs eliminated from NVFP4 activation path. - - `dsv4/kernels/cuda/amax_gsa.cu`: GPU-only amax→gsa kernel - - `dsv4/kernels/cuda/fused_amax_quantize.cu`: quantize with gsa from GPU buffer - - `dsv4/ops/quantize.py`: `quantize_nvfp4_gpu_fused()` — two kernel launches, zero CPU syncs - - `dsv4/layers/linear.py` Nvfp4Linear: uses `quantize_nvfp4_gpu_fused` - - `dsv4/layers/grouped_linear.py` Nvfp4GroupedLinear: uses `quantize_nvfp4_gpu_fused` (was last holdout) - - `dsv4/layers/moe.py` Nvfp4MoE: uses `quantize_nvfp4_gpu_fused` - - `dsv4/layers/shared_expert.py` Nvfp4SharedExpert: uses `quantize_nvfp4_gpu_fused` - - Hot-path D2H sync count: ~486 → ≤ 5 (argmax + token decode) - -2. **P4 (done)**: Changed `v = k.clone()` to `v = k` in `single_shot_inference.py:320`. - The `.transpose(-1,-2).contiguous()` in `dsv4_attention` already creates - a new tensor, so the clone was redundant. - -3. **Removed `torch.cuda.synchronize(x.device)`** from `moe_forward` in - `single_shot_inference.py`. Made topk_ids validity check conditional on - `VERBOSE >= 2`. - -4. **Added fused CUDA sampler**: `dsv4/kernels/cuda/sampler.cu` with - `dsv4/model/sampler.py` wrapper. Temperature + repetition penalty + top-k - + top-p (nucleus) sampling, single kernel launch, zero CPU syncs. - Updated `single_shot_inference.py` to use `CUDASampler` with defaults - temperature=0.6, top_k=50, top_p=0.95 (was greedy temp=0.0). - -5. **Pre-allocated decode buffers**: `dec_tid_buf`, `dec_tid32_buf`, - `dec_pos_buf` — reused across decode steps instead of `torch.tensor()` - per step. - -6. **Added thinking token tracking**: THINK_START=128821, THINK_END=128822 - are displayed as [THINKING] in diagnostics. - -### INVALIDATED audit items (removed from this doc): -- **RoPE 8x duplication**: INVALIDATED. Each GPU needs its own RoPE cache - for the FMHA kernel to read from local HBM. No cross-GPU traffic. - Not a perf issue. -- **mHC BF16 bmm**: INVALIDATED. The bmm is (1,4,4)×(1,4,7168) = 114K FLOPs. - Negligible compared to MoE (billions of FLOPs). Not a bottleneck. -- **Router .float() cast**: INVALIDATED. Needed for FP32 activation_topk - (numerical stability for sqrt(softplus)). ~1μs. Not a bottleneck. - -### CARDINAL RULE VIOLATION: -The session broke the cardinal rule: MUST USE THE TEST HARNESS. Instead of -using `fire_b200_test` or `fire_b200_cuda_test`, raw SSH commands were used -to compile kernels and run tests on the B200. This caused: -- Stale processes not being cleaned up properly -- No log management -- Potentially conflicting screen sessions -- The test harness's GPU cleanup / process killing was bypassed - -**ALL TESTING MUST USE THE HARNESS.** If the harness needs to be more dynamic -(e.g., support running `single_shot_inference.py` from the repo root, not -just `tests/unit/`), THEN FIX THE HARNESS. Do not bypass it. - -### Compilation issues found: -- `at::cuda::getCurrentCUDAStream()` does not exist. Use `c10::cuda::getCurrentCUDAStream()`. -- `torch::TensorOptions().device(x.device())` doesn't compile. Use `x.options().dtype(...)`. -- Both fixed in committed code. - -### TESTED ON B200 (2026-06-01 22:40 UTC): -- P0/P2/P3/P4/P5/P7 all verified working -- Decode speed: 0.51s/token (greedy) / 0.53s/token (sampling) -- Sampler SMEM fix: LK=24 (48KB fits default), cudaFuncSetAttribute carveout -- Output: greedy produces repetition loop ("The capital of France is the" × N) -- With sampling (temp=0.6, top_k=50, top_p=0.95, rep_pen=1.1): produces "The capital of America is founded" -- Logits are reasonable: top-1 matches expected tokens for first 5 steps -- Residual |X| grows to 500-700 at L60 — mHC bounds it but residual is high - -### NOT YET STARTED: -- P1 — REMOVED. Multi-GPU layout is correct for the reference script. -- P2 (vectorize KVCache.append_swa) — simple fix, not started -- P3 (preallocate comp_kv, kill torch.cat) — not started -- P5 (in-place RoPE) — not started -- P7 (compressor early return + decode buffering) — not started -- Complete P0 by fusing amax+quantize or making quantize read from GPU buffer -- Testing ANY of the committed changes on the B200 - ---- - -## P0 — Per-call `.item()` D2H sync inside every NVFP4 linear - -**This is the biggest single contributor and almost certainly explains the -order of magnitude on its own.** - -`dsv4/layers/linear.py:166–168`: +This is the biggest single-line perf bug in v17. +`dsv4/layers/moe.py:61`: ```python -if getattr(self, '_use_runtime_gsa', False): - amax = hidden_states.float().abs().max().clamp(min=1e-8).item() - self._activation_global_scale = amax / (6.0 * 448.0) +self._fused_swiglu = False # Set via set_fused_swiglu() ``` -`.item()` is a blocking **D2H copy with full stream synchronization**. It -forces every pending kernel on the device to finish before the host can read -the value, then host blocks until the value arrives, then the host computes -the scalar and the next kernel launches. **Every single linear call that has -`_use_runtime_gsa = True` is a hard pipeline bubble.** +`set_fused_swiglu()` exists (`moe.py:103`), `warmup_fused_swiglu_compilation` +exists and is wired into the warmup path, the fused kernel +`run_fused_swiglu_grouped_gemm` is implemented and tested. But **searching +`single_shot_inference.py` for `set_fused_swiglu` returns zero hits.** -How many times does this happen per decoded token? - -| Call site | Per layer | × 61 layers | -|---|---|---| -| attention projections (q_a, q_b, kv, o_b) | 4 | 244 | -| o_a (grouped) | 1 | 61 | -| router gate (non-hash layers) | 1 | ~58 | -| moe runner | 1 | 61 | -| shared expert | 1 | 61 | -| lm_head | 1 | 1 | -| **TOTAL D2H syncs / decoded token** | | **~486** | - -At conservative ~50 µs per D2H sync on a B200 with kernel queue in flight, -that's **~24 ms of pure pipeline bubbles per token from this one line.** -That's just the syncs — the lost overlap on top of that is larger. - -### The fix (in priority order) - -1. **Use `compute_amax_gsa_gpu` kernel** (already written, committed). - Computes amax on GPU, returns scalar GPU tensor. The CuTeDSL GEMM's - `global_scale_a` is already a GPU tensor via `to_cute()`, so passing the - GPU scalar to the GEMM requires zero CPU syncs. - -2. **Complete the fix**: `quantize_nvfp4_gpu()` still needs a Python float - for `global_scale`. Either: - a. Modify `quantize_nvfp4.cu` to read `global_scale` from a GPU buffer - instead of a kernel parameter. - b. Fuse amax+quantize into a single kernel that outputs FP4 + writes gsa - to a GPU buffer for the GEMM. - -3. **Warmup-once gsa** (alternative): Compute gsa during a warmup forward - at startup, store as device tensor, disable `_use_runtime_gsa` on the - hot path. The infrastructure exists at `linear.py:133` - (`compute_activation_global_scale`). One warmup token, then - `_use_runtime_gsa = False` for every Nvfp4Linear. - -### Falsifiable gate - -Per-decoded-token D2H sync count: goes from ~486 to **≤ 5** (argmax + token -decode + end-of-loop bookkeeping). If sync count is still > 50 after this -fix, dig deeper before declaring done. - ---- - -## ~~P1~~ — REMOVED - -The single_shot_inference.py is a **reference implementation** for vLLM/SGLang -integration. The multi-GPU layer-pipeline sharding (`gpu = li % NUM_GPUS`) is -the correct pattern for this reference — it's how vLLM actually distributes -layers across GPUs. The EP/TP sharding discussion belongs in the vLLM -integration, not the reference script. **Do not change the multi-GPU layout.** - ---- - -## P2 — Python loop in `KVCache.append_swa` (`:272`) +What this costs every layer, every token: +`moe.py:640–660` (the unfused branch that runs by default): ```python -def append_swa(self, kv, pos): - T = kv.shape[0] - for i in range(T): - idx = (self.swa_head + i) % self.ws - self.swa[idx], self.swa_pos[idx] = kv[i], pos[i] - ... +l1_out = run_nvfp4_grouped_gemm(...) # NVFP4 → BF16 GEMM +l1_deil = deinterleave_l1_weights(l1_out...) # BF16 → BF16 deinterleave (extra launch) +gate = l1_deil[:, :self.intermediate_size] # BF16 slice +up = l1_deil[:, self.intermediate_size:] # BF16 slice +gate_silu = F.silu(gate) # BF16 SiLU launch +if swiglu_limit: # + gate_silu = gate_silu.clamp(...) # BF16 clamp launch + up = up.clamp(...) # BF16 clamp launch +activated = gate_silu * up # BF16 elementwise +slot_l2_x_fp4, slot_l2_x_sf, _ = quantize_nvfp4_gpu_fused(activated) # back to FP4 ``` -Per-decoded-token, T=1 so this loop runs once. **But the assignment -`self.swa[idx], self.swa_pos[idx] = kv[i], pos[i]` is two scalar tensor -indexing ops on the GPU**, each of which queues a tiny kernel. The -single-token cost is small (~tens of µs) but it's a serialization point. +That's **8 BF16-tensor-resident kernel launches** per layer per token, +moving 2× `intermediate_size × n_active_experts` BF16 elements through +HBM, between two NVFP4 GEMMs that could have been fused. -During prefill at T=N (say N=20 tokens in the warmup prompt), this loop -runs N times and queues 2N tiny kernels. That's significant. +What the fused path does (`moe.py:617–625`): +- Single launch: NVFP4 GEMM + SwiGLU + clamp in kernel registers +- Output goes directly to FP4 in `deinterleave_amax_quantize_nvfp4_fused` + +**For Pro (n_active=6, intermediate=3072), per token, all 30 MoE layers:** +- 30 × 6 × (3072 BF16 = 6 KB) × 2 (R+W) × 8 launches ≈ **3 MB** + of pointless BF16 HBM traffic per token, plus 240 unfused launches. + +It's not bandwidth-dominant, but **240 launches/token is the kind of +launch-rate ceiling that caps decode tok/s at the launch-floor of the +hardware.** B200 launch rate ~1–2 µs in practice. That's 240–480 µs/token +of pure launch overhead from this one missing call. ### The fix -Vectorize: +One line in main(), in the MoE/SE setup loop: ```python -def append_swa(self, kv, pos): - T = kv.shape[0] - idx = (self.swa_head + torch.arange(T, device=self.dev)) % self.ws - self.swa.index_copy_(0, idx, kv) - self.swa_pos.index_copy_(0, idx, pos) - self.swa_head = (self.swa_head + T) % self.ws - self.swa_len = min(self.swa_len + T, self.ws) +for li in range(n_layers): + if li in moes: + moes[li].set_fused_swiglu(True) + moes[li].set_swiglu_limit(cfg.get('swiglu_limit')) # if applicable + if li in shared_experts: + shared_experts[li].set_fused_swiglu(True) + shared_experts[li].set_swiglu_limit(cfg.get('swiglu_limit')) ``` -Two kernel launches instead of 2T. Same numerical result. +Then ensure the warmup path triggers `warmup_fused_swiglu_compilation` +once before the decode loop. ### Falsifiable gate -`append_swa` queues exactly 2 kernels regardless of T. Verifiable with -`cudaLaunchKernel` count between two `cudaDeviceSynchronize` calls bracketing -the function. +After enabling: per-MoE-layer launch count drops from ~9 to ~2 (the GEMM ++ the L2 path). Verifiable with Nsight or `cudaLaunchKernel` counter. +Numerical parity: `cos ≥ 0.9995` vs unfused, captured before the switch. ---- +## P1 — Shared expert has the same fused-path gap -## P3 — Quadratic `torch.cat` growth on compressed KV (`:280`) +The shared expert (`shared_expert.py:240`, `:285`) calls +`quantize_nvfp4_gpu_fused` between its L1 and L2 GEMMs but does **not** +have a fused SwiGLU path of its own. Whether the same kernel +(`run_fused_swiglu_grouped_gemm`) can be reused for SE depends on whether +SE expects a "group of 1" — needs investigation, not assumption. -```python -def add_compressed(self, ckv, cpos, idx_kv=None): - if ckv is None: return - self.comp_kv = ckv if self.comp_kv is None else torch.cat([self.comp_kv, ckv]) - ... -``` +### Action (read, don't guess) -Each `torch.cat` allocates a new tensor of size `n_comp + new_len` and copies -the entire existing `comp_kv` into it. After N tokens have produced -compressed entries, total work is O(N²) and total allocator pressure is O(N²) -bytes. - -For the Paris demo with ~50 decoded tokens this is invisible. **For the -million-token contexts V4 is built for, this is catastrophic** — you'd spend -most of your time copying KV around. - -### The fix - -Preallocate a ring or growing-power-of-2 buffer. Same pattern as `swa`: - -```python -# In __init__: -self.comp_kv_buf = torch.zeros(max_comp, head_dim, dtype=torch.bfloat16, device=dev) -self.comp_pos_buf = torch.zeros(max_comp, dtype=torch.long, device=dev) -self.comp_idx_buf = ... # same -self.n_comp = 0 - -def add_compressed(self, ckv, cpos, idx_kv=None): - if ckv is None: return - T = ckv.shape[0] - end = self.n_comp + T - self.comp_kv_buf[self.n_comp:end] = ckv - self.comp_pos_buf[self.n_comp:end] = cpos - if idx_kv is not None: self.comp_idx_buf[self.n_comp:end] = idx_kv - self.n_comp = end -``` - -`comp_kv` getters return `comp_kv_buf[:n_comp]` (a view, no copy). - -`max_comp` for 1M context with m=4: 250K entries × 512 × 2 bytes = 256 MB. -For 1M context with m=128 (HCA): ~16K entries × 512 × 2 = 16 MB. Both fit. +Print the shapes and dtypes of SE's L1 GEMM input/output and compare to +what `run_fused_swiglu_grouped_gemm` expects. If they match (modulo +groups=1), wire it. If not, the fused-SwiGLU kernel needs a +"dense/single-group" specialization — which is a kernel-side ask, not a +single_shot fix. ### Falsifiable gate -Memory growth across 1000 decode steps stays flat (within 100 MB of -steady-state). Decode-step time stays flat instead of growing. +Either SE uses the same fused kernel as MoE (same launch-count savings), +or there's a documented `.md` paper trail explaining why it can't and +what the production path is. ---- - -## P4 — `v = k` instead of `v = k.clone()` (`:318`) — DONE - -DSV4 uses shared KV — k and v are the same tensor. The `clone()` was -allocating and copying the entire KV buffer per call unnecessarily. - -**FIX APPLIED**: Changed `v = k.clone()` to `v = k`. The `dsv4_attention` -function transposes V internally via `.transpose(-1,-2).contiguous()` which -already creates a new tensor. The original K is never mutated. - ---- - -## P5 — RoPE allocates and clones the whole tensor (`:65`) +## P2 — Linear `.run()` per-call FP32 scale uploads still exist +`dsv4/layers/linear.py:188`: ```python -def _apply_rope(x, pos, cos, sin, rope_dim, inverse=False): - ... - out = x.clone(); ro = torch.empty_like(xr) - ro[..., 0::2], ro[..., 1::2] = rev, rod - out[:, :, nope:] = ro.bfloat16(); return out +gsa = self._gsa_buf.fill_(self._activation_global_scale) ``` -Called **3× per attention block** (Q, KV, inverse) × 61 layers = **183 RoPE -calls per token**. Each call does: `cos[pos]` gather, FP32 cast of 64 dims, -multiply-add, `x.clone()` of the full (T, nh, hd) tensor (most of which is -NoPE and doesn't need to be touched), `empty_like`, strided write, BF16 cast. - -For T=1, hd=512, nope=448, n_h=128 per call: cloning 128×512 BF16 = 128 KB per -call × 183 = 23 MB of pointless memcpy per token. Negligible bandwidth-wise -on a B200, but it's **183 kernel launches** that contribute to the launch-rate -ceiling. +After the earlier P0 fix (`_use_runtime_gsa = False`), this no longer +syncs via `.item()`. But it still does a CPU→GPU scalar fill per call. +For Pro, 4 Nvfp4Linears in attention × 61 layers = 244 `fill_()` calls +per token. At ~5 µs each that's ~1.2 ms/token of CPU→GPU dispatch. ### The fix -In-place RoPE for the last 64 dims, no full clone, no FP32 round-trip on the -NoPE half: +Make `_activation_global_scale` a 1-element `torch.Tensor` on device, set +once at warmup. The fill becomes redundant — pass `self._gsa_buf` directly +to the kernel, no per-call fill needed. ```python -def _apply_rope_inplace(x, pos, cos, sin, rope_dim, inverse=False): - nope = x.shape[-1] - rope_dim - c = cos[pos] # (T, rope_dim/2) - s = sin[pos] - xr = x[..., nope:] # view, not copy - ev = xr[..., 0::2].clone() # need the original ev for the mix - od = xr[..., 1::2] # view; will write back below - if inverse: - xr[..., 0::2] = ev * c[..., None, :] + od * s[..., None, :] - xr[..., 1::2] = -ev * s[..., None, :] + od.clone() * c[..., None, :] - else: - ... - return x # mutated in place -``` +# In Nvfp4Linear.__init__: +self._gsa_buf = torch.full((1,), 1.0 / (6.0 * 448.0), dtype=torch.float32, device=device) -Even better: **fuse RoPE into the Q/KV projection kernel**. The NVFP4 GEMM -already emits BF16; adding a RoPE postlude in registers is straightforward -and saves all 183 launches. That's the production target, not the script's -job, but the script should at least not do the 183 clones. +# After compute_activation_global_scale (runs once at warmup): +self._gsa_buf.fill_(gs) # ONE TIME, not per call + +# In run(): +self.kernel(..., global_scale_a=self._gsa_buf) # no fill +``` ### Falsifiable gate -RoPE kernel launch count per decoded token drops from 183 to ≤ 3. When fused -into GEMM: 0. +Zero CPU→GPU scalar fills on the hot path. Verifiable with +`cudaMemcpy*Async` counter (D2H / H2D should both be zero between two +syncs bracketing one layer). + +## P3 — In-kernel RoPE fusion (still on the table, deferred from prior audit) + +P5 from the v15 audit: in-place RoPE eliminated the clone problem, but +RoPE is still 3 separate launches per attention block × 61 layers ≈ 183 +launches per token. Fusing RoPE into the Q/KV NVFP4 GEMM epilogue (the +GEMM already emits BF16 to the gather buffer; adding a per-channel +multiply-and-add in registers is straightforward) would eliminate +those launches entirely. + +**This is a kernel-side change**, not a single_shot fix. Production target, +not single_shot target. Track it but don't gate the perf rollup on it. + +### Falsifiable gate (when kernel work lands) + +RoPE launch count: 183/token → 0/token. End-to-end cos ≥ 0.999998 vs +unfused. --- -## P6 — Indexer scoring is FP32 einsum (deferred to E7) +# PART 2 — KV CACHE: WHAT'S ALREADY FP4-COMPATIBLE, WHAT ISN'T -The lightning indexer uses `torch.einsum` in FP32 on CUDA cores. Correct but -not fast. At long context (n_comp ~ 250K), this becomes a wall. +DSv4's three KV streams have very different characteristics. Treating them +uniformly is the trap. -**Defer to roadmap E7** (FP4 tensor-core scoring). At Paris-scale context -(n_comp ≤ 30), FP32 einsum is acceptable. - ---- - -## P7 — Compressor re-runs GEMMs when `n_complete == 0` - -At T=1 decode with HCA (r=128), the compressor runs two NVFP4 GEMMs (kv_proj, -gate_proj) for nothing because `n_complete = 1 // 128 = 0`. The early return -happens AFTER the GEMMs. - -### The fix - -Move `n_complete == 0` check above the GEMMs. For CSA (r=4), buffer -hidden_states across 4 decode steps and run the compressor only on the step -where a complete block is available. - ---- - -## P8 — Layer-level fusion candidates (production future) - -1. **NVFP4-1.2: Fuse FP4 quant into FMHA output → wo_a** (roadmap E6). -2. **Fuse RMSNorm + Q/KV projection.** -3. **Fuse RoPE into Q/KV GEMM epilogue** (as in P5 above). -4. **mHC pre_block + RMSNorm fusion.** -5. **CUDA graph capture** (roadmap E9) — unlocked after P0–P3 and syncs are fixed. - ---- - -## Priority order - -| # | Item | Effort | Win | Status | +| Stream | Stored width | At 1M ctx | Per-access pattern | Quantizable? | |---|---|---|---|---| -| **P0** | Kill `.item()` in `_use_runtime_gsa` | S | **Huge** (~24 ms/token) | COMPLETE — tested on B200, 0.51s/token -| **P1** | ~~REMOVED~~ — multi-GPU layout is correct for reference | — | — | REMOVED | -| **P2** | Vectorize `KVCache.append_swa` | XS | Small/medium (prefill) | DONE — in single_shot_inference.py | -| **P3** | Preallocate `comp_kv`, kill `torch.cat` | S | Critical at long ctx | DONE — in single_shot_inference.py | -| **P4** | `v = k` instead of `v = k.clone()` | XS | Big (memory + BW) | DONE | -| **P5** | In-place / fused RoPE | S | Medium (-180 launches) | DONE — in single_shot_inference.py | -| **P6** | Indexer FP4 tensor-core scoring | L | Critical at long ctx | DEFERRED (E7) | -| **P7** | Compressor early return + decode buffering | S | Medium | DONE — tested on B200, HCA skips GEMMs at T=1 decode | -| **P8** | Production fusion targets | L | Where the real wins live | DEFERRED | +| **CSA main compressed** | hd=512 BF16 | 256 MB × 30 = ~7.5 GB | Random access via top-k (~1024 entries / query) | **Yes — FP4 strongly indicated** | +| **CSA indexer keys** | c_I=128 BF16 | 64 MB × 30 = ~2 GB | Streamed full-cache for top-k scoring | **Yes — FP4 paper-specified §5.2.1** | +| **HCA compressed** | hd=512 BF16 | 8 MB × 30 = 240 MB | Full sequential read every layer | **Yes — FP4 indicated** | +| **SWA** | hd=512 BF16 | 128 KB × 61 = 8 MB | Sequential ring buffer, recent 128 tokens | **No — too small to matter** | -**Do P0 and P1 first.** They are tiny changes, individually catch the -biggest wins, and unlock all the downstream work (CUDA graphs, prefill -throughput, real-world context lengths). +Total BF16: ~10 GB at 1M context. Per the prior audit rewrite, this fits +comfortably on 8×B200. So **KV quantization is a throughput question, not +a memory question.** + +## Why FP4 storage is the right answer for the compressed streams + +Three reasons, in priority order: + +1. **Paper-aligned.** §5.2.1 explicitly specifies the indexer QK path + runs entirely in FP4. The main compressed KV cache being FP4 is + consistent with the rest of the NVFP4 model — the cache is, after all, + just stored projections of NVFP4 weights × BF16 hidden states. + +2. **Bandwidth.** Decode is KV-read-bound at long context. Reading + FP4 instead of BF16 quarters the bytes-per-token loaded by FMHA. + At top_k=1024, hd=512, 30 CSA layers: that's `30 × 1024 × 512 × 1.5 bytes + saved = 23 MB/token saved`. Across batch=8 and millions of decode + steps, real money. + +3. **Kernel-native on Blackwell.** Loading FP4 → tcgen05.mma is a + first-class path with TMA + UMMA + the `mxf4nvf4` MMA kind. The + in-kernel dequant happens for free during the MMA. **The infrastructure + exists in the production FMHA kernel already** (per the prior + `epilogue_op` work and the `ENABLE_FP4_EPILOGUE` template param). + +## What this looks like in code + +The compressed KV write path currently lands BF16 in `comp_kv_buf`. The +production sequence should be: + +1. Compressor produces BF16 output (still — the softmax compression needs + accumulation precision). +2. Quantize-to-NVFP4 in the same kernel as the compression (epilogue + fusion), using the **same NVFP4 quant primitives the linears already + use** (`quantize_nvfp4_gpu_fused`). +3. Store FP4 + per-block E4M3 scales in `comp_kv_buf` (which becomes a + FP4 buffer + scale buffer pair). +4. FMHA reads FP4, dequants in-kernel via TMA + tcgen05's native FP4 + path. No `__constant__` LUT needed — the hardware decodes E2M1. + +For the indexer keys this is the same pattern but the consumer is the +indexer scoring kernel (the FP32 einsum today, the FP4 tensor-core scorer +when E7 lands). + +### Falsifiable gate (per stream) + +- **CSA main + HCA + indexer:** end-to-end output cos ≥ 0.999 with FP4 + storage vs BF16. KV cache memory at 8K context drops by ~3.5× (8 → 2.3 + GB). FMHA-bound decode latency at 8K context drops measurably. +- **Recall@k for indexer ≥ 99% vs FP32 oracle** (the bar from the prior + indexer-fix audit). Critical — FP4 must not corrupt top-k ranking. --- -## DOCTRINE — what to refuse during this perf pass +# PART 3 — OTHER FUSION WINS, RANKED BY EFFORT/IMPACT -1. **DSL wall → raw CUDA C++, not Python.** If an agent says "I'll cache the - amax in Python state," that's still Python on the hot path. The right - cache lives in a `torch.Tensor` on device. +## P4 — Fuse RMSNorm into the next NVFP4 quantize -2. **Raw CUDA ≠ scalar math.** When someone reaches for "let's just write a - scalar fused RoPE kernel," remind them the production target is tensor-core - throughput in the NVFP4 GEMM epilogue. Don't ship a scalar fused kernel as - "fast enough." +Q/KV projection input is RMSNormed; RMSNorm is a separate launch. The +NVFP4 quantize kernel already does an amax reduction per group — fusing +RMSNorm (which is *also* an amax-style reduction followed by a scale) +into the quantizer's input is a natural fit. Saves a launch + a BF16 +materialization of `(T, H)` per RMSNorm site (2 per layer = 122/token). -3. **Print, don't guess.** Before claiming P0 is fixed, measure D2H syncs - per decoded token with Nsight or a tracing wrapper. The "we removed - `.item()`" claim is not verified until the sync count drops. +**Effort:** S (kernel-side, but the quantizer already has the right shape). +**Impact:** Medium. 122 launches/token, ~0.7 ms/token from launch overhead alone. -4. **Integration over exploration.** Do not write `linear_v2.py` with - "perf improvements." Edit `linear.py`. The four `_use_runtime_gsa = True` - flags in `single_shot_inference.py` are the test surface: flip them, run, - compare. +## P5 — Fuse mHC pre_block + RMSNorm into a single op -5. **Falsifiable gates.** Every priority above has a measured number. - "It feels faster" does not close the gate. +Same logic as P4 but for mHC. `attn_mhc.pre_block(X_l)` → `rmsnorm` is 3 +kernels back-to-back. Fusable. mHC already exposes a `_project_and_rms` +half per prior audit notes — wire it through both halves of the layer. -6. **Do not optimize cold paths.** Weight loading is cold. mHC weight - conversion is cold. Anything that runs once during `main()` setup is - cold. The hot path is everything inside the `for step in range(MAX_NEW_TOKENS):` - loop. If a proposed change is in `load_all_weights`, `_load_moe_weights_stacked`, - or any of the `make_*` helpers — that's cold, deprioritize it. +**Effort:** S. **Impact:** Medium. ~120 launches/token. -7. **ALWAYS USE THE TEST HARNESS.** `fire_b200_test` for Python, `fire_b200_cuda_test` - for CUDA. No raw SSH. No manual screen sessions. If the harness needs - changes to support your use case, FIX THE HARNESS. Do not bypass it. +## P6 — CUDA graph capture (the big one, last) + +Single biggest single-token win after everything above. Captures the entire +decode step into a graph; replay eliminates **all** launch overhead. +Probably worth 2–3× speedup at batch=1. + +Blockers in v17: +1. `set_device()` boundaries in the layer pipeline (the `cuda.synchronize()` + at line 963) — graph capture spans devices via multi-graph or + per-device sub-graphs. Manageable but not free. +2. Dynamic shape in `KVCache.add_compressed` — `self.n_comp` grows. + Fix: capture *one* graph per prefill chunk size, replay per + decoded token (which has fixed T=1 shape; the growing buffer is + a write into a pre-allocated tensor, capturable). +3. Any conditional `if` on tensor data — debug prints, the assertion at + line 608. Strip from the capture path with a flag. + +**Effort:** L. **Impact:** Huge (the biggest remaining single win). +**Sequence:** land after P0/P1/P2/P3 so the captured graph reflects the +post-fusion structure. + +--- + +# PART 4 — TURBOQUANT: ARCHITECTURAL VERDICT + +Reading `turboquant/`: this is an **ICLR 2026 paper implementation** of +vector-quantization KV compression. Two algorithms: +- MSE-quantize keys/values via codebook (3 bit by default) +- Inner-product-aware quantize keys (preserves dot products) via Algorithm 2 +- Per-vector L2-norm preserved separately, plus QJL sign sketch for + residual recovery + +Operational shape: +- Operates on **standard MHA/GQA shape** `(..., n_heads, head_dim)`, + head_dim typically 128. +- Requires a `head_dim × head_dim` rotation matrix per layer (precomputed + from random seed, shared across heads). +- Has a Triton fused-decode kernel that computes attention scores directly + from packed codebook indices. +- vLLM integration via `turboquant/vllm_attn_backend.py`. + +## Why it doesn't fit DSv4 + +Three structural mismatches, in order of severity: + +### 1. The DSv4 KV cache is already a learned compression + +DSv4 doesn't store per-token KV. The CSA compressor's whole job is to +reduce m=4 tokens into 1 compressed entry via a softmax-weighted mix. +That entry is what gets cached. TurboQuant quantizes the *post-projection +per-token KV* of standard attention — exactly the thing DSv4 has +already replaced with a learned compressor. **You'd be applying a lossy +compression on top of an already-lossy compression**, which (a) compounds +loss in an uncontrolled way and (b) attacks the wrong dimension. The +compressed entries are already 4× (CSA) or 128× (HCA) reduced in the +sequence dimension; further reducing the *head dimension* via codebook +gives little additional savings (you're already attending over very few +entries per query) at high quality cost. + +### 2. Wrong shape, wrong primitive + +TurboQuant operates on `(..., n_heads, head_dim=128)` per-token vectors +and uses a `128×128` random rotation. DSv4's compressed cache is shape +`(n_comp, head_dim=512)` — no head dimension. The whole "rotate the head +dim" abstraction needs to be reworked, and once you do, you're writing +new code that isn't TurboQuant anymore. + +For the indexer keys, the storage *is* per-block 128-dim, which is closer +to TurboQuant's natural shape. But the indexer's scoring math is +`ReLU(q·k) · w_h` summed across heads — TurboQuant's "preserve inner +products" guarantee from Algorithm 2 doesn't compose with the ReLU +nonlinearity. The quantization error becomes worst-case at the threshold, +which is where top-k decisions get made. **Bad fit precisely where it +matters most.** + +### 3. NVFP4 hardware exists; TurboQuant is software-only + +TurboQuant runs as bit-packed uint8 + Triton kernels. It can't use +tcgen05 FP4 tensor cores because its values aren't FP4 — they're +codebook *indices*. So you'd be paying CPU/GPU cycles to dequant via +gathers and per-token rotation matrix-vector multiplies, when the same +storage cost (4 bits/value) is available natively as FP4 with hardware +dequant during MMA. + +The TurboQuant benchmark numbers (+3–5% throughput at 3-bit) are +real, but they're against `bf16_kv` baselines on architectures that +don't have FP4 tensor cores. On Blackwell with NVFP4, the comparison +should be FP4 storage + FP4 MMA — which is strictly better in every +axis (bandwidth, capacity, dequant cost). + +## Where TurboQuant *would* help, and the verdict on whether it's worth it + +The only DSv4 stream where TurboQuant's shape is a natural fit is the +**SWA branch** — uncompressed per-token KV in the sliding window, 128 +tokens × `n_layers` × `hd=512` = 8 MB at 1M context. + +**It's 8 MB.** Not worth a new dependency, a paper-grade extra failure +mode, or the rotation overhead. The SWA branch fits in L2 cache on B200. + +### Verdict + +Don't use TurboQuant. The right move for DSv4's KV cache is **FP4 storage ++ FP4 MMA on the compressed streams**, fully Blackwell-native, paper- +aligned (§5.2.1), with no codebook lookup overhead. The infrastructure to +do this is already in your kernel library (the `ENABLE_FP4_EPILOGUE` +template, the FP4 MMA path). + +If you want a paper to cite for "what's the state-of-the-art KV +compression in 2026," TurboQuant is one. If you want the highest-perf +production-grade DSv4 implementation, native FP4 is the answer. + +--- + +# PRIORITY ORDER + +| # | Item | Effort | Win | Type | +|---|---|---|---|---| +| **P0** | Call `set_fused_swiglu(True)` on all MoEs | **XS** | **240–480 µs/token** | one-line script fix | +| **P1** | Same for shared expert (after print-and-confirm) | S | ~120 µs/token | likely script fix | +| **P2** | Drop per-call `fill_()` in Nvfp4Linear | S | ~1.2 ms/token | library fix | +| **KV-1** | FP4 storage for CSA main compressed KV | M | Huge at long context | kernel + script | +| **KV-2** | FP4 storage for HCA compressed KV | M | Same pattern as KV-1 | reuses KV-1 work | +| **KV-3** | FP4 storage for indexer keys (pair with E7) | M | Throughput + paper compliance | kernel work | +| **P3** | RoPE fused into Q/KV GEMM epilogue | M | 183 launches/token | kernel work | +| **P4** | RMSNorm fused into next quantize | S | 122 launches/token | kernel work | +| **P5** | mHC pre_block + RMSNorm fused | S | ~120 launches/token | kernel work | +| **P6** | CUDA graph capture | L | **2–3× total** | after everything above | + +**P0 first.** It's a one-line edit that unlocks the fused kernel that +already exists. It is the most embarrassingly easy and most embarrassingly +overlooked perf bug in v17. The kernel author already did the hard work; +the script just isn't asking for it. + +After P0/P1/P2 land, the linear hot path is genuinely tight and the +remaining wins are kernel-side fusion (P3/P4/P5) and the KV cache dtype +question (KV-1/KV-2/KV-3). Land all of those before attempting CUDA +graphs — the captured graph should reflect the final fused structure, not +the pre-fusion one. + +--- + +# DOCTRINE + +1. **DSL wall → raw CUDA C++, not Python.** Applies to P3/P4/P5 (kernel- + side fusion work). The fused-SwiGLU kernel already exists as a model + for what these should look like — it's NVFP4 GEMM + arbitrary-op + epilogue in registers, fully Blackwell-native. + +2. **Raw CUDA ≠ scalar math.** Applies to KV-1/KV-2/KV-3. The FP4 + storage path on the read side uses `tcgen05.mma`'s native E2M1 decode + — no scalar dequant, no `__constant__` LUT (which was only needed + for the indexer scoring CUDA-core path). + +3. **Print, don't guess.** Applies in particular to P1 (verify SE + shapes can use the MoE fused kernel) and KV-1/KV-2 (print the actual + compressor output before deciding the FP4 quant boundary — same + pattern that found the indexer bug). Do not assume the compressor + emits a shape that matches the FP4 quant kernel; print and confirm. + +4. **Integration over exploration.** Do not write `Nvfp4MoE_v2`. Do not + write `KVCache_fp4_v2`. Edit the existing classes. P0 is one line in + `main()`. KV-1/KV-2 are 2-tensor type changes plus the kernel-side + read path. + +5. **Falsifiable gates.** Already listed per priority. Meta-gate: after + P0–P5 land, decode latency at 8K context should be **single-digit + ms**, not three-digit. If it isn't, something is still on the hot + path that shouldn't be, and the answer is "profile, don't guess + next." + +6. **Don't optimize for problems you don't have.** TurboQuant is the + cautionary tale here. The KV cache at 1M is 10 GB on 8 × B200 — that + is not a problem that needs solving with a new dependency. The + problem is throughput, and the right answer is FP4 storage + FP4 MMA, + which is hardware-native and doesn't require codebook lookups. \ No newline at end of file diff --git a/dsv4/layers/linear.py b/dsv4/layers/linear.py index fb007010..4281d8ee 100644 --- a/dsv4/layers/linear.py +++ b/dsv4/layers/linear.py @@ -136,6 +136,7 @@ class Nvfp4Linear: with torch.no_grad(): _, _, gs = quantize_to_nvfp4(hidden_states_sample) self._activation_global_scale = gs + self._gsa_buf_initialized = False # P2: re-fill on next call def run(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -177,7 +178,11 @@ class Nvfp4Linear: self._gsa_buf.copy_(gsa_gpu[:1].reshape(1)) # GPU → GPU, no sync else: from dsv4.ops.quantize import quantize_nvfp4_gpu - self._gsa_buf.fill_(self._activation_global_scale) + # P2: _activation_global_scale is set once at warmup — no per-call fill needed. + # The buffer retains its value across calls (GPU tensor, persistent). + if not getattr(self, '_gsa_buf_initialized', False): + self._gsa_buf.fill_(self._activation_global_scale) + self._gsa_buf_initialized = True x_fp4, x_sf = quantize_nvfp4_gpu(hidden_states, self._activation_global_scale) # Scatter x_fp4 into padded buffer diff --git a/dsv4/layers/shared_expert.py b/dsv4/layers/shared_expert.py index 8ac4d803..b8a8e0cd 100644 --- a/dsv4/layers/shared_expert.py +++ b/dsv4/layers/shared_expert.py @@ -26,10 +26,13 @@ from dsv4.ops.quantize import ( ) from dsv4.ops.layouts import ( make_b_k_major, + interleave_l1_weights, ) from dsv4.ops.gemm_runner import ( run_nvfp4_grouped_gemm, + run_fused_swiglu_grouped_gemm, ) +from dsv4.ops.quantize import quantize_nvfp4_gpu_fused from dsv4.kernels.gemm.grouped import ( ceil_div as cutedsl_ceil_div, pad_and_swizzle_single, @@ -62,6 +65,7 @@ class Nvfp4SharedExpert: self.max_num_tokens = max_num_tokens self.device = device self.swiglu_limit = swiglu_limit + self._fused_swiglu = False # Set via set_fused_swiglu() # Weights (set after construction, then call finalize_weights) self.l1_fp4 = None @@ -99,6 +103,10 @@ class Nvfp4SharedExpert: def set_swiglu_limit(self, limit: float): self.swiglu_limit = limit + def set_fused_swiglu(self, enabled: bool): + """Enable fused L1 GEMM + SwiGLU kernel (1-group variant of MoE fused kernel).""" + self._fused_swiglu = enabled + def finalize_weights(self): """Process weights for CuTeDSL GEMM. Must be called after setting l1/l2 weights.""" # Convert uint8 checkpoint weights to float4_e2m1fn_x2 view @@ -107,6 +115,11 @@ class Nvfp4SharedExpert: # Checkpoint weight is (N_packed, K_packed), make_b_k_major expects (E, K_packed, N_packed) l1_stacked = torch.stack(l1_view).permute(0, 2, 1).contiguous() l2_stacked = torch.stack(l2_view).permute(0, 2, 1).contiguous() + # P1: Interleave L1 gate/up weights for fused SwiGLU kernel compatibility. + # The fused kernel's SwiGLU epilogue expects granularity-8 interleaved gate/up. + # The unfused path (if _fused_swiglu=False) deinterleaves the GEMM output before splitting. + if self._fused_swiglu: + l1_stacked = interleave_l1_weights(l1_stacked, granularity=8) # Stack weights and convert to K-major self._l1_mat_b = make_b_k_major(l1_stacked) # (1, K_packed, N_packed) self._l2_mat_b = make_b_k_major(l2_stacked) @@ -230,6 +243,28 @@ class Nvfp4SharedExpert: + def _run_l1_fused(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Fused L1 GEMM + SwiGLU + clamp — single kernel launch (1-group variant of MoE fused kernel).""" + num_tokens = hidden_states.shape[0] + x_bf16 = hidden_states.reshape(num_tokens, self.hidden_size) + + # Quantize activation to NVFP4 + x_fp4, x_sf, gsa = quantize_nvfp4_gpu_fused(x_bf16) + + # Run fused grouped GEMM with 1 group + l1_out = run_fused_swiglu_grouped_gemm( + mat_a=x_fp4, + scale_a=x_sf, + global_scale_a=gsa, + mat_b=self._l1_mat_b, + scale_b=self._l1_sf_view, + global_scale_b=self._l1_gs_view, + expert_offsets=torch.tensor([num_tokens], dtype=torch.int64, device=x_fp4.device), + swiglu_limit=self.swiglu_limit if self.swiglu_limit is not None else 0.0, + num_tokens=num_tokens, + ) + return l1_out # (num_tokens, intermediate_size) BF16, SwiGLU already applied + def _run_l1(self, hidden_states: torch.Tensor) -> torch.Tensor: """L1 GEMM: activation × gate_up_weight → BF16.""" num_tokens = hidden_states.shape[0] @@ -325,21 +360,24 @@ class Nvfp4SharedExpert: """Actual implementation — called via custom autograd to be torch.compile-safe.""" self._ensure_initialized() - l1_out = self._run_l1(hidden_states) - if l1_out.shape[1] < 2 * self.intermediate_size: - print(f" WARNING: l1_out shape {l1_out.shape} < expected (N, {2*self.intermediate_size})", flush=True) + if self._fused_swiglu: + # P1: Fused L1 GEMM + SwiGLU + clamp in one kernel launch + intermediate = self._run_l1_fused(hidden_states) + else: + l1_out = self._run_l1(hidden_states) + if l1_out.shape[1] < 2 * self.intermediate_size: + print(f" WARNING: l1_out shape {l1_out.shape} < expected (N, {2*self.intermediate_size})", flush=True) - gate = l1_out[:, :self.intermediate_size] - up = l1_out[:, self.intermediate_size:] - if torch.isnan(l1_out).any(): - print(f" SE L1 NaN: l1_out nan at {torch.isnan(l1_out).sum().item()} / {l1_out.numel()} positions, shape={l1_out.shape}", flush=True) - if torch.isnan(gate).any() or torch.isnan(up).any(): - print(f" SE gate nan={torch.isnan(gate).any().item()} up nan={torch.isnan(up).any().item()}", flush=True) - if self.swiglu_limit is not None: - # Match SiluAndMulWithClamp: clamp gate BEFORE silu, clamp up to [-limit, limit] - gate = gate.clamp(max=self.swiglu_limit) - up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit) - intermediate = torch.nn.functional.silu(gate) * up + gate = l1_out[:, :self.intermediate_size] + up = l1_out[:, self.intermediate_size:] + if torch.isnan(l1_out).any(): + print(f" SE L1 NaN: l1_out nan at {torch.isnan(l1_out).sum().item()} / {l1_out.numel()} positions, shape={l1_out.shape}", flush=True) + if torch.isnan(gate).any() or torch.isnan(up).any(): + print(f" SE gate nan={torch.isnan(gate).any().item()} up nan={torch.isnan(up).any().item()}", flush=True) + if self.swiglu_limit is not None: + gate = gate.clamp(max=self.swiglu_limit) + up = up.clamp(min=-self.swiglu_limit, max=self.swiglu_limit) + intermediate = torch.nn.functional.silu(gate) * up output = self._run_l2(intermediate) return output diff --git a/single_shot_inference.py b/single_shot_inference.py index 795c7c9d..1730e72f 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -1021,6 +1021,7 @@ def main(): intermediate_size=cfg.get("moe_intermediate_size", 3072), top_k=cfg.get("num_experts_per_tok", 6), device=dev) moe.set_swiglu_limit(cfg.get("swiglu_limit", 10.0)) + moe._fused_swiglu = True # P0: Enable fused SwiGLU kernel — eliminates 8 BF16 launches per MoE per token _load_moe_weights_stacked(all_w, li, pfx, dev, moe, cfg) # EAGERLY process stacked weights → K-major + swizzle, free raw tensors moe._ensure_stacked() @@ -1035,6 +1036,8 @@ def main(): se = Nvfp4SharedExpert(hidden_size=H, intermediate_size=cfg.get("moe_intermediate_size", 3072), device=dev, swiglu_limit=cfg.get("swiglu_limit", 10.0)) _load_shared_expert_weights(all_w, li, pfx, dev, se, cfg) + # P1: Enable fused SwiGLU for shared expert (1-group variant of MoE fused kernel) + se.set_fused_swiglu(True) # EAGERLY process shared expert weights se._ensure_initialized() # Fix activation global scales — _ensure_initialized sets gsa from l1_gs (which is 1.0)