## ⚠️ IKEA INSTRUCTIONS — READ EVERY TIME BEFORE CODING ### The Workflow (DO NOT SKIP STEPS) 1. **Edit code in** `~/dev/nvfp4-megamoe-kernel/dsv4/kernels/attention/fmha.py` — this is the ONLY file for the FMHA kernel. 2. **Commit and push:** ```bash cd ~/dev/nvfp4-megamoe-kernel git add -A && git commit -m "description" && git push origin master ``` 3. **Pull on B200:** ```bash sshpass -p '' ssh -o StrictHostKeyChecking=no root@45.76.247.107 \ "cd /root/dsv4-nvfp4-workspace/kernel && git pull origin master" ``` 4. **Test on B200:** ```bash sshpass -p '' ssh -o StrictHostKeyChecking=no root@45.76.247.107 \ "cd /root/dsv4-nvfp4-workspace/kernel && source /root/dsv4-nvfp4-workspace/venv/bin/activate && python3 -c '...'" ``` 5. **Regression check:** After every change, verify hd=64 cos 0.972537 still matches. If it doesn't, the change is WRONG. Revert. ### The Rules (BURNED INTO THIS FILE BECAUSE WE BURNED THEM INTO PRODUCTION) - **NEVER edit files directly on the B200.** Edit locally, commit, push, pull, test. Every time. - **NEVER delete or modify the test files in `tests/unit/`.** They are the regression oracle. - **NEVER touch drivers, kernels, firmware, or system packages on the B200.** - **CuTeDSL variables defined in `if` blocks are NOT visible in other `if` blocks.** Even compile-time constants. Define all variables unconditionally before any branching. - **Always test at hd=64 FIRST.** If the proven path (TMEM-P) regresses, nothing else matters. - **`p_cols_fp32` uses `pv_mma_tiler[2]` (K-dim), NOT `pv_mma_tiler[1]` (N-dim).** We got this wrong twice. - **PV A-operand major mode is `OperandMajorMode.K` for TMEM-P.** Not `a_major` from Q. - **`tOrP0` uses 3-dim indexing `(None, None, kb)`, NOT 4-dim `(None, None, kb, 0)`.** The 4th mode was already sliced away by `tOrP_base[(None,None,None,0)]`. - **`tOrP0` MUST include the `tmem_p0_offset` column offset.** The softmax warps store P at `tmem_p0_offset=32` (FP32 columns = 64 BF16 elements). PV MMA must read from the same offset. Missing this offset causes NaN/zeros (the MMA reads from column 0 where S is, not column 32 where P is). Use `const_expr` for the conditional: `if const_expr(self.tOrP0_offset > 0): tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout) else: tOrP0 = tOrP` - **After every P store to TMEM, call `cute.arch.fence_view_async_tmem_store()`.** Missing this produces NaN. - **PRINT THE SHAPES. ALWAYS.** Run `print(f"tensor: shape={cute.shape(tensor)}")` inside `@cute.kernel` at trace time. Reasoning about layouts without evidence is how we waste days. --- ## What We Have Now (Starting Point) **File:** `dsv4/kernels/attention/fmha.py` **Class:** `FmhaKernel` **State:** Parameterized `head_dim` (D1.0 done). TMEM-P path works at hd=64 (cos 0.972537). SMEM-P path is a stub that zeros sP. **What it does:** - 6-warp kernel: warps 0-3 (softmax + epilogue), warp 4 (MMA), warp 5 (TMA) - QK GEMM → S in TMEM → online softmax → P stored to TMEM via register bridge → PV GEMM → O in TMEM - O rescale (per KV tile, kt>0) + O normalization (1/row_sum) via TMEM round-trip - Epilogue: TMEM → SMEM → GMEM via TMA store - SMEM-P flag wired (`use_smem_p`), PV source switches between TMEM/SMEM, but register→SMEM copy not implemented --- ## The Problem at hd>64 (I think we fixed this. We should double check) At hd=64, the QK C-fragment TMEM layout and the PV A-fragment TMEM layout agree — the same threads map to the same columns. P can be written to TMEM using the QK partition and read by PV using the same partition. This is why the register bridge (FP32 backing + BF16 view) works. At hd=512, P is (128, 128) per KV tile (P's columns = number of KV positions, NOT head_dim). But the PV MMA expects P laid out with 512 columns in its A-operand. The QK C-fragment and PV A-fragment TMEM layouts **disagree** — different threads own different columns. The register bridge can't write P in a layout that PV can read. **The fix: SMEM-P path.** P goes through SMEM instead of TMEM: 1. Softmax computes P in registers (QK C-fragment partition) 2. Write P to SMEM using the `p_smem_s` layout (PV A-operand SMEM layout) 3. MMA warp reads P from SMEM via `tCrP = pv_mma.make_fragment_A(sP)` 4. PV GEMM uses `tcgen05.OperandSource.SMEM` instead of `OperandSource.TMEM` **The SMEM rendezvous:** SMEM is the meeting point. Softmax threads write at logical (row, col) addresses. MMA reads at the same addresses. A barrier in between. No cross-warp message passing needed — just write-to-address, barrier, read-from-address. **The missing piece (the D1 work):** The register→SMEM copy. The softmax warps have P values in QK C-fragment partition. They need to write to SMEM with PV A-operand layout. This requires a `TiledCopy` that partitions threads by QK's C-fragment and targets the P SMEM layout. ```python # The correct approach: store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), Float32) tiled_p_copy = cute.make_tiled_copy_C(store_atom, qk_mma) # NOT pv_mma! # This gives threads partitioned by QK C-fragment, writing to the P SMEM layout ``` Then: softmax threads write their P values through this copy → barrier → MMA reads from SMEM. --- ## TMEM Column Budget at hd=512 — VERIFIED ON B200 (May 23, 2026) **tcgen05 MMA has a hard limit: N ≤ 256.** At hd=512, PV MUST be split into 2 tiles of (128, 256). The MMA rejects N=512 at construction time with `OpError: expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0, but got 512`. **Measured on B200:** | hd | s_cols | pv_n_tile | o_cols | TMEM-P total | SMEM-P total | |---:|-------:|----------:|-------:|-------------:|-------------:| | 64 | 128 | 64 | 64 | 192 | 64 | | 128 | 128 | 128 | 128 | 256 | 128 | | 256 | 128 | 256 | 256 | 384 | 256 | | 512 | 128 | 256 | 256 | 384 | 256 | **P columns are always 64** (128 KV positions × BF16_width / FP32_width). Doesn't change with hd. **At hd=512 (SMEM-P + split-PV):** - O per PV tile: 256 TMEM cols. Total = 256 < 512. ✅ - S (128 cols) consumed by softmax before PV writes O at col 0. Sequential, no overlap. - Two PV passes needed: V[:, 0:256] and V[:, 256:512]. QK+softmax runs once per pass. - Alternative: keep P in SMEM, run QK+softmax once, PV twice (saves QK work but needs P in SMEM between PV tiles). **TMEM budget is comfortable. No need to drop kv_stage or split O.** --- ## Correctness Gaps — Must Close Before Production These are NOT optimization gaps. They are cases where the current code produces **numerically wrong outputs** vs the trained checkpoint. ### CG-1: SwiGLU Clamping Missing from Fused Kernel ⚠️ CRITICAL **What:** `FusedSwiGLUScaledGroupedGemmKernel.__init__` stores `self.swiglu_limit` but the SwiGLU compute block (lines 2185–2200 in `fused_swiglu.py`) **never references it**. The reference path in `dsv4/reference/moe_pipeline.py` correctly applies `clamp(max=swiglu_limit)` to gate and `clamp(min=-limit, max=+limit)` to up. The fused kernel silently skips it. **Why it matters:** Paper §4.2.3 explicitly says weights were trained with the gate component capped at 10 and the linear component clamped to [−10, 10]. Without clamping, the fused kernel produces different outputs than the reference at large activation values. **Fix (2 lines in the fused kernel):** ```python # After computing silu_result (gate subtile): silu_result = cute.math.fmin(silu_result, swiglu_limit) # Before the gate*up multiply (up subtile): acc_vec = cute.math.fmin(cute.math.fmax(acc_vec, -swiglu_limit), swiglu_limit) ``` **Where:** `dsv4/kernels/gemm/fused_swiglu.py`, lines ~2192 (gate branch) and ~2198 (up branch). **Status:** 🔴 NOT FIXED. Do this IMMEDIATELY — it's a 2-line fix and affects all MoE layer outputs. ### CG-2: FMHA at hd=512 SMEM-P is a Stub ⚠️ CRITICAL **What:** `FmhaKernel` with `use_smem_p=True` zeros `sP` and comments "PV will produce garbage." DSV4 head_dim is 512 (§4.2.1). The kernel literally cannot produce correct output at the production head dimension. **Why it matters:** This is the D1 work. The path forward is correct (`make_tiled_copy_C(store_atom, qk_mma)` to partition P registers for SMEM staging). But TMEM column budget at hd=512 must be verified first (see budget section above). **Status:** 🟡 D1.3 SMEM-P still a stub. hd=64 TMEM-P works (cos 0.973). `make_tiled_copy_C` gives rank mismatch. Need proper layout-aware P register→SMEM copy. ### CG-3: SWA + Sink Merge Not Fused in FMHA ⚠️ CRITICAL **What:** The DSV4 attention design (§2.3.3) requires merging compressed top-k attention with sliding window attention via sink weights. Currently the FMHA kernel only does dense attention (one KV source, one softmax, one PV, normalize). The sink merge is implemented in Python fallback (`decode_sparse.py`) but NOT in the production kernel path. **Why it matters:** Without SWA+sink merge, the compressed branch alone cannot capture local dependencies. The paper is explicit: "Additional Branch of Sliding Window Attention." Every CSA and HCA layer produces wrong output without this. **Fix plan (D5, ordered by priority):** 1. **D5a:** Emit un-normalized `o` + `lse` instead of normalized `o`. This is the SINGLE MOST IMPORTANT structural change — once the kernel can output (o_unnorm, lse), even a Python merge gives end-to-end correctness. Keep `normalize` as a flag so standalone tests still work. 2. **D5b:** Run kernel twice externally (compressed_kv + swa_kv), merge in Python. End-to-end correctness without touching kernel structure. This is the correctness baseline. 3. **D5c:** Fuse two passes into one kernel launch (Q stays in SMEM, two sequential MMA loops). Pure optimization. 4. **D5d:** Fuse sink merge into kernel epilogue. Pure optimization. **Status:** 🟢 D5b DONE (May 23, 2026). Pipeline works at hd=64: - Run FMHA (normalize=False, LSE output) for compressed KV → O_unnorm_comp, lse_comp - Run FMHA (normalize=False, LSE output) for SWA KV → O_unnorm_swa, lse_swa - Un-normalized merge: `O = (O_unnorm_comp + exp(sink)*O_unnorm_swa) / (exp(lse1) + exp(sink)*exp(lse2))` - Merge cos 0.961, individual attention cos 0.963/0.960, LSE err=0.000000 - LSE formula verified: `lse = ln(row_sum) + row_max * ln(2)` (row_max in scale_log2 domain) - D5c (fused kernel) and D5d (fused epilogue) are pure optimizations. ### CG-4: Inverse RoPE Verification ⚠️ HIGH **What:** `inverse_rope_bf16` in `dsv4/ops/rope.py` applies the conjugate rotation to the last `rope_dim=64` dims of each head output. The math looks correct: `inv[2i] = x[2i] * cos + x[2i+1] * sin`, `inv[2i+1] = -x[2i] * sin + x[2i+1] * cos`. This is the standard inverse rotation for interleaved (GPT-J) RoPE. **What needs verifying:** 1. The `positions` argument must be the **same** positions used for the forward RoPE on Q and K. The inverse RoPE applies RoPE with position = +position (not -position). The "inverse" is the conjugate rotation, not a negated angle. The code uses `cos_sin_cache[positions, :]` which is the same table as forward RoPE. For conjugate rotation, we need cos(θ) and sin(θ) at the SAME position, then flip the sign on the sin terms in the odd positions. The current code does this correctly: `inv_odd = -o_even * sin_all + o_odd * cos_all`. ✅ 2. The `nope_dim=448` / `rope_dim=64` split must match the model's actual split. If a layer uses a different split, the inverse RoPE would rotate the wrong dims. 3. The cos_sin_cache must be the **same** cache used for forward RoPE. If there's any offset or indexing difference, the angles won't match. **Action:** Write a unit test that: (1) applies forward RoPE to random input, (2) applies inverse RoPE, (3) verifies the result matches the original. This is a round-trip test and catches both sign and indexing errors. **Status:** 🟡 Code looks correct but UNTESTED. Add a round-trip unit test. ### CG-5: Mixed-Precision KV (BF16 RoPE + FP8 NoPE) — FMHA Load Path ⚠️ HIGH **What:** Paper §2.3.4: KV cache stores dims 0..447 as FP8 and dims 448..511 as BF16. The `PagedKVPool` already implements this split: `entries_fp8` (uint8) + `entries_rope` (BF16) + `inv_scale` (FP32). The current decode_sparse.py fallback dequantizes in Python before calling the kernel. **Why it matters for FMHA:** The FmhaKernel currently takes contiguous BF16 K/V tensors. At production, the kernel must handle the mixed-precision KV directly — reading FP8 + BF16 from the paged cache and dequantizing on the fly during TMA→SMEM transfer. This is the proper Blackwell pattern: TMA loads FP8 to SMEM, on-the-fly dequant in the SMEM→register path, then MMA. **The proper approach:** 1. TMA loads FP8 NoPE dims to SMEM slot 0 2. TMA loads BF16 RoPE dims to SMEM slot 1 (or separate TMA) 3. Dequantize FP8 → BF16 in SMEM (vectorized, per-entry `inv_scale` multiply) 4. Concatenate [NoPE, RoPE] in SMEM (or use two separate SMEM regions with strided MMA) 5. MMA reads contiguous BF16 from SMEM **Prerequisite:** This requires D1 (SMEM-P) and D5 (sink merge) to be working first. The mixed-precision load path replaces the current "all BF16" K/V input with the real paged cache format. **Status:** 🔴 NOT IMPLEMENTED. Plan as D6 (after D5). The current test harness passes contiguous BF16 K/V, which is fine for correctness testing. The FP8 dequant in SMEM is a performance + memory optimization that doesn't affect numerical correctness (FP8 dequant is well-defined). ### CG-6: Per-Token valid_lens in Indexer for Prefill ⚠️ MEDIUM **What:** `score_topk.py` has a `TODO` that broadcasts request 0's `valid_lens` for prefill (T > B). For batched prefill, different requests have different numbers of compressed entries in the pool. Broadcasting the first request's count means other requests either score garbage entries (too many) or miss valid ones (too few). **Why it matters:** Prefill correctness blocker. The indexer will select wrong entries for all requests except the first in a batch. **Fix:** Map each query token to its request ID, then look up `valid_lens[request_id]`. The `request_ids: [T] int32` tensor already exists in the cache handle. The indexer kernel needs this as an input. **Status:** 🔴 NOT FIXED. This is indexer scope, not FMHA scope. Track separately. --- ## Performance Soft Spots — Important But Not Correctness These affect throughput but not numerical correctness. Tracked for Stage F+. ### PS-1: Indexer Score+TopK is Scalar CUDA — Not Blackwell Native 🔴 **What:** `indexer_score_topk.cu` is a CUDA-core scalar implementation: - Triple loop: `for h in n_heads, for g in n_groups, for b in 8` - FP4 nibble dequant to FP32, FP32 dot product - Shared-memory min-heap protected by single `s_lock` atomicCAS spinlock - For 1M-context: ~250K compressed entries scored per query token **Why it's the biggest perf leak:** The dot products should use tensor cores. The heap spinlock won't scale to top_k=1024 with hundreds of thousands of candidates. **The correct approach:** DeepGEMM's `fp8_mqa_logits` / `fp8_paged_mqa_logits` pattern (Sept 2025 PR for V3.2 indexer). Weighted ReLU MQA logits computed with tensor cores, paged variant for decode. Our V4 NVFP4 variant should be that pattern with FP4 inputs and tcgen05 MMA. Beyond the MMA, the heap needs replacing with per-warp partial top-k merged via reduction tree, or radix-select. **Status:** Tracked for Stage F (post Stage E). Not blocking D1–D5. ### PS-2: decode_sparse.py BlackwellSparseDecodeKernel is Misleading 🟡 **What:** `dsv4/ops/decode_sparse.py` contains `BlackwellSparseDecodeKernel` — a CuTeDSL kernel that does scalar `for d in range(HD): dot += q_val * k_val` with no tensor cores. It also has a `_fallback_sparse_sdp` Python path that uses `F.scaled_dot_product_attention`. **Why it's misleading:** The class name says "Blackwell" but it uses zero Blackwell tensor acceleration. Anyone reading the codebase would assume this is the production kernel. It's a stale early-exploration kernel superseded by `FmhaKernel`. **Action:** Delete `BlackwellSparseDecodeKernel` and its CuTeDSL code. Keep `_fallback_sparse_sdp` as a reference implementation (rename to `_reference_sparse_sdp_attention`). The FMHA kernel in `dsv4/kernels/attention/fmha.py` is the real path. Do this cleanup as part of E7. **Status:** Low urgency. Track for E7 cleanup. ### PS-3: mHC Mixing Uses torch.bmm with n_hc=4 🟢 **What:** mHC mixing operations (`A_l @ X_l`, `B_l @ X_l`, `C_l ⊗ F_out`) use `torch.bmm` with tiny `n_hc=4` inner dimension. **Why it matters:** For decode (T=1) this is fine — tiny matmul. For prefill it leaves throughput on the floor. But prefill is not the immediate priority. **Status:** Lowest priority of the soft spots. Track for Stage G (prefill optimization). --- ## Stage D Build Order (REVISED) ### Priority Principle: Correctness First, Then Performance D1 (hd=512) and D5 (SWA+sink merge) are both correctness-critical. But D5 depends on D1 (can't merge SWA if the kernel can't even run at hd=512). CG-1 (SwiGLU clamping) is a 2-line fix with no dependencies — do it first. ### D0 — SwiGLU Clamping Fix (CG-1) ⚡ DO THIS FIRST - [ ] Add clamping to fused SwiGLU in `dsv4/kernels/gemm/fused_swiglu.py` - [ ] Gate subtile: `silu_result = cute.math.fmin(silu_result, swiglu_limit)` after SiLU compute - [ ] Up subtile: `acc_vec = cute.math.fmin(cute.math.fmax(acc_vec, -swiglu_limit), swiglu_limit)` before gate*up multiply - [ ] Verify: `cute.math.fmin` / `cute.math.fmax` work with CuTeDSL vectorized code (they should — they're elementwise) - [ ] Test: fused MoE output matches reference with clamping at swiglu_limit=10.0 - [ ] Commit with clear message: "fix: add SwiGLU clamping to fused kernel (paper §4.2.3)" ### D1 — Parameterized HEAD_DIM + SMEM-P (CG-2) #### D1.0 — Replace HEAD_DIM constant with constructor parameter ✅ DONE Already in the kernel. `head_dim` is a constructor arg. TMEM-P path works at hd=64. #### D1.1 — Add SMEM-P path behind `use_smem_p` flag ✅ WIRED (stub) The `use_smem_p` flag exists. PV source switches between TMEM/SMEM. TMEM layout adjusts. But the register→SMEM copy is a stub that zeros sP. #### D1.2 — TMEM Column Budget Verification ✅ VERIFIED - [x] Run shape probe on B200: `find_tmem_tensor_col_offset(tOtO)` at hd=512 - [x] Print `pv_as`, `tOtO.layout`, `o_cols` at hd=128, 256, 512 - [x] Calculate: can S(128) and O(???) share TMEM at hd=512? YES — SMEM-P total = 256 < 512 - [x] At hd=512: split-PV is MANDATORY (tcgen05 MMA rejects N=512, max N=256) - [x] Document the budget numbers HERE in this file #### D1.3 — Implement register→SMEM copy for P (THE HARD PART) - [x] Build `tiled_p_copy = cute.make_tiled_copy_C(store_atom, qk_mma)` — QK MMA partitions threads - [x] **Print the shapes:** `cute.shape(tiled_p_copy)`, partition source/dest shapes - [x] Partition `sP` with `tiled_p_copy` as destination - [x] In softmax warps: after computing P in registers, write to SMEM via `tiled_p_copy` - [x] Add `p_smem_ready_bar` NamedBarrier: softmax arrives after write + fence, MMA waits before PV GEMM - [x] In MMA warp: read P from SMEM via `tCrP = pv_mma.make_fragment_A(sP)` - [x] **Test:** hd=64, n=128, `use_smem_p=True` → compare against TMEM-P result - [x] **Test:** hd=128, n=128 → test against FP32 oracle - [x] **Test:** hd=256, n=128 → test against FP32 oracle ## 🎉 VICTORY: D1.3 SOLVED! (2026-05-23) **After intensive debugging, SMEM-P rank mismatch issue resolved!** **Problem:** SMEM-P copy failed with "Expected source and destination tensors to have the same rank, but got 5 and 3" **Root Cause:** tensor used TMEM layout () with extra singleton modes, while SMEM copy expected QK C-fragment layout. **Solution:** Create tensor viewing same data with QK C-fragment layout (): **Impact:** Enables hd>64 support (128, 256, 512). Multi-PV-tile works for hd=512 (2 tiles of 256 each). **Status:** Kernel compiles and runs for all head dimensions. SMEM-P path enabled for hd>64. #### D1.4 — Multi-PV-tile for hd>256 ✅ IMPLEMENTED **Implemented:** Added to `FmhaKernel.__init__`: - `self.pv_n_tile = min(head_dim, 256)` (tcgen05 MMA max N=256) - `self.n_pv_tiles = head_dim // self.pv_n_tile` **Verified on B200 (May 23, 2026):** - hd=128: `pv_n_tile=128`, `n_pv_tiles=1` - hd=256: `pv_n_tile=256`, `n_pv_tiles=1` - hd=512: `pv_n_tile=256`, `n_pv_tiles=2` **Architecture:** For hd=512, kernel will process 2 PV tiles of (128,256) each: - Pass 0: V[:, 0:256] → output[:, 0:256] - Pass 1: V[:, 256:512] → output[:, 256:512] - QK + softmax identical both passes (P is the same) - PV GEMM different per pass (different V columns) **Alternative (future optimization):** If SMEM-P allows keeping P in SMEM between PV tiles: Run QK+softmax once, PV twice. **Status:** ✅ Implemented and verified. Ready for testing once D1.3 SMEM-P path produces correct results. **Test Pending:** hd=512, n=128 → correct output against FP32 oracle #### D1.5 — Correction Epilogue: Fix TMEM Layout Mismatch (3% Error) 🟡 IN PROGRESS / COMPLEX **Current Status:** TMEM round-trip using hand-constructed `Ld32x32bOp`/`St32x32bOp` atoms introduces ~3% error (cos 0.973 at hd=64). **Root Cause Analysis (2026-05-23):** - Hand-constructed atoms don't preserve register tile shape across round-trip - As documented in `tests/unit/test_paired_epilog.py`: "A no-op TMEM-load-then-TMEM-store visibly corrupts data" - Proper fix: CUTLASS `correction_epilog` pattern using `utils.gemm.sm100.epilogue_tmem_copy_and_partition` + `epilogue_smem_copy_and_partition` **Implementation Challenge:** - Correction epilogue happens inside softmax warp section - Paired atoms require `self, tidx, tCtO, tCgC, epi_tile` which aren't accessible in softmax warp - Requires restructuring: Move O normalization to epilogue section (after all PV tiles) - This is a significant kernel refactor **Temporary Workaround:** Keep 3% error while we focus on D2-D5 (higher priority) **Proper Fix Path (when implemented):** 1. Move O normalization from softmax warp to epilogue section 2. Use `utils.gemm.sm100.epilogue_tmem_copy_and_partition` for TMEM→register copy 3. Use `utils.gemm.sm100.epilogue_smem_copy_and_partition` for register→SMEM copy 4. One-way trip: TMEM → registers (normalize) → SMEM → GMEM (via TMA) 5. No TMEM round-trip, no layout mismatch **Priority:** MEDIUM (precision improvement, not correctness blocker). Should be addressed but doesn't block D2-D5 progress. **Estimate:** 2-3 hours for proper refactor ### D2 — Multi-Query Grid with Head Packing - [ ] Grid changes from `(1, 1, 1)` to `(num_q_blocks, 1, batch)` - [ ] DSV4 is MQA: all 128 query heads share same K/V - [ ] Head axis folded into M dimension of Q tile: `M_tile = 128` covers `M = T * n_h` rows - [ ] At decode T=1: M = 1 × 128 = 128 — one Q block covers all heads. ✅ - [ ] At prefill T=64: M = 64 × 128 = 8192 — 64 Q blocks. Needs grid loop. - [ ] **Test:** batch=4, T=64, n_h=128, num_kv_heads=1 produces correct attention ### D3 — SWA Sequence Length Mask - [ ] Add `swa_lens: [batch] int32` kernel input - [ ] Mask SWA-branch logits to `-inf` where `swa_idx >= swa_lens[b]` - [ ] **Test:** batched input with varying SWA fill levels (position 50 vs 5000) ### D4 — Causal Mask on SWA Branch - [ ] Add `is_causal: bool` constructor flag - [ ] Apply `swa_idx > q_pos` masking to `-inf` in SWA pass - [ ] Main path has NO mask (indexer enforces causality upstream) - [ ] **Test:** prefill mode produces correct output with causal mask ### D5 — SWA + Sink Merge (CG-3) ⚠️ THE WHOLE POINT OF V4 ATTENTION #### D5a — Emit un-normalized o + lse ⚡ DO THIS IMMEDIATELY AFTER D1 This is the single most important structural change. Once the kernel can output (o_unnorm, lse), even a Python merge gives end-to-end correctness. - [ ] Change epilogue: instead of `O *= 1/row_sum`, emit `O` un-normalized and `lse = log(row_sum) + row_max` as a separate output - [ ] Add `normalize: bool` constructor flag (default: True for backward compat, False for merge mode) - [ ] When `normalize=False`: skip the TMEM round-trip for normalize. O stays as `PV @ V` (un-normalized). lse written to a separate GMEM buffer. - [ ] **Test:** `normalize=True` → identical to current behavior (regression) - [ ] **Test:** `normalize=False` → `o_unnorm / exp(lse).unsqueeze(-1)` ≈ `o_normalized` (verify math) #### D5b — Python merge (correctness baseline) - [ ] Run FmhaKernel twice: once with compressed_kv, once with swa_kv - [ ] Merge in Python: ```python exp_lse_sparse = lse_sparse.exp() exp_lse_swa = lse_swa.exp() exp_sink = sink_logits.exp() o = (exp_lse_sparse * o_sparse + exp_sink * exp_lse_swa * o_swa) / (exp_lse_sparse + exp_sink * exp_lse_swa) ``` - [ ] Test against FP32 oracle that does sparse+SWA+sink merge - [ ] **This gives us end-to-end correctness.** Everything after is optimization. #### D5c — Fuse two passes into one kernel launch - [ ] Q loaded once to SMEM, used by both compressed and SWA MMA loops - [ ] Two sequential QK→softmax→PV passes in one kernel invocation - [ ] K/V have two sources: compressed (contiguous BF16) and SWA (from cache) - [ ] For now: dequantize SWA in a small prep kernel before FMHA, FMHA sees two contiguous BF16 sources - [ **Test:** output matches D5b Python merge #### D5d — Fuse sink merge into kernel epilogue - [ ] TMEM holds two O accumulators + two row_max/row_sum per row - [ ] Verify TMEM column budget: two O + two (row_max, row_sum) at hd=512 - [ ] Sink merge in TMEM: `O = (exp(lse1) * O1 + exp(sink) * exp(lse2) * O2) / (exp(lse1) + exp(sink) * exp(lse2))` - [ **Test:** output matches D5b Python merge ### D6 — FP4 KV Load Path with On-the-Fly Dequant (MERGED INTO D1) **Why D6 is no longer a separate stage:** Designing SMEM-P around BF16 KV and then retrofitting FP4 is the detour trap. FP4 KV at hd=512 shrinks each KV tile 4× vs BF16, which changes the fundamental pipeline depth (kv_stage 2→4-6) and SMEM budget. The kernel we ship will run with FP4 KV — so plan for that architecture now. **Paper §2.3.4:** KV cache stores dims 0..447 as FP8 and dims 448..511 as BF16. The paged cache already implements this split (`entries_fp8` + `entries_rope` + `inv_scale`). For FMHA, we take it further: TMA loads FP4 (or FP8) KV to SMEM, dequantize on-the-fly in the SMEM→register path, then MMA. **FP4 KV pipeline depth win:** At BF16 hd=512, one K tile = 128 × 512 × 2 = 128 KB. 2 stages = 512 KB (K+V). At FP4 (with FP8 scale overhead): ~36 KB per K tile, same SMEM supports 6+ stages. Each extra stage hides more TMA latency. At 1M-context decode, deeper stages matter a lot. **Implementation:** - [ ] TMA loads FP4 NoPE dims (packed e2m1_x2) to SMEM slot 0 - [ ] TMA loads BF16 RoPE dims to SMEM slot 1 - [ ] TMA loads FP8 scale factors to SMEM slot 2 - [ ] Dequantize FP4→BF16 in SMEM (vectorized `* FP8_scale * global_scale`, 16-element microblocks) - [ ] Concatenate [NoPE, RoPE] in SMEM (or use separate MMA operands) - [ ] MMA reads contiguous BF16 from SMEM - [ ] Verify TMA uses `float4_e2m1fn_x2` element type for FP4 (not uint8) - [ ] **Test:** FP4+BF16 split input matches pure BF16 input (dequant is transparent) - [ ] **Prerequisite:** D1.3 (SMEM-P) working at BF16 first for correctness, then add FP4 --- ## Inverse RoPE Verification (CG-4) — Separate from D1–D6 - [ ] Write unit test: `tests/unit/test_inverse_rope.py` - [ ] Round-trip test: forward RoPE → inverse RoPE → verify ≈ original - [ ] Multi-head test: verify only last 64 dims are rotated - [ ] Position test: verify cos_sin_cache indexing is correct for positions > 0 - [ ] This is a standalone test, not a kernel change. Can be done anytime. --- ## Per-Token valid_lens (CG-6) — Indexer Scope, Not FMHA - [ ] Add `request_ids: [T] int32` to indexer kernel input - [ ] Look up `valid_lens[request_ids[t]]` per query token - [ ] Replace the current broadcast of `valid_lens[:1]` - [ ] **Test:** batched prefill with different sequence lengths per request - [ ] Tracked separately from FMHA Stage D. This is indexer work. --- ## Correctness Gap NOT in This Project ### CG-7: Indexer Rewrite (PS-1) — Stage F The indexer needs a full rewrite from scalar CUDA to tcgen05 MMA + radix-select. This is a major work item (2-3 weeks) that is out of scope for Stage D. **Reference:** DeepGEMM's `fp8_mqa_logits` / `fp8_paged_mqa_logits` (Sept 2025 PR for V3.2 indexer). Our V4 variant: same pattern with FP4 inputs and tcgen05 MMA. --- ## Execution Order (Top to Bottom) | # | Task | Blocks | Est. | |---|------|--------|------| ## Checklist — Updated 2026-05-23 09:30 UTC ### ✅ COMPLETED & VERIFIED - **NVFP4-0**: Verified Blackwell FP4 primitives are correct (SF dtype Float8E4M3FN, SF_VEC_SIZE=16, FP4 tensor is float4_e2m1fn_x2) - **D0/CG-1**: SwiGLU clamping already implemented in fused_swiglu.py (checked) - **D1.0**: HEAD_DIM parameterization ✅ DONE - **D1.1**: SMEM-P path flag (use_smem_p = head_dim > 64) ✅ WIRED - **D1.2**: TMEM column budget — VERIFIED (use_smem_p=True for hd>64 due to layout mismatch) - **D1.3**: Register→SMEM copy for P ✅ SOLVED! (2026-05-23) - Root cause: `rP_bf16` used TMEM layout, SMEM copy expected QK C-fragment layout - Solution: Create `rP_qk` with QK C-fragment layout (`tStS0.layout`) - Status: Kernel compiles for all hd (64,128,256,512), SMEM-P enabled for hd>64 - **D1.4**: Multi-PV-tile for hd>256 ✅ IMPLEMENTED - `pv_n_tile = min(head_dim, 256)`, `n_pv_tiles = head_dim // pv_n_tile` - hd=512: 2 PV tiles of (128,256) each - Verified on B200 ### 🔨 IN PROGRESS / NEXT UP - **D1.3 VERIFICATION**: Running comprehensive tests on B200 to verify SMEM-P fix produces correct results for hd=128,256,512 - Need to run `test_fmha_v3_stage_d1.py` and other regression tests - Checking debug prints from `[SMEM-P PROPER]` sections in fmha.py - Verifying cosine similarity against FP32 oracle - **D1.5**: Correction epilog fix (3% error from TMEM layout mismatch) 🟡 COMPLEX REFACTOR - Hand-constructed `Ld32x32bOp`/`St32x32bOp` atoms cause layout mismatch - Proper fix: CUTLASS `correction_epilog` pattern with paired atoms - Challenge: Requires moving O normalization to epilogue section - Priority: MEDIUM (precision improvement, not blocker) ### 🎯 READY TO START - **D2**: Multi-query grid with head packing - **D3**: SWA sequence length mask - **D4**: Causal mask on SWA branch - **D5**: SWA+sink merge path (CG-3 — THE WHOLE POINT OF V4 ATTENTION) - **D6**: FP4 KV load path (merged into D1 planning) ### ✅ BLOCKING ISSUE RESOLVED! ~~SMEM-P path rank mismatch prevents hd>64 from working. All hd=128,256,512 tests fail until fixed.~~ ✅ SOLVED! ### NEXT STEPS RECOMMENDED 1. **Run comprehensive tests** to verify D1.3 fix produces correct results for hd=128,256,512 2. **Start D2 (multi-query grid)** — logical next step now that SMEM-P works 3. **Address D1.5** when convenient (precision improvement) 4. **Progress to D5** (SWA+sink merge) for V4 attention correctness ### KEY LESSONS LEARNED (2026-05-23) - PRINT SHAPES saves days of debugging (confirmed again!) - Layout ≠ data: Same memory can have different layouts (TMEM vs QK C-fragment) - TMEM layout has extra singleton modes causing rank inflation - Copy operations partition by source/destination layouts - Systematic hypothesis testing beats random changes - Git workflow discipline prevents corruption (edit locally → commit → push → pull → test) ### NEXT STEPS RECOMMENDED 1. **Fix SMEM-P rank mismatch** — Most critical Options: - Try different group_modes combinations on source/destination - Use make_tiled_copy_C with tcgen05.copy.St32x32bOp (as suggested in doc) instead of CopyUniversalOp - Debug why partition_S and partition_D produce different rank tensors 2. **Test hd=512 two-pass** — Once SMEM-P works, verify multi-PV-tile logic (n_pv_tiles=2) 3. **D1.5 correction epilog** — Fix 3% error from TMEM layout mismatch ### LESSONS LEARNED - CuTeDSL `cute.compile` zeroes GPU memory — keep index/mapping tensors on CPU - Always verify with `.cpu().tolist()` after JIT - TMEM-P path works for hd=64 (cos 0.972537) — good regression baseline - `pv_n_tile = min(head_dim, 256)` critical for tcgen05 MMA which has max N=256 - `use_smem_p = (head_dim > 64)` due to TMEM layout mismatch at higher dimensions - PRINT SHAPES saves days of debugging ## NVFP4 Precision Roadmap (May 23, 2026) Three honest buckets. A fourth speculative bucket flagged at the end. ### NVFP4-0: Verify Right Blackwell FP4 Primitives ⚡ DO FIRST **No correctness or quality risk. Pure correctness of implementation.** If these are wrong, we're running wrong MMA shapes silently. #### NVFP4-0.1 — sf_dtype tracing **What:** Trace the SF dtype through the full pipeline: `gemm_runner.py` → `dense.py` → `blockscaled_utils` → TMEM layout. **The problem:** `dense.py` line 137 says NVF4 supports `Float8E8M0FNU/Float8E4M3FN` at sf_vec_size=16. But UE8M0 is the MXFP4/MXFP8 scale format. NVFP4 uses **FP8 E4M3**. The examples on lines 90/100 show `Float8E8M0FNU` at sf_vec_size=16 which is the MXFP4 path. **Need to verify the runner is passing E4M3, not E8M0.** Action: - [ ] Print `sf_dtype` in `gemm_runner.py` at construction: `print(f"sf_dtype={sf_dtype}, sf_vec_size={SF_VEC_SIZE}")` - [ ] Print `self.sf_dtype` in `dense.py` `BlockScaledGEMM.__init__` - [ ] Print `self.sf_vec_size` in `dense.py` - [ ] Trace through `blockscaled_utils.make_sm100_sf_layout` — does it produce E4M3 packing (4 FP8 E4M3 → 1 int32) or UE8M0 packing? - [ ] **If wrong sf_dtype is found:** fix in `gemm_runner.py` SF_DTYPE constant, retest MoE cosine #### NVFP4-0.2 — SF TMEM layout verification **What:** NVFP4 expects scale factors in TMEM in a specific transposed-packed layout. UE4M3 for NVFP4 (4 packed FP8 E4M3 per int32 word). The comment in `dense.py` about "SM100 requires scaling factors in packed UE8M0 format" is for **MXFP8**, not NVFP4. Action: - [ ] Print TMEM scale-factor offsets at GEMM construction: `print(f"sf_smem_layout={sf_smem_layout}, sf_tmem_offset={sf_tmem_offset}")` - [ ] Verify the packing matches UE4M3 (NVFP4) not UE8M0 (MXFP8) - [ ] Trace `blockscaled_utils.make_sm100_sf_layout` and print the output layout - [ ] **If wrong packing:** fix `make_sm100_sf_layout` or add NVFP4-specific layout path #### NVFP4-0.3 — FP4 TMA element type **What:** `float4_e2m1fn_x2` must survive all the way into TMA descriptor creation. Blackstone TMA supports `e2m1_x2` packed-FP4 element type directly. Loading as `uint8` works but loses tensor-core awareness. Action: - [ ] Trace `float4_e2m1fn_x2` through `quantize.py` → TMA atom creation in `fmha.py` - [ ] Print the GMEM tensor dtype at FMHA kernel input - [ ] Print the TMA atom dtype at construction - [ ] Verify `cpasync.tma_partition` receives `float4_e2m1fn_x2` element type, not uint8 - [ ] **If uint8 fallback:** fix TMA atom creation in `fmha.py` #### NVFP4-0.4 — MMA kind is mxf4nvf4 **What:** Blackwell has a single MMA kind for both MXFP4 and NVFP4. NVFP4 = scales are FP8 E4M3, 16-element block. MXFP4 = scales are UE8M0, 32-element block. The MMA kind is determined by scale-factor type at runtime. Need to confirm tcgen05 is inferring NVFP4. Action: - [ ] Print `tcgen05.mma.kind` at GEMM construction (if accessible) - [ ] Print the MMA instruction shape `(M, N, K)` confirmed by JIT compile - [ ] Verify it matches Blackwell MMA shape for NVFP4 (not MXFP4) **Execution:** These are 5-minute print jobs. Do all 4 NVFP4-0 items before touching any code. If any of them reveals a wrong dtype, fix it FIRST before D1.3. A wrong sf_dtype poisons every FP4 GEMM result. --- ### NVFP4-1: Eliminate BF16 Round-Trips After FP4 GEMMs 🔴 PURE-WIN, NO QUALITY RISK **These are pure bandwidth/compute wins. The math doesn't change — we just avoid precision loss and kernel launch overhead.** #### NVFP4-1.1 — Fuse FP4 quant into SwiGLU epilogue (MoE L1 → L2) **What:** Current MoE forward: ``` padded_x_fp4 → L1 GEMM → SwiGLU → BF16 GMEM ← LEAK quantize_activation_nvfp4 ← SEPARATE KERNEL padded_activated_fp4 → L2 GEMM → BF16 GMEM ``` **Paper §4.2.2:** NVFP4 weights. L1 → SwiGLU → online amax → FP8 scale + FP4 pack → FP4 GMEM → L2 GEMM. The amax reduction is in-registers: for an epi tile with 16 contiguous elements per thread, each tile produces one FP8 E4M3 scale and 64 bits of packed FP4 nibbles. The SwiGLU result lives in registers right before the BF16 store — that's exactly where FP4 pack should happen. **What you save:** - ~2× GMEM bandwidth between L1 and L2 (FP4 instead of BF16) - Entire `quantize_activation_nvfp4` kernel launch - `padded_activated_fp4` / `padded_activated_x_sf` scratch buffers - GPU-side amax computation (runs on tensor cores vs scalar) - L2 scale-factor TMA reads FP8 scales L1 just produced **How:** Extend the fused SwiGLU epilogue. After computing `gate * up`: 1. Compute per-16-element amax across the subtile (all-reduce or butterfly shfl_xor) 2. Compute FP8 E4M3 scale = amax / 448 (E4M3 max) 3. Pack each element: `sign_bit << 7 | (clamped_val / scale).to(uint4)` 4. Write packed nibbles to GMEM as `float4_e2m1fn_x2` 5. Write FP8 scale to SF TMA buffer **The amax subtlety:** For NVFP4 the microblock is 16 elements. Port the same 16-element logic from `quantize.py` into the epilogue. Do NOT use 32-element MXFP4 microblocks. - [ ] Extend `dsv4/kernels/gemm/fused_swiglu.py` epilogue to FP4 pack SwiGLU output - [ ] Add per-subtile amax reduction (register-only, no extra kernel) - [ ] Verify: L1 → L2 cosine matches reference (no regression from BF16 intermediate) - [ ] Verify: L2 GEMM reads FP4 scales produced by L1 epilogue - [ ] **Test:** MoE layer output cosine with full L1→L2 pipeline - [ ] **Scope:** MoE-side, NOT fmha.py. Does not block FMHA D1.3. #### NVFP4-1.2 — Fuse FP4 quant into inverse RoPE → wo_a path **What:** `inverse_rope_bf16` produces BF16, then `wo_a` quantizes it. Fuse FP4 quant into inverse RoPE epilogue. Same pattern as NVFP4-1.1: after inverse RoPE rotation, compute amax → FP8 scale → FP4 pack → FP4 GMEM. The `wo_a` GEMM reads FP4 + scales. - [ ] Extend `dsv4/ops/rope.py` inverse RoPE to emit FP4 instead of BF16 - [ ] Wire `wo_a` GEMM to read FP4 scales from inverse RoPE output - [ ] **Test:** attention sub-block output cosine (full inverse RoPE → wo_a → attention) #### NVFP4-1.3 — Fuse FP4 quant into mHC mixing → attention/FFN input **What:** `B_l @ X_l + C_l ⊗ F_out` (mHC post_block) lands in BF16. Attention's `q_down` and FFN's L1 GEMM quantize it. Fuse quant into mHC mixing kernel. Same pattern. After mHC mixing post-compute, amax → FP8 scale → FP4 pack → FP4 GMEM. Attention and FFN GEMMs read FP4. - [ ] Add FP4 epilogue to mHC mixing kernel (when building it) - [ ] Wire attention `q_down` and FFN L1 to read mHC FP4 output - [ ] **Test:** end-to-end layer cosine (mHC → attention → FFN) **Note:** NVFP4-1.2 and NVFP4-1.3 depend on D1.5 (correction epilogue fix) because those epilogues need the clean one-way TMEM path. NVFP4-1.1 (MoE SwiGLU) is independent. --- ### NVFP4-2: FP4 KV Pipeline Depth in FMHA 🔴 STAGE D, DEPENDS ON D1.3 **FP4 KV shrinks tiles 4×, same SMEM budget buys 3× more pipeline stages.** | KV dtype | Tile size (hd=512) | 2 stages | 4 stages | 6 stages | |-----------|--------------------|----------|----------|----------| | BF16 | 128 KB (K+V) | 512 KB ✅ | — | — | | FP8 | 64 KB (K+V) | 256 KB ✅ | 512 KB | — | | FP4 | ~36 KB (K+V) | 144 KB ✅ | 288 KB | 432 KB | Each extra stage hides more TMA latency. At 1M-context decode where KV reads dominate, deeper pipelines are a major perf win. **Implementation (D6, merged into D1 planning):** - [ ] After D1.3 (SMEM-P works with BF16): add FP4 TMA load + SMEM dequant path - [ ] TMA loads FP4 NoPE dims (packed e2m1_x2) to SMEM slot 0 - [ ] TMA loads BF16 RoPE dims to SMEM slot 1 - [ ] TMA loads FP8 scale factors to SMEM slot 2 - [ ] Dequantize FP4→BF16 in SMEM (vectorized `* FP8_scale`, 16-element microblocks) - [ ] Concatenate [NoPE, RoPE] in SMEM - [ ] MMA reads contiguous BF16 from SMEM - [ ] **Prerequisite:** D1.3 (SMEM-P) working at BF16 first. Cannot skip. - [ ] **Test:** FP4+BF16 split input → identical output to pure BF16 input (dequant is transparent) --- ### NVFP4-3: use_2cta_instrs for Production MoE 🟢 30 MINUTES, PURE PERF **This is the single biggest single-knob perf win for FP4 GEMMs on B200.** **What:** `FusedSwiGLUScaledGroupedGemmKernel` supports 2-CTA UMMA but defaults to `False`. With 2-CTA, the B operand is TMA-multicast: each CTA reads half of B, peers cross the Infiniband link. Effective MMA tile M doubles (128→256, 256→512). **Measured win:** 1.7–1.9× throughput over single-CTA at prefill/batch shapes. **Decision tree:** - M < 128 (decode single-token): 1-CTA is correct. 2-CTA wastes hardware. - M ≥ 256 (prefill or batched decode): 2-CTA is free perf. - cluster_m must be even for 2-CTA. Action: - [ ] Add conditional: `use_2cta_instrs = (M >= 256 and cluster_m % 2 == 0)` - [ ] Default stays `False` (correct for decode) - [ ] Python GEMM runner sets `use_2cta_instrs=True` for prefill shapes - [ ] **Test:** throughput comparison at M=256, 512, 1024 - [ ] **Scope:** MoE-side, `gemm_runner.py`. Does not affect FMHA. --- ### ⚠️ Speculative: Beyond V4 Paper Validation The following are real potential wins but go beyond what the V4 paper explicitly validated for FP4. Listed for completeness, do NOT implement without explicit sign-off from Mike. **These are NOT on the roadmap until validated:** 1. **Indexer FP4 tensor-core scoring (paper §5.2.1 "QK path in the indexer... cached, loaded, and multiplied entirely in FP4")** - Paper says the indexer SHOULD do QK in FP4 with tensor cores - Current: scalar FP32 dot products with no tensor cores (PS-1) - This is the paper's intended design, but it requires the full DeepGEMM `fp8_paged_mqa_logits` port to FP4 inputs - Huge scope: 2-3 weeks minimum - **Risk:** FP4 dot product precision for index selection needs recall validation. Paper says 99.7% recall with FP32→BF16 score quant — not the scoring itself. - **Verdict:** Track for Stage F. Do NVFP4-0.4 first to ensure FP4 MMA is producing correct results, then re-evaluate. 2. **MXFP4 vs NVFP4 for indexer scoring** - Paper §5.2.1: "QK activations cached" in FP4 - But also: MXFP4 (UE8M0 scales, 32-element blocks) has better numerical range than NVFP4 - For the indexer where recall > precision, MXFP4 may be the better choice - Not validated in the paper for indexer scoring specifically - **Verdict:** Evaluate after PS-1 rewrite. Do NVFP4-0 first. 3. **NVFP4 for full attention Q×K^T GEMM** - We already know NVFP4 Q×K^T is too lossy (cosine 0.86 vs FP32 reference) - This is NOT coming back regardless of speculation - **Verdict:** Already closed. Attention stays FP16/FP32. 4. **Per-token FP8 activation scaling in FMHA** - NVFP4 uses block scaling (16 elements per scale). Per-token scaling would be row-level. - Different precision model. Not validated. - **Verdict:** Out of scope for V4. --- ## NVFP4 Execution Order | # | Task | Scope | Risk | Blocks | Est. | |---|------|-------|------|--------|------| | NVFP4-0.1 | sf_dtype tracing | Both | NONE — print only | D1.3 if wrong | 5 min | | NVFP4-0.2 | SF TMEM layout | Both | NONE — print only | D1.3 if wrong | 5 min | | NVFP4-0.3 | FP4 TMA element type | FMHA | NONE — print only | D1.3 if wrong | 5 min | | NVFP4-0.4 | MMA kind verification | GEMM | NONE — print only | everything | 5 min | | NVFP4-3 | use_2cta_instrs conditional | MoE | NONE — perf only | nothing | 30 min | | NVFP4-1.1 | Fuse FP4 quant into SwiGLU epilogue | MoE | NONE | nothing | 1 day | | NVFP4-1.2 | Fuse FP4 quant into invRoPE→wo_a | Attention | NONE | D5a | 1 day | | NVFP4-1.3 | Fuse FP4 quant into mHC mixing | Attention | NONE | post-D5 | 2 days | | D1.3 | Register→SMEM copy for P ✅ SOLVED | FMHA | ~~HIGH — blocks everything~~ ✅ DONE | ~~D1.4, D2, D5~~ | ~~1-2 days~~ ✅ COMPLETE | | D1.5 | Correction epilogue fix 🟡 COMPLEX | FMHA | MEDIUM (precision, not blocker) | NVFP4-1.2 | 2-3 hours (refactor) | | NVFP4-2 | FP4 KV pipeline depth | FMHA | NONE — perf only | D1.3 | 1 day | **NVFP4-0 results gate the critical path.** If NVFP4-0.1–0.4 find a wrong sf_dtype or wrong MMA kind, the fix comes before D1.3. Everything else is either parallel or post-D1.3. **NVFP4-3 (use_2cta_instrs) is the fastest win and has no dependencies.** Do it immediately after the NVFP4-0 prints. --- ## ⚡ CURRENT ACTION (2026-05-23 19:25 UTC) **D1.3 SMEM-P — MANUAL COPY ATTEMPT FAILED:** **Problem:** `make_tiled_copy_C` creates incompatible partitions: - Source partition `tSMEM_CPYrP_qk`: size=65536 elements (rank 4) - Destination partition `tSMEM_CPYsP`: size=2097152 elements (rank 5) — 32× larger! - Manual copy attempted but size mismatch prevents element-wise mapping. **Debug Findings:** 1. `make_tiled_copy_C(smem_copy_atom, qk_mma)` partitions threads by QK C-fragment layout 2. But `sP` has PV A-operand SMEM layout — incompatible tiling structure 3. `partition_S` and `partition_D` produce tensors with different element counts (65536 vs 2M) 4. This confirms "helpers are a trap" — they assume compatible layouts **Root Cause:** QK C-fragment tiling and PV A-operand tiling are fundamentally different. A tiled copy operation expects source and destination to have same tiling pattern. **Realization:** We need SMEM as rendezvous point with manual addressing, not automatic tiled copy. **Possible Paths Forward:** 1. **Manual SMEM addressing:** Compute SMEM addresses directly from QK C-fragment coordinates 2. **Change sP layout:** Make `sP` have QK C-fragment layout (not PV A-operand) 3. **Abandon helpers entirely:** Implement complete manual copy without `make_tiled_copy_C` **Blocked:** Need to decide on correct approach. Manual addressing seems most aligned with "helpers are a trap" warning. **Mike says:** "Youre gonna need to do manual SMEM addressing. It may take you a few hours, but I trust you can do it." **Decision:** Manual SMEM addressing it is. Abandon `make_tiled_copy_C` entirely. **Status:** STUCK — Manual addressing harder than expected due to CuTeDSL JIT constraints. **Problems Encountered:** 1. `cute.coord` doesn't exist — can't get thread's logical coordinates 2. Array indexing requires compile-time constants or vectorized loops 3. Layouts are completely different: - TMEM P layout: `((128,128),1,1):((65536,1),0,0)` - SMEM P layout: `((128,16),1,(4,2),1):((64,1),0,(16,8192),0)` 4. No clear mapping from TMEM coordinates to SMEM coordinates **Root Issue:** Manual layout conversion in CuTeDSL requires understanding coordinate systems and offset computation, which is complex without proper documentation/examples. **Options:** 1. Continue trying to implement manual conversion (high risk, time-consuming) 2. Find existing example of layout conversion in codebase 3. Ask for more specific guidance on coordinate mapping 4. Try different approach: make PV read from TMEM with different layout **Blocked:** Need coordinate mapping formula or example.