D1.5: Fix O rescale with paired atoms (incremental approach)
Keep epilogue_tma_store for final output (proven path). Only fix the multi-KV-tile O rescale using paired atoms from epilogue_tmem_copy_and_partition. The paired atoms share addressing, making the TMEM->REGS->modify->TMEM cycle lossless. Guarded by const_expr(n_kv_tiles > 1) so single-tile path (n=128) is completely unaffected — zero regression risk. Full correction epilogue (one-way TMEM->REGS->SMEM->GMEM) deferred until we can address the MLIR compilation time issue.
This commit is contained in:
805
README.md
805
README.md
@@ -1,636 +1,339 @@
|
||||
# DSV4 Inference Kernel
|
||||
|
||||
## ⚠️⚠️⚠️ CRITICAL: TMA Partition Tensor Mode Ordering ⚠️⚠️⚠️
|
||||
Production-grade Blackwell SM100 inference kernel for **DeepSeek-V4-Pro NVFP4**, written in CuTeDSL with a CUDA fallback path. Target hardware: NVIDIA B200 (180 GiB HBM3e).
|
||||
|
||||
**THIS BUG COST US AN ENTIRE DAY. READ THIS. BURN IT INTO YOUR BRAIN.**
|
||||
|
||||
After `cpasync.tma_partition()`, the output GMEM tensor has **4 modes** (verified on B200):
|
||||
|
||||
```
|
||||
tBgK shape: (((64, 128), 1), ?, KV_tiles, ?)
|
||||
mode 0 1 2 3
|
||||
```
|
||||
|
||||
**Mode 2 is the GMEM tile dimension.** The dimension you index with `kt` to load different K/V tiles.
|
||||
|
||||
### THE WRONG WAY (what we did — silently loads from tile 0 forever):
|
||||
|
||||
```python
|
||||
# ❌❌❌ (None,None,0,0) KEEPS MODES 0,1 FREE, SETS MODE 2 TO 0 ❌❌❌
|
||||
# Mode 2 (the KV tile dim) gets collapsed to coordinate 0.
|
||||
# TMA ALWAYS reads from tile 0.
|
||||
tBgK = tBgK[(None, None, 0, 0)] # ← WRONG! Mode 2 pinned to 0!
|
||||
|
||||
# The copy "works" but kv_coord indexes mode 1 (inner GEMM K, not KV tiles).
|
||||
cute.copy(tma_k, tBgK[(None, kv_coord)], ...) # ← kv_coord indexes wrong mode!
|
||||
```
|
||||
|
||||
### THE RIGHT WAY (verified on B200 at n=128 and n=256):
|
||||
|
||||
```python
|
||||
# ✅ (None,0,None,0) keeps modes 0 and 2 free → 2D tensor
|
||||
# Mode 2 (KV tiles) survives as the second mode.
|
||||
tBgK = tBgK[(None, 0, None, 0)]
|
||||
|
||||
# ✅ [None, kt] indexes the surviving mode 1 (originally mode 2 = KV tiles)
|
||||
cute.copy(tma_k, tBgK[None, kt], ...)
|
||||
# ^^ THIS IS THE KV TILE DIM
|
||||
```
|
||||
|
||||
**Verified shapes on B200 (May 22, n=256, inside @cute.kernel):**
|
||||
```
|
||||
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
|
||||
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
|
||||
```
|
||||
|
||||
### WHY THIS IS SO INSIDIOUS
|
||||
|
||||
1. **No error, no warning.** The slice `tBgK[(None,None,0,0)]` silently sets mode 2 to 0.
|
||||
2. **Single-tile (n=128) works perfectly.** With only 1 KV tile, mode 2 is size 1, so the bug is invisible.
|
||||
3. **Multi-tile tests produce "reasonable" output.** The TMA loads from tile 0 every time, so you get a valid (but wrong) attention computation. Cosine similarity is 0.7-0.9, not NaN.
|
||||
4. **The strides are all 0.** Printing `tBgK.layout.stride` shows all zeros for TMA tensors. You can't detect the bug from strides alone.
|
||||
5. **`cute.printf` shows `kv_coord=0`.** We thought the JIT was constant-folding the variable. It wasn't — the variable was fine, but it was indexing the wrong mode.
|
||||
6. **The 8-mode theory was wrong.** We assumed tma_partition produced 8 TMA coordinate dimensions. It produces 4. The 8-None no-op slice fails with "weakly congruent" at JIT compile.
|
||||
|
||||
### THE LESSON
|
||||
|
||||
**PRINT THE SHAPES. ALWAYS.** Run `print(f"tBgK: shape={cute.shape(tBgK)}")` inside `@cute.kernel` at trace time. The shapes are your ground truth. Reasoning about mode counts without evidence is how we wasted a day.
|
||||
|
||||
**The correct pre-slice depends on which mode is the GMEM tile iteration axis.** For our `local_tile` + `partition_B` + `group_modes(0,3)` pattern, mode 2 is the KV tile axis. `(None,0,None,0)` keeps it free. `(None,None,0,0)` collapses it to 0.
|
||||
|
||||
```python
|
||||
# ALWAYS verify the shape at trace time:
|
||||
print(f"tBgK shape: {cute.shape(tBgK)}") # 4 modes
|
||||
print(f"tBgK after slice: {cute.shape(tBgK[(None,0,None,0)])}") # 2 modes
|
||||
|
||||
# Then index the 2D tensor:
|
||||
cute.copy(tma_k, tBgK[None, kt], ...)
|
||||
```
|
||||
|
||||
**IF YOU USE (None,None,0,0) INSTEAD OF (None,0,None,0), MULTI-TILE TMA WILL BE SILENTLY BROKEN.**
|
||||
For what's done, what's blocked, and what's next, see **ROADMAP.md**. This file is the durable reference — architecture, design choices, package layout, workflow, and hard-won lessons. If you're touching the kernel, read the "Lessons learned" section every time.
|
||||
|
||||
---
|
||||
|
||||
## Architecture
|
||||
## DSV4 is not MLA
|
||||
|
||||
DSV4 is **not MLA**. It uses **CSA (Compressed Sparse Attention, m=4)** and **HCA (Heavily Compressed Attention, m′=128)**. KV latent is (T, 512) shared across all 128 heads. Sink weights merge sparse + SWA attention. vLLM misnames this as "MLA" — it is not. The architecture is fundamentally different.
|
||||
This cannot be repeated enough. vLLM and some integrations misname DSV4's attention as MLA. It is fundamentally a different architecture. If you reason about this kernel as MLA + extras, you will make wrong decisions.
|
||||
|
||||
```
|
||||
DSV4 inference pipeline — component status
|
||||
==========================================
|
||||
The differences that matter:
|
||||
|
||||
Legend:
|
||||
[✓] built and tested
|
||||
[~] partial — reference or seam exists, native pending
|
||||
[✗] to build
|
||||
| | MLA (V2/V3) | V4 |
|
||||
|---|---|---|
|
||||
| Compression axis | feature/head dim (per-token latent) | **sequence dim** (multiple tokens collapsed into one entry) |
|
||||
| Cache entries per token | one latent per token | one compressed entry per `m` tokens |
|
||||
| Attention pattern | dense over all cached latents | hybrid: sparse top-k (CSA) + dense over heavily-compressed (HCA) + sliding window (SWA) |
|
||||
| Compression rate | n/a (1:1) | m=4 for CSA, m'=128 for HCA |
|
||||
| Selection | none — all tokens attended | lightning indexer + top-k for CSA |
|
||||
| Output positional fix | n/a | inverse RoPE on each per-head output |
|
||||
| Sink merge | n/a | per-head learnable attention sink merged via single softmax over `[S_comp, S_swa + sink]` |
|
||||
|
||||
|
||||
┌────────────────────────────────────┐
|
||||
│ [✗] Embedding + mHC init │
|
||||
│ token embed + n_hc=4 streams │
|
||||
└────────────────┬───────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─ Transformer layer × L ──────────────────────────────────────────────┐
|
||||
│ HCA on layers 0–1 of Pro, alternating CSA / HCA after │
|
||||
│ │
|
||||
│ ┌─ Attention sub-block ──────────────────────────────────────────┐ │
|
||||
│ │ [✓] Residual mHC pre + post mix │ │
|
||||
│ │ [~] Norms + RoPE RMSNorm + partial RoPE │ │
|
||||
│ │ [✓] Q / KV projection NVFP4 linears + LoRA │ │
|
||||
│ │ [✓] Token compressor CSA m=4 / HCA m′=128 │ │
|
||||
│ │ [✓] Indexer + top-k CSA, FP32 dot + top-k │ │
|
||||
│ │ [~] FMHA core QK → online softmax → PV │ │
|
||||
│ │ + SWA branch + sink merge │ │
|
||||
│ │ [✓] Output projection inv RoPE + wo_a grouped + wo_b │ │
|
||||
│ └────────────────────────────────────────────────────────────────┘ │
|
||||
│ │
|
||||
│ ┌─ FFN sub-block ────────────────────────────────────────────────┐ │
|
||||
│ │ [✓] Residual mHC pre + post mix │ │
|
||||
│ │ [✓] Pre-FFN norm RMSNorm │ │
|
||||
│ │ [✓] Router sqrt(softplus) + topk + hash │ │
|
||||
│ │ [✓] Routed MoE fused SwiGLU L1 + L2 │ │
|
||||
│ │ [✓] Shared expert NVFP4 single-group GEMM │ │
|
||||
│ └────────────────────────────────────────────────────────────────┘ │
|
||||
└──────────────────────────────────┬───────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────────────────────────────────────────┐
|
||||
│ [✗] Final RMSNorm → [✗] LM head → [✗] MTP (depth=1) → [✗] Sampler │
|
||||
└──────────────────────────────────────────────────────────────────────┘
|
||||
|
||||
┌─ Supporting infrastructure ──────────────────────────────────────────┐
|
||||
│ [✗] KV cache management │
|
||||
│ • state cache: SWA window + uncompressed tail per layer │
|
||||
│ • classical paged cache: lcm(m, m′) = 128 tokens per block │
|
||||
│ • heterogeneous layout per layer │
|
||||
└──────────────────────────────────────────────────────────────────────┘
|
||||
|
||||
|
||||
Summary
|
||||
-------
|
||||
Built [✓] : 9 — mHC ×2, Q/KV proj, output proj, routed MoE,
|
||||
shared expert, token compressor, indexer+topk,
|
||||
router, pre-FFN norm
|
||||
Partial [~] : 3 — norms+RoPE, FMHA core
|
||||
To build [✗] : 6 — embedding+init, final norm, LM head, MTP, sampler, KV cache
|
||||
```
|
||||
Cache layout reflects this: per-layer **state cache** for SWA window + uncompressed tail (used for CSA/HCA compression), plus a **classical paged cache** holding compressed CSA/HCA entries, with block size = `lcm(m, m') = 128` original tokens per block.
|
||||
|
||||
---
|
||||
|
||||
## Status (May 26, 2026 — 18:40 UTC)
|
||||
## DSV4 architecture (paper-side reference)
|
||||
|
||||
| Stage | Status | Description |
|
||||
|-------|--------|-------------|
|
||||
| A | ✅ COMPLETE | Q@K^T via tcgen05.mma → TMEM → GMEM |
|
||||
| B | ✅ COMPLETE | QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving) |
|
||||
| C | ✅ COMPLETE | Real online softmax. Kernel outputs un-norm O + LSE (no TMEM round-trip). Migrated to `dsv4/kernels/attention/fmha.py` as `FmhaKernel`. |
|
||||
| D1 | 🟡 hd≤256 DONE | Parameterized HEAD_DIM. qk_mma_tiler fix (hd=64/128/256 cos 0.999998). hd=512 SMEM fits but MLIR compilation hangs (>3hr). External k_sub merge proven impossible. O rescale TMEM round-trip BROKEN (Ld32x32bOp/St32x32bOp corrupt data). Python KV merge workaround works. |
|
||||
| D1.5 | ❌ BLOCKER | O rescale for multi-KV-tile (kt>0). TMEM round-trip corruption (even NO-OP round-trip fails). Python KV merge workaround: cos 0.999994. Production: 5-9 kernel launches per decode. Fix requires correction epilog (one-way TMEM→regs→SMEM→GMEM). |
|
||||
| D2 | 🟡 Per-head DONE | Head-packed M-dimension launch (cos 0.999995, n_h=1-128). Multi-CTA grid blocked: `flat_divide` + `epilogue_tma_store` layout mismatch. |
|
||||
| D3 | ✅ DONE | SWA sequence length mask (in-kernel post-QK via tTMEM_LOADcS coordinates, swa_len Int32 scalar, offset by n_comp for D5c) |
|
||||
| D4 | ✅ DONE | Causal mask on SWA branch (SWA-relative position > m_coord → -inf, combined with D3 via OR logic) |
|
||||
| D5 | ✅ D5a+D5b+D5c DONE | D5a: normalize flag + LSE + row_sums output. D5b: Per-row LSE + Python KV merge (cos 0.999994). D5c: Sink bias as logit modification — mathematically equivalent to separate merge, single pass over combined KV (cos 0.999996 single-tile AND multi-tile). D5d (fused in-kernel merge) NOT NEEDED — sink bias approach supersedes it. |
|
||||
| E1-E7 | TODO | Production extraction (class, custom op, cache, cleanup) |
|
||||
| NVFP4-3 | ✅ DONE | `use_2cta_instrs` conditional in gemm_runner.py. 1.7-1.9× throughput at prefill shapes. |
|
||||
The bits the kernel implements, with the choices we made for inference.
|
||||
|
||||
### Per-layer attention type schedule
|
||||
|
||||
```
|
||||
Flash (43 layers): layers 0-1 = SWA, layers 2..42 alternating CSA/HCA (CSA at layer 2)
|
||||
Pro (61 layers): layers 0-1 = HCA, layers 2..60 alternating CSA/HCA (CSA at layer 2)
|
||||
```
|
||||
|
||||
Frozen at construction time per `LayerSpec` so torch.compile constant-folds the dispatch. Validation in `dsv4/model/layer_schedule.py:validate_schedule` is loud — wrong schedule = silent garbage.
|
||||
|
||||
### Compressed Sparse Attention (CSA)
|
||||
|
||||
- Compresses every `m=4` KV entries into one via a token-level learned softmax with overlapping window (current m + previous m). See eq. 11–12 of the paper.
|
||||
- Compressed sequence length is `n/m`.
|
||||
- **Lightning indexer** scores each query against compressed blocks via weighted ReLU MQA logits (eq. 16). Top-k selector keeps `csa_top_k` blocks (512 Flash / 1024 Pro).
|
||||
- Core attention is MQA over the selected blocks + a sliding window branch of `n_win=128` raw tokens.
|
||||
- Partial RoPE on the last 64 dims of Q and the compressed K, with **inverse RoPE on each per-head output** so the per-token contribution carries the correct relative position.
|
||||
- Per-head attention sink: learnable logit added to the softmax denominator (eq. 27). We merge sparse + SWA via the sink-bias-as-logit trick — see "Sink merge" below.
|
||||
|
||||
### Heavily Compressed Attention (HCA)
|
||||
|
||||
- Same compressor concept as CSA but `m'=128`, no overlap, dense attention over the (very short) compressed sequence.
|
||||
- No indexer.
|
||||
- Same partial RoPE + inverse RoPE + sliding window + sink as CSA.
|
||||
|
||||
### Sliding Window Attention (SWA)
|
||||
|
||||
- First two layers of Flash. Pure local attention over the SWA window. No compressed branch, no indexer.
|
||||
- Cache layout: ring buffer of size `n_win` per request in the state cache.
|
||||
|
||||
### Manifold-Constrained Hyper-Connections (mHC)
|
||||
|
||||
- Replaces residual connections. Width-expanded residual stream `(T, n_hc=4, d)`.
|
||||
- Per-token dynamic `A_l`, `B_l`, `C_l` mixing matrices generated by a fused 24-output prenorm projection (4 + 4² + 4).
|
||||
- `A_l = σ(.)`, `C_l = 2σ(.)`, `B_l = SinkhornKnopp(exp(.), t_max=20)` to project onto the Birkhoff polytope.
|
||||
- `pre_block`: `x_in = A_l @ X_l`; `post_block`: `X_next = B_l @ X_l + C_l ⊗ F_out`.
|
||||
- `B_l` held in FP32 for the bmm precision; A/C cast to BF16.
|
||||
|
||||
### Router
|
||||
|
||||
- Two modes, frozen at construction by layer index:
|
||||
- **Hash routing** (layers 0–2): deterministic per-token-ID LUT lookup, uniform weights `1/k`.
|
||||
- **Dense routing** (layers 3+): `sqrt(softplus(X @ W_gate))` activation, plus learned `e_bias` for *selection only*. Top-k (k=6), renormalize on unbiased activations, multiply by `routed_scaling_factor`.
|
||||
|
||||
### MoE
|
||||
|
||||
- DeepSeekMoE: shared expert + N routed experts (Flash 256, Pro 384), 6 activated per token.
|
||||
- L1 GEMM (gate + up interleaved at granularity 8) → SwiGLU → L2 GEMM (down).
|
||||
- SwiGLU clamping per paper §4.2.3: gate capped at `swiglu_limit=10`, linear clamped to `[-limit, +limit]`.
|
||||
- All weights NVFP4, FP8 E4M3 scales, 16-element microblocks.
|
||||
|
||||
### Sink merge (D5c — key insight)
|
||||
|
||||
The paper writes the sink merge as a weighted combination of two separate softmax outputs. But because the sink is just an additive logit bias on one branch, the whole thing collapses to a **single softmax over `[S_comp, S_swa + attn_sink]`**.
|
||||
|
||||
One pass, one kernel. No two-loop epilogue, no LSE arithmetic in the merge. This is why D5d (fused merge epilogue) is not needed.
|
||||
|
||||
---
|
||||
|
||||
## Package Structure
|
||||
## Our kernel design choices
|
||||
|
||||
### Attention kernel (FmhaKernel)
|
||||
|
||||
**6-warp specialization.** Warps 0–3 handle softmax + correction + epilogue. Warp 4 is the MMA warp (QK + PV). Warp 5 is the TMA warp (Q/K/V loads, output store via pipeline).
|
||||
|
||||
**P staging — two paths.**
|
||||
- **TMEM-P** (hd ≤ 64): P stored to TMEM via register bridge (FP32 backing + BF16 view). PV reads P from TMEM. Used at the small head dims where QK C-fragment and PV A-fragment TMEM layouts agree.
|
||||
- **SMEM-P** (hd > 64): P written to SMEM via coordinate-indexed store using `tTMEM_LOADcS` to map register indices to `(m, k)` then into `sP`'s subtile layout. PV reads P from SMEM with `OperandSource.SMEM`. Required because the QK ↔ PV TMEM layout disagreement at hd > 64 corrupts the round-trip.
|
||||
|
||||
**Un-normalized O + LSE output.** The kernel emits raw `sum(P · V)` and `lse = ln(row_sum) + row_max · ln(2)`. External code (or the next kernel pass) divides. This composes — D5 merge, multi-tile rescale, and the inverse-RoPE → wo_a fuse all rely on it.
|
||||
|
||||
**Per-head launch for multi-head.** Python loop dispatches the single-CTA kernel once per head. Multi-CTA grid using `flat_divide` + `tma_partition` is the next refactor (see ROADMAP); the path is unblocked once the correction-epilog rewrite lands.
|
||||
|
||||
**Head-packed M dimension for decode.** Q reshaped to `(n_h * T, hd, 1)`, all heads' rows packed into the 128-row M tile. Per-row softmax. At Pro decode (T=1, n_h=128) the M tile fits exactly.
|
||||
|
||||
**K-dim sub-tiling at hd > 256.** When `head_dim > 256` (MMA instruction K-dim limit), Q and K split into `n_k_sub_tiles = head_dim / 256` chunks along head_dim. QK accumulates in TMEM across sub-tiles (additive in logit space). The PV path uses `pv_n_tile = 128` for hd > 256 to keep sV+sC within the 232 KB SMEM budget.
|
||||
|
||||
**Sink bias as logit modification.** D3 (SWA length mask), D4 (causal mask on SWA), and D5c (attention sink) all live in the same post-QK, pre-softmax in-register code. They read `tTMEM_LOADcS` to get `(m, k)` coordinates and modify `tTMEM_LOADrS` before the row-max reduction. The sink bias is added in the raw-logit domain as `attn_sink / scale_softmax`, then the existing `* scale_log2` multiply converts to log2 space.
|
||||
|
||||
### MoE kernel (FusedSwiGLUScaledGroupedGemmKernel)
|
||||
|
||||
**7-warp specialization.** Warps 0–3 epilogue (TMEM → registers → SMEM → GMEM with global scale, SwiGLU, clamp). Warp 4 MMA (`tcgen05.mma.block_scale` with SFA/SFB in TMEM). Warp 5 TMA load (A, B, SFA, SFB). Warp 6 scheduler (`MoEStaticPersistentTileScheduler`).
|
||||
|
||||
**One-way TMEM → registers → SMEM → GMEM epilogue.** Uses `epilogue_tmem_copy_and_partition` + `epilogue_smem_copy_and_partition` (CUTLASS helpers, paired atoms). The SwiGLU + clamping math runs in registers between the t2r and r2s copies. No TMEM round-trip. This is the same pattern FMHA needs to adopt to fix the D1.5 blocker — see ROADMAP.
|
||||
|
||||
**Subtile-level gate/up pairing.** With granularity-8 interleaved L1 weights and `epi_tile_n=8`, even subtiles are gate and odd subtiles are up. `silu_gate_buf` register tensor carries the SiLU result across the subtile-pair boundary.
|
||||
|
||||
**`use_2cta_instrs` conditional** on `tokens_sum ≥ 256` and even `cluster_m`. Decode (small M) stays 1-CTA; prefill/batched gets 2-CTA UMMA with multicast B (1.7–1.9× throughput).
|
||||
|
||||
### Heterogeneous KV cache
|
||||
|
||||
- **State cache** per request: fixed-size block holding `(n_win SWA KV)` and `(uncompressed tail tokens awaiting compression)`. One block per request, lifetime managed by request scheduling.
|
||||
- **Classical paged cache** per request: variable blocks holding `(k1 CSA compressed entries, k2 HCA compressed entries)` per layer. `k1 = lcm(m, m') / m = 32`, `k2 = lcm(m, m') / m' = 1`. Block covers 128 original tokens.
|
||||
- Different layers can produce different KV cache sizes (CSA vs HCA vs SWA-only). The state cache + classical-pool split keeps PagedAttention-style alignment intact for the compressed pool.
|
||||
|
||||
### NVFP4 throughout
|
||||
|
||||
- **Weights**: NVFP4 (FP8 E4M3 scales, 16-element microblocks). Verified: `sf_dtype`, TMA element type, MMA kind (`mxf4nvf4`) all correct.
|
||||
- **Activations**: BF16 today, FP4 after NVFP4-1.x epilogue fusion lands (see ROADMAP).
|
||||
- **KV cache**: BF16 today; the FP8 (RoPE in BF16, NoPE in FP8) split per paper §2.3.4 is on the roadmap as NVFP4-2.
|
||||
- **Indexer keys**: stored FP4 in the cache today, but scored with a scalar CUDA-core kernel. Tensor-core FP4 scoring (paper §5.2.1) is a Stage F priority.
|
||||
|
||||
---
|
||||
|
||||
## Package structure
|
||||
|
||||
```
|
||||
dsv4/
|
||||
├── kernels/ Pure GPU code (CuTeDSL @cute.jit, .cu files)
|
||||
│ ├── gemm/ NVFP4 MoE GEMM kernels (grouped, fused_swiglu, dense, scheduler)
|
||||
│ ├── attention/ FMHA kernel — FmhaKernel (hd=64, TMEM-P proven; SMEM-P stub for hd>64)
|
||||
│ ├── compressor/ CSA/HCA token-level compressor (CuTeDSL, 419 lines)
|
||||
│ ├── indexer/ CSA indexer — score+topk (FP32 dot products, top-k selection)
|
||||
│ ├── attention/ FMHA — FmhaKernel (hd=64/128/256 proven, hd=512 MLIR-blocked)
|
||||
│ ├── gemm/ NVFP4 MoE GEMM (grouped, fused_swiglu, dense, scheduler)
|
||||
│ ├── compressor/ CSA/HCA token-level compressor (CuTeDSL)
|
||||
│ ├── indexer/ CSA indexer score+topk (FP32 scalar today; tensor-core FP4 on roadmap)
|
||||
│ ├── router/ Dense router decode kernel (warp-specialized persistent GEMM)
|
||||
│ ├── cache/ Cache kernels — append_swa (write KV to split state cache layout)
|
||||
│ ├── decode/ Decode-time attention (sparse, SWA — future)
|
||||
│ └── cuda/ Raw .cu files (deinterleave_quantize, sparse_topk_metadata)
|
||||
│ ├── cache/ append_swa (writes KV to state cache)
|
||||
│ ├── decode/ Decode-time attention (future)
|
||||
│ └── cuda/ Raw .cu (deinterleave_quantize, sparse_topk_metadata, etc.)
|
||||
├── ops/ PyTorch ↔ kernel bridges
|
||||
│ ├── quantize.py BF16 ↔ NVFP4 conversion, scale factors
|
||||
│ ├── quantize.py BF16 ↔ NVFP4, scale factor handling
|
||||
│ ├── layouts.py Scale swizzle, gate/up interleave, K-major, offsets
|
||||
│ ├── gemm_runner.py Warmup, compile, run grouped/fused GEMMs
|
||||
│ ├── custom_ops.py torch.library.custom_op registrations
|
||||
│ ├── decode_sparse.py native_sparse_decode dispatcher
|
||||
│ ├── decode_swa.py native_swa_decode dispatcher
|
||||
│ ├── rope.py Forward + inverse RoPE
|
||||
│ ├── topk.py Python wrapper for sparse_topk_metadata.cu
|
||||
│ ├── topk_select.py Top-k selection wrapper
|
||||
│ ├── rope.py Forward + inverse RoPE (partial, last 64 dims)
|
||||
│ ├── topk.py Sparse top-k metadata wrapper
|
||||
│ └── router.py Router op bridge
|
||||
├── layers/ nn.Module-style components
|
||||
│ ├── linear.py Nvfp4Linear
|
||||
│ ├── grouped_linear.py Nvfp4GroupedLinear
|
||||
│ ├── moe.py Nvfp4MoE
|
||||
│ ├── grouped_linear.py Nvfp4GroupedLinear (output projection)
|
||||
│ ├── moe.py Nvfp4MoE (routed experts)
|
||||
│ ├── shared_expert.py Nvfp4SharedExpert
|
||||
│ ├── mhc.py mHCLayer
|
||||
│ ├── attention.py DSV4 attention sub-block (CSA/HCA/SWA variants, 245 lines)
|
||||
│ ├── norm.py RMSNorm (PyTorch ref, fused kernel later)
|
||||
│ ├── router.py Router — token-to-expert assignment (273 lines)
|
||||
│ ├── embedding.py Token embedding + mHC init (stub)
|
||||
│ ├── mhc.py mHCLayer (Sinkhorn-Knopp, residual mixing)
|
||||
│ ├── attention.py AttentionSubBlock (CSA/HCA/SWA variants by LayerSpec)
|
||||
│ ├── norm.py RMSNorm
|
||||
│ ├── router.py Router (dense + hash modes)
|
||||
│ ├── embedding.py Token embedding + mHC init
|
||||
│ └── ffn.py FFN sub-block
|
||||
├── model/ Model assembly
|
||||
│ ├── config.py Model config
|
||||
│ ├── layer.py Transformer layer
|
||||
│ ├── layer_schedule.py Layer scheduling
|
||||
│ ├── config.py DSV4Config
|
||||
│ ├── layer.py TransformerLayer
|
||||
│ ├── layer_schedule.py LayerSpec, AttentionType, build_schedule, validate_schedule
|
||||
│ ├── mtp.py Multi-token prediction
|
||||
│ ├── sampler.py Token sampler
|
||||
│ └── dsv4.py Full model (stub — Phase 1)
|
||||
│ └── dsv4.py Full model
|
||||
├── cache/ KV cache infra
|
||||
│ ├── allocator.py Cache memory allocator
|
||||
│ ├── allocator.py Memory allocator
|
||||
│ ├── block_table.py Paged cache block table
|
||||
│ ├── flush.py Cache flush
|
||||
│ ├── handle.py Cache handle
|
||||
│ ├── manager.py Cache manager
|
||||
│ ├── paged_cache.py Paged KV cache
|
||||
│ ├── prepare_forward.py Forward prep
|
||||
│ ├── schema.py Cache schema
|
||||
│ └── state_cache.py State cache (SWA ring buffer)
|
||||
│ ├── paged_cache.py Classical paged cache (CSA/HCA)
|
||||
│ ├── state_cache.py State cache (SWA + uncompressed tail)
|
||||
│ ├── schema.py, handle.py, flush.py, prepare_forward.py
|
||||
├── loader/ Checkpoint I/O
|
||||
│ ├── hf_checkpoint.py HuggingFace checkpoint loader
|
||||
│ └── layout_convert.py Weight layout conversion
|
||||
│ ├── hf_checkpoint.py
|
||||
│ └── layout_convert.py
|
||||
└── reference/ Slow PyTorch oracles (never imported by production code)
|
||||
├── attention.py RoPE, KV cache, causal attention, SWA
|
||||
├── csa_attention.py CSA/HCA sparse attention
|
||||
├── compressor.py Compressor PyTorch example
|
||||
└── moe_pipeline.py MoE pipeline reference
|
||||
├── attention.py, csa_attention.py, compressor.py, moe_pipeline.py
|
||||
```
|
||||
|
||||
**Mental model:** `kernels/` → `ops/` → `layers/` → `model/` (dependency flows left to right). `reference/` and `loader/` are sidecars.
|
||||
**Dependency arrow:** `kernels/` → `ops/` → `layers/` → `model/`. `reference/` and `loader/` are sidecars.
|
||||
|
||||
---
|
||||
|
||||
## Active Test Files
|
||||
## Workflow & test harness
|
||||
|
||||
### FMHA (Stages A/B/C/D1) — in `tests/unit/`
|
||||
### The non-negotiables
|
||||
|
||||
| File | Stage | Status |
|
||||
|------|-------|--------|
|
||||
| `test_fmha_v3.py` | A+B | ✅ Full QK→identity softmax→PV, cosine 0.999999 |
|
||||
| `test_fmha_v3_12w.py` | A+B | ✅ 12-warp QK→PV, cosine 0.999999 |
|
||||
| `test_fmha_v3_stage_c.py` | C | ✅ Real online softmax + normalize, n=128 cos 0.973. **Also in module as `FmhaKernel`.** |
|
||||
| `test_fmha_v3_stage_d1.py` | D1 | ✅ hd=64/128/256 PASS (cos 0.999998, TMEM-P). hd=512 SMEM overflow. |
|
||||
| `test_fmha_v3_stage_d5b.py` | D5b | ✅ Python SWA+sink merge (cos 0.999994, LSE err=0.0) |
|
||||
| `test_d5c_fused.py` | D5c | ✅ Single-tile combined KV + sink bias (cos 0.999996) |
|
||||
| `test_d5c_multitile.py` | D5c | ✅ Multi-tile with Python KV merge + sink bias (cos 0.999996) |
|
||||
| `test_d1_*.py` | D1 | 🔨 Debug/diagnostic variants (hd512, regression, sweep, raw, debug) |
|
||||
| `test_paired_epilog.py` | C | ✅ Paired atom epilogue experiments |
|
||||
| `test_pv64_with_softmax.py` | B | ✅ (128,64) PV, single AB pipeline |
|
||||
| `test_128_128_vdiag.py` | A+B | ✅ (128,128) PV baseline |
|
||||
| `test_qkonly.py` | A | ✅ QK with split Q/KV pipelines |
|
||||
| `test_qk_softmax.py` | A+B | ✅ QK + identity softmax, no PV |
|
||||
- **NEVER edit on the B200.** Always: edit locally → commit → push → pull on B200 → test.
|
||||
- **ALWAYS use the test harness** (`fire_b200_test`, `run_test.sh`, `check_log.sh`). Never raw SSH+nohup. Nohup does not survive SSH drops; screen sessions do.
|
||||
- **ALWAYS verify hd=64 regression** (cos ~0.999998) after every FMHA change. If it regresses, the change is wrong. Revert.
|
||||
- **NEVER touch drivers, kernels, firmware, or system packages** on the B200.
|
||||
- **NEVER delete test files** in `tests/unit/` without explicit approval.
|
||||
|
||||
### MoE / GEMM — in `tests/unit/`
|
||||
|
||||
| File | What |
|
||||
|------|------|
|
||||
| `test_cutedsl.py` | NVFP4 grouped GEMM kernel |
|
||||
| `cudagraph_test.py` | Cudagraph capture + replay |
|
||||
| `layertest.py` | Per-layer correctness |
|
||||
| `test_custom_op.py` | torch.library custom ops |
|
||||
| `test_compile_custom_op.py` | Compile + warmup |
|
||||
| `test_fp4_roundtrip.py` | BF16 → NVFP4 → BF16 roundtrip |
|
||||
| `test_interleave.py` | Gate/up weight interleaving |
|
||||
| `test_interleave_gemm.py` | Interleaved GEMM correctness |
|
||||
| `test_fused_step1.py` | Fused SwiGLU GEMM |
|
||||
|
||||
---
|
||||
|
||||
## Test Harness
|
||||
|
||||
Scripts in `tests/` for running tests on the B200 (`root@45.76.247.107`):
|
||||
|
||||
### `run_test.sh` — Run a test in a screen session
|
||||
|
||||
```bash
|
||||
# On the B200:
|
||||
cd /root/dsv4-nvfp4-workspace/kernel
|
||||
bash tests/run_test.sh tests/unit/test_fmha_v3.py
|
||||
```
|
||||
|
||||
What it does:
|
||||
1. Kills any existing `kernel-test` screen and **SIGKILLs all child processes** (handles deadlocked GPU procs that ignore SIGHUP)
|
||||
2. Deletes the old log file
|
||||
3. Starts a new `screen -dmS kernel-test` running the test
|
||||
4. Logs output to `/tmp/kernel-test.log`
|
||||
5. Verifies the screen started
|
||||
|
||||
### `check_log.sh` — Check test progress
|
||||
|
||||
```bash
|
||||
bash tests/check_log.sh
|
||||
```
|
||||
|
||||
Shows the log contents and whether the screen is still running.
|
||||
|
||||
### Local → B200 workflow
|
||||
### Local → B200 cycle
|
||||
|
||||
```bash
|
||||
# 1. Edit locally, commit, push
|
||||
cd ~/dev/nvfp4-megamoe-kernel
|
||||
git add -A && git commit -m "my change" && git push
|
||||
git add -A && git commit -m "description" && git push origin master
|
||||
|
||||
# 2. SSH to B200, pull, run
|
||||
ssh root@45.76.247.107
|
||||
# 2. One-command test (auto-pushes, runs, dumps log)
|
||||
~/.openclaw/workspace/fire_b200_test tests/unit/test_fmha_v3_stage_c.py
|
||||
```
|
||||
|
||||
### Manual B200 cycle
|
||||
|
||||
```bash
|
||||
ssh root@<B200>
|
||||
cd /root/dsv4-nvfp4-workspace/kernel && git pull
|
||||
bash tests/run_test.sh tests/unit/test_fmha_v3_stage_c_full.py
|
||||
|
||||
# 3. Check results
|
||||
bash tests/run_test.sh tests/unit/test_<...>.py
|
||||
bash tests/check_log.sh
|
||||
```
|
||||
|
||||
### `fire_b200_test` — One-command local test runner
|
||||
`run_test.sh` kills any prior `kernel-test` screen (with SIGKILL on stuck GPU procs), deletes the old log, starts a fresh `screen -dmS kernel-test`, and logs to `/tmp/kernel-test.log`.
|
||||
|
||||
Lives in `~/.openclaw/workspace/fire_b200_test` (NOT in the repo — project-specific tooling).
|
||||
### Environment
|
||||
|
||||
```bash
|
||||
# From your local machine, one command to push, run, and get results:
|
||||
~/.openclaw/workspace/fire_b200_test tests/unit/test_fmha_v3.py
|
||||
```
|
||||
|
||||
What it does:
|
||||
1. Auto-commits and pushes any local changes
|
||||
2. SSH to B200, pulls, starts `run_test.sh` in a screen
|
||||
3. Polls every 15s until the screen exits
|
||||
4. Dumps the full test log to your terminal
|
||||
|
||||
**This is strictly for the DSV4 NVFP4 kernel project.** It hardcodes the B200 IP, repo paths, and git remote.
|
||||
- **B200 access**: see `MEMORY.md` (not committed).
|
||||
- **venv**: `source /root/dsv4-nvfp4-workspace/venv/bin/activate`
|
||||
- **PYTHONPATH**: `/root/dsv4-nvfp4-workspace/kernel`
|
||||
- **Model**: `/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4`
|
||||
- **vLLM** (modified for Blackwell): `/root/dsv4-nvfp4-workspace/vllm`
|
||||
- **CUTLASS FMHA reference**: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py`
|
||||
- **Local CUTLASS clone**: `/home/openclaw/dev/cutlass`
|
||||
|
||||
---
|
||||
|
||||
## CuTeDSL constraints (read every session)
|
||||
|
||||
## Stage C: Online Softmax — TMEM Layout Mismatch Issue
|
||||
These are surface-level traps. Get them wrong and the kernel silently produces garbage, NaN, or "weakly congruent" at JIT compile time.
|
||||
|
||||
### Current Results (test_fmha_v3_stage_c.py)
|
||||
1. **TMA partition tensors have 4 modes**: `(((64,128),1), ?, KV_tiles, ?)`. `(None, 0, None, 0)` keeps mode 2 (KV tiles) free; `[None, kt]` indexes it. `(None, None, 0, 0)` silently pins mode 2 to 0 — multi-tile loads break invisibly.
|
||||
|
||||
| n | cos | Status |
|
||||
|---|-----|--------|
|
||||
| 128 | 0.973 | ⚠️ 3% error from TMEM layout mismatch |
|
||||
| 256 | 0.793 | ⚠️ Two TMEM round-trips compound the error |
|
||||
| 384+ | N/A | Pipeline doesn't cycle past 2 KV tiles |
|
||||
2. **`vectorize=True` loops accept only load/store/print.** No `fmax`, no `cmpf`, no inner loops, no carry across iterations.
|
||||
|
||||
### Root Cause: TMEM Layout Mismatch
|
||||
3. **`.reduce(cute.ReductionOp.MAX)` reduces the entire C-fragment to a scalar** — global, not per-row. Use a plain `range()` loop with `cute.arch.fmax` for per-row max.
|
||||
|
||||
The MMA instruction writes O to TMEM using the **C-fragment layout**. The `epilogue_tma_store` helper reads O from TMEM using `get_tmem_load_op`, which uses the **correct** C-fragment-compatible layout. **Raw PV output is perfect (cos 0.999998)** when `epilogue_tma_store` reads directly without any round-trip.
|
||||
4. **`cute.arch.fmax` is impure** for the vectorizer. Use it inside plain `range()`, never inside `vectorize=True`.
|
||||
|
||||
The problem appears when we do a **TMEM round-trip** (load O → modify → store back) using hand-constructed `Ld32x32bOp/St32x32bOp` atoms. These atoms use a different column mapping than the MMA's C-fragment layout, causing ~3% data corruption per round-trip. Both the NO-OP round-trip (previously used to "fix" layout) and the normalize round-trip (multiply by 1/row_sum) suffer from this error.
|
||||
5. **Hand-constructed TMEM atoms corrupt data on round-trip.** Independently-built `Ld32x32bOp` + `St32x32bOp` atoms have addressing that doesn't match — even a NO-OP round-trip drops cos to ~0.97. Use paired atoms from `epilogue_tmem_copy_and_partition` / `epilogue_smem_copy_and_partition` for one-way trips. This is the D1.5 blocker in ROADMAP.
|
||||
|
||||
**Fix proven but not yet integrated:** The `epilogue_tmem_copy_and_partition` + `epilogue_smem_copy_and_partition` pattern from CUTLASS's `cutlass.utils.gemm.sm100` reads O from TMEM using the correct `get_tmem_load_op` layout and writes to SMEM using `get_smem_store_op`. This is a one-way trip (TMEM→reg→SMEM→GMEM) that eliminates the layout mismatch entirely. Integration requires proper `flat_divide` and `tma_partition` handling inside the kernel's warp-specific if blocks.
|
||||
6. **CuTeDSL `if` blocks are separate MLIR regions.** Variables defined inside one `if` are not visible in another, even when the condition is a compile-time constant. Define all variables unconditionally before any branching.
|
||||
|
||||
### Key Bug Fix: tOrP0 TMEM Column Offset (May 23)
|
||||
7. **Guard dead code with `const_expr`.** CuTeDSL compiles both branches of Python `if`. At hd=64, the SMEM-P or O-rescale code generates IR you don't need; without `const_expr`, MLIR chews on it.
|
||||
|
||||
The softmax warps store P at `tmem_p0_offset=32` FP32 columns (64 BF16 elements). PV MMA must read from the same offset. **`tOrP0` was missing this offset**, causing PV to read from TMEM column 0 (where S is) instead of column 32 (where P is). This was the root cause of NaN/zeros in D1 tests. Fixed with:
|
||||
```python
|
||||
if const_expr(self.tOrP0_offset > 0):
|
||||
tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout)
|
||||
else:
|
||||
tOrP0 = tOrP
|
||||
```
|
||||
Must use `const_expr` conditional (not Python `if`) because CuTeDSL compiles both branches, and `tOrP.iterator + 0` fails with MLIR type error.
|
||||
8. **`tma_partition` and `flat_divide` may not survive inside `if warp_idx` blocks.** Construct partitioned tensors before warp branching, or in a regular Python helper function. (The MoE kernel calls `tma_partition` inside the epilogue warp's `if`, so this constraint may depend on context — print and verify.)
|
||||
|
||||
### Architecture (6-warp, current)
|
||||
9. **TMEM allocation must be a power of 2.** Round up after summing column requirements.
|
||||
|
||||
```
|
||||
Warps 0-3: Softmax + Epilogue (row_max, row_sum, P store, O rescale, final normalize)
|
||||
Warp 4: MMA (QK, PV)
|
||||
Warp 5: TMA (Q/K/V load)
|
||||
```
|
||||
10. **`composition` vs `logical_divide` produce different layouts** even when re-tiling the same tensor. `correction_rescale` uses `composition`, `correction_epilog` uses `logical_divide`. Copy atoms must match the tensor layout they were created with.
|
||||
|
||||
### TMEM Layout
|
||||
11. **After every P store to TMEM, call `cute.arch.fence_view_async_tmem_store()`.** Missing this produces NaN.
|
||||
|
||||
```
|
||||
Col 0-31: S (QK acc, 128 FP32 via Ld32x32bOp Repetition(32))
|
||||
Col 32-95: P (64 FP32 via St32x32bOp Repetition(32), register bridge BF16 view)
|
||||
Col 128+: O (PV acc, 64 FP32, rescale via Ld32x32bOp Repetition(16))
|
||||
```
|
||||
12. **`St32x32bOp` must use Float32, not BFloat16.** BFloat16 causes illegal memory access.
|
||||
|
||||
### Remaining for Multi-Tile Production
|
||||
13. **First PV must have `ACCUMULATE=False`.** Otherwise adds uninitialized TMEM contents to the output.
|
||||
|
||||
1. **Fix TMEM layout mismatch** — replace hand-constructed atom round-trips with correction_epilog pattern
|
||||
2. **Pipeline state cycling for n≥384** — kv_stage=2 can only buffer 2 tiles
|
||||
3. **12-warp layout** — separate softmax/correction/epilogue warps
|
||||
4. **O rescale for kt > 0** — must also use paired atoms or correction_epilog
|
||||
14. **`find_tmem_tensor_col_offset()` returns footprint size, not a safe offset.** Never use it as a TMEM placement.
|
||||
|
||||
15. **FMHA never trusts DLPack tensor layouts.** Reconstruct V as `(hd, s_k)` MN-major inside CuTe via explicit `make_tensor` + `make_layout`.
|
||||
|
||||
---
|
||||
|
||||
## CuTeDSL Constraints (hard-won)
|
||||
## Lessons learned (the gold — read every session)
|
||||
|
||||
1. **`vectorize=True` loops: ONLY load/store/print** — no fmax, no cmpf, no inner loops, no carry
|
||||
2. **`.reduce(cute.ReductionOp.MAX)`:** reduces ENTIRE C-fragment to scalar — global max, not per-row
|
||||
3. **`cute.arch.fmax`:** impure for vectorizer — use plain `range()` loop
|
||||
4. **TMA partition tensors have 4 modes:** `(((64,128),1), ?, KV_tiles, ?)` — `(None,0,None,0)` keeps mode 2 (KV tiles) free, `[None, kt]` indexes it
|
||||
5. **`tBgK[(None, None, 0, 0)]` pins mode 2 to 0** — silently reads tile 0 forever. Use `(None,0,None,0)` instead.
|
||||
6. **`softmax_done_bar` NamedBarrier is reusable** across tiles
|
||||
7. **Hand-constructed TMEM atoms corrupt data on round-trip:** `Ld32x32bOp` + `St32x32bOp` built independently introduce ~3% error. Use `get_tmem_load_op` + `get_smem_store_op` paired atoms for one-way trips.
|
||||
8. **CuTeDSL region isolation:** `flat_divide` and `tma_partition` can't be called inside `if warp_idx` blocks. Do partitioning outside `if` blocks or in regular (non-`@cute.kernel`) helper functions.
|
||||
9. **`composition` vs `logical_divide`:** Both re-tile a tensor, but produce different layouts. The CUTLASS `correction_rescale` uses `composition`, `correction_epilog` uses `logical_divide`. The copy atoms must match the tensor layout they were created with.
|
||||
10. **Variables in CuTeDSL `if` blocks are NOT visible in other `if` blocks.** Even when the condition is a compile-time constant (`self.use_smem_p`), CuTeDSL's MLIR lowering creates separate regions. Variables must be defined *unconditionally* before the first `if` that uses them. This applies across `if warp_idx == X` blocks, `for` loops, and nested branches. If a variable is set in `if not use_smem_p:` and read in another `if not use_smem_p:` inside a `for` loop inside an `if warp_idx < mma_warp_id:`, it won't be visible. Define all such variables before *any* branching.
|
||||
11. **`tOrP0` MUST include the `tmem_p0_offset` column offset.** The softmax warps store P at `tmem_p0_offset=32` (FP32 columns = 64 BF16 elements). PV MMA must read from the same offset. Missing this causes NaN/zeros (MMA reads S from column 0, not P from column 32). Use `const_expr` conditional: `if const_expr(self.tOrP0_offset > 0): tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout) else: tOrP0 = tOrP`. Cannot use `tOrP.iterator + 0` (MLIR OpResult + int fails).
|
||||
12. **LSE formula: `lse = ln(row_sum) + row_max * ln(2)`.** `row_max` is in the scale_log2 domain (`max(S * scale * log2(e))`). Multiply by `ln(2)` to convert to natural log domain: `attn_max = row_max * ln(2)`. So `lse = ln(row_sum) + row_max * ln(2)`. Verified: LSE err=0.000000.
|
||||
13. **CuTeDSL MLIR backend cannot handle complex pipeline loops.** The MLIR→PTX optimizer has exponential-or-worse behavior for kernels with TMA pipeline acquire/release inside loops. Both Python `range()` (unrolled) and `cutlass.range(unroll=1)` (runtime) trigger 3+ hour compilation for hd=512. Consider raw CUDA C++ for complex kernels. Pre-compilation + cubin caching is a viable workaround if the optimizer eventually finishes.
|
||||
14. **Guard dead code with `const_expr`.** CuTeDSL compiles BOTH branches of Python `if` statements. Use `const_expr(condition)` to eliminate dead code at compile time. Critical for: O rescale (only when n_kv_tiles>1), LSE (only when normalize=False), SMEM-P path (only when use_smem_p=True), k_sub path (only when n_k_sub_tiles>1).
|
||||
15. **External k_sub merge is mathematically impossible.** k_sub segments are additive in LOGIT space (S = S_0 + S_1), not attention weight space. You cannot recover softmax(S_0+S_1)@V from softmax(S_0)@V and softmax(S_1)@V. The D5 merge formula works for different token sets (additive in weight space), NOT for partial dot products. In-kernel k_sub accumulation before softmax is the only correct approach.
|
||||
16. **`pv_n_tile` reduction is the easiest SMEM knob.** At hd>256, reducing pv_n_tile from 256 to 128 shrinks sV and sC by 2× each. Cost: 4 PV GEMM passes instead of 2. But PV is typically not the bottleneck, and this is simpler than SMEM overlap or Q tiling.
|
||||
17. **O rescale TMEM round-trip with Ld32x32bOp/St32x32bOp is BROKEN.** Even a NO-OP round-trip (load O, multiply by 1.0, store back) corrupts data (cos 0.804 at s_k=256). The hand-constructed atoms don't preserve the C-fragment layout during round-trips. CUTLASS `correction_rescale` uses the same pattern — unclear why theirs works. **Workaround:** Python KV merge with per-segment LSE (cos 0.999998 for s_k up to 1024).
|
||||
18. **KV merge formula uses NORMALIZED outputs, not un-normalized.** The correct D5 merge for different token sets: `O = sum_i [exp(lse_i) * O_i_norm] / sum_i [exp(lse_i)]`. Using `O_i_unnorm` instead of `O_i_norm` gives cos ~0.91. The un-norm merge only works when both segments share the same `row_max` (global max), which isn't the case for separate KV segments.
|
||||
19. **`flat_divide` + `epilogue_tma_store` layout mismatch.** When using `cute.flat_divide` to create per-CTA GMEM views with runtime block coordinates (for multi-CTA grid), the resulting tensor layout is incompatible with CUTLASS's `epilogue_tma_store` pipeline, which expects the layout from `local_tile`. The tma_partition and epilogue must be refactored together to support multi-CTA grids.
|
||||
20. **`local_tile` does not support runtime coordinates.** `cute.local_tile(mQ, tiler, (runtime_val, None))` fails at trace time. Must use `cute.flat_divide(mQ, tiler)` instead, which creates a tiled view with all rest dimensions accessible via runtime indexing.
|
||||
21. **Sink bias domain correction.** Adding `attn_sink` directly to raw logits is wrong — it gets scaled by `scale_log2`. Fix: add `attn_sink / scale` to raw logits, so after `* scale_log2` it becomes `attn_sink * log2(e)`, correctly multiplying attention weights by `exp(attn_sink)`.
|
||||
22. **O normalization uses row_sum, NOT LSE.** `O_norm = O_unnorm / row_sum` is correct. `O_unnorm * exp(-LSE)` is WRONG because O_unnorm is max-shifted (divided by `2^row_max`), not raw `exp(S) @ V`. The kernel now outputs `row_sum` alongside LSE.
|
||||
23. **n_comp is compile-time, swa_len is runtime.** The `n_comp` parameter controls `const_expr` guards in the kernel and cannot vary between segments of the same kernel instance. `swa_len` is an `Int32` scalar and can vary per request. For multi-tile production, use a kernel cache keyed on `(n_comp, apply_sink_bias, head_dim, s_k)`.
|
||||
These cost real days to learn. They are listed in priority of how easy they are to repeat.
|
||||
|
||||
### Layout & TMA
|
||||
|
||||
- **TMA partition mode ordering** (the bug that ate a whole day): see CuTeDSL constraint #1 above. The wrong slice produces "reasonable" wrong outputs — cos 0.7–0.9, never NaN — so you can ship it without knowing.
|
||||
- **Square hides bugs.** (128,128) worked for every wrong approach to PV. Always test non-square shapes early.
|
||||
- **Print the shapes always.** Reasoning about TMEM layouts or TMA mode counts without running `cute.printf(cute.shape(t))` inside `@cute.kernel` is how every multi-day debug starts. Shapes are ground truth.
|
||||
- **`qk_mma_tiler` K-dim must equal `head_dim`**, not the MMA instruction's K sub-tile size. Hardcoding `qk_ik * 4 = 64` was the root cause of the hd>64 failure; the QK GEMM only computed half the dot product. Fix was one line; cos went from 0.78 to 0.999997 at hd=128.
|
||||
|
||||
### TMEM
|
||||
|
||||
- **Never assume TMEM round-trips are safe.** Verify with a NO-OP test (load → store unchanged) before adding any logic. The hand-constructed atoms produce ~3% error even on NO-OP.
|
||||
- **FMHA P store uses QK C-fragment composition, not PV A-fragment.** Two aliases of the same TMEM region. Mixing them up gives valid-looking garbage.
|
||||
- **Register bridge for P: FP32 backing (store partition) + BF16 view (QK-load layout).** Do not skip the dual view.
|
||||
- **TMEM round-trip mismatch with `epilogue_tma_store`**: `epilogue_tma_store` reads O from TMEM using `get_tmem_load_op`'s layout. Hand-built atoms read with a different layout. Round-tripping through hand-built atoms transcodes the data, leaving 3% error.
|
||||
- **The correction-epilog pattern is the fix.** TMEM → registers (via paired t2r atom) → modify in registers → SMEM (via paired r2s atom) → GMEM (via TMA). One-way trip, no round-trip, no transcoding. The MoE kernel uses this and gets perfect results. See ROADMAP.
|
||||
|
||||
### CuTeDSL & MLIR
|
||||
|
||||
- **CuTeDSL `if` blocks create separate MLIR regions.** Variables defined in `if not use_smem_p:` and read in another `if not use_smem_p:` inside a `for` inside an `if warp_idx < mma_warp_id:` are not visible. Define unconditionally before any branching.
|
||||
- **CuTeDSL compiles both branches of Python `if`.** Wrap mode-specific dead code in `const_expr(condition)` to eliminate it. Critical for O rescale (`n_kv_tiles > 1`), LSE compute (`not normalize`), SMEM-P path.
|
||||
- **CuTeDSL MLIR backend cannot handle complex pipeline loops at hd=512.** Both unrolled (Python `range`) and runtime (`cutlass.range unroll=1`) loops trigger exponential-or-worse optimizer time. Tracer is fast (~0.8s); MLIR optimizer chews for 3+ hours. Workaround options in ROADMAP.
|
||||
- **Don't mix Python loops and pipeline ops.** Python `for` unrolls at trace time — N copies of pipeline acquire/release + TMA + GEMM blow up the IR. Prefer `cutlass.range(unroll=1)` for pipeline loops.
|
||||
|
||||
### Math & merging
|
||||
|
||||
- **External k_sub merge is mathematically impossible.** You cannot merge `softmax(Q_k0 @ K_k0^T) @ V` and `softmax(Q_k1 @ K_k1^T) @ V` into `softmax(Q @ K^T) @ V`. k_sub partitions are additive in **logit** space (`S = S_0 + S_1`); softmax is nonlinear. The D5 merge formula only works because sparse and SWA attend over **different token sets** (additive in weight space). In-kernel accumulation before softmax is the only correct approach for k_sub.
|
||||
- **D5 multi-tile KV merge IS valid.** Per-segment LSE + the formula `O = Σ exp(lse_i) · O_i / Σ exp(lse_i)` works because each segment is a separate softmax over a separate token range. This is the Python KV merge workaround that ships today; the in-kernel single-launch version requires the correction-epilog fix.
|
||||
- **Sink merge = single softmax over `[S_comp, S_swa + attn_sink]`.** The two-branch weighted merge formula in the paper is mathematically equivalent to adding `attn_sink` as a logit bias on the SWA positions and softmaxing once. One pass, one kernel. This obsoleted D5d.
|
||||
|
||||
### Numerics
|
||||
|
||||
- **Always test at hd=64 first.** If the proven TMEM-P path regresses, nothing else matters.
|
||||
- **`St32x32bOp` must be Float32**, not BFloat16. BFloat16 throws illegal memory access. (Yes, this is a CuTeDSL constraint — listing here because it's been forgotten more than once.)
|
||||
- **First PV `ACCUMULATE=False`.** Otherwise sums uninitialized TMEM into the output and you see ~50% error.
|
||||
|
||||
### Workflow
|
||||
|
||||
- **Never edit on the B200.** Edit locally, commit, push, pull, test. The B200 has no editor history; one bad save and the file is lost.
|
||||
- **Print shapes inside `@cute.kernel` at trace time.** `print(f"tBgK shape: {cute.shape(tBgK)}")` runs at compile time, not runtime, and is your only window into the JIT's view of layouts. This is the single most useful debugging line in CuTeDSL.
|
||||
|
||||
### SMEM budget
|
||||
|
||||
- **`pv_n_tile` is the easiest SMEM knob.** At hd > 256, reducing `pv_n_tile` from 256 to 128 halves sV and sC. Cost: 4 PV GEMM passes instead of 2 (PV is rarely the bottleneck). Simpler than SMEM overlap or Q tiling.
|
||||
- **`kv_stage` is the second-easiest.** Drop to 1 when budget gets tight at hd > 128; lose double-buffering on K/V but free 64+ KB.
|
||||
- **SMEM budget at various hd** (with `pv_n_tile=256` for hd≤256, `pv_n_tile=128` for hd>256, `kv_stage=2` for hd≤128 else 1):
|
||||
|
||||
| hd | sQ | sK | sV | sP | sC | Total | Limit |
|
||||
|---:|---:|---:|---:|---:|---:|------:|------:|
|
||||
| 64 | 32 KB | 32 KB | 32 KB | — | 32 KB | 128 KB | 232 KB |
|
||||
| 128 | 32 KB | 32 KB | 32 KB | — | 32 KB | 128 KB | 232 KB |
|
||||
| 256 | 64 KB | 64 KB | 64 KB | 0* | 32 KB | 224 KB | 232 KB |
|
||||
| 512 | 64 KB | 64 KB | 32 KB | 0* | 32 KB | 192 KB | 232 KB |
|
||||
|
||||
*TMEM-P path: sP allocation skipped via `const_expr` conditional.
|
||||
|
||||
---
|
||||
|
||||
## Key Lessons
|
||||
## Reference
|
||||
|
||||
1. **NEVER use `find_tmem_tensor_col_offset()` as TMEM placement.** It returns footprint size, not a safe offset.
|
||||
2. **FMHA never trusts DLPack tensor layouts.** Reconstruct V as (hd, s_k) MN-major inside CuTe.
|
||||
3. **TMEM allocation must be power of 2.**
|
||||
4. **Square hides bugs.** (128,128) worked for every wrong approach. Always test non-square.
|
||||
5. **St32x32bOp MUST use Float32**, NOT BFloat16. BFloat16 causes illegal memory access.
|
||||
6. **First PV ACCUMULATE=False.** Otherwise adds uninitialized TMEM to output.
|
||||
7. **FMHA P store uses QK C-fragment composition, NOT PV A-fragment.** Two aliases, same TMEM.
|
||||
8. **Register bridge: FP32 backing (store partition) + BF16 view (QK-load layout).** Do not skip this.
|
||||
9. **PRINT THE SHAPES. ALWAYS.** Reasoning about TMEM layouts without evidence is how we waste days.
|
||||
10. **Never assume TMEM round-trips are safe.** Verify with NO-OP tests before adding logic.
|
||||
|
||||
---
|
||||
|
||||
## Stage D: Full Decode Attention (revised May 23)
|
||||
|
||||
### Key Insight: The Indexer Solves Paging Upstream
|
||||
|
||||
The indexer now hands the kernel `selected_kv: [T, top_k, head_dim] BF16` — a **dense, materialized, dequantized** K/V tile. FMHA sees a dense `[T, top_k, 512]` tile, exactly like Stage A/B's existing `k` and `v` inputs. **The kernel doesn't need to know it's sparse.** Paged TMA, scattered HBM reads, FP8 dequantization — all handled by `gather_selected_kv` upstream.
|
||||
|
||||
The SWA branch is the only "irregular" thing: it reads from the state cache's ring buffer with a position mask. SWA is small (`n_win=128` per query), so it's a separate fused branch with a sink-weighted merge.
|
||||
|
||||
**One FMHA kernel serves all three DSV4 attention types:**
|
||||
- **CSA:** `compressed_kv` = top-k from indexer, `swa_kv` from cache → sink merge
|
||||
- **HCA:** `compressed_kv` = all classical pool entries (gather-all mode), `swa_kv` from cache → sink merge
|
||||
- **SWA-only (Flash layers 0-1):** `compressed_kv` = empty (`top_k=0`), only SWA runs. Sink merge degenerates to just `o_swa` after renormalization.
|
||||
|
||||
### Build Order
|
||||
|
||||
**D1 — Parameterize HEAD_DIM + SMEM-P** (~1 day, MOSTLY DONE)
|
||||
|
||||
Currently hardcoded at 64. Promote to constructor arg, thread through `_setup`. Test at 64, then 512 (DSV4's real value).
|
||||
|
||||
hd≤256: ✅ DONE. cos 0.999998 at hd=64/128/256. Both TMEM-P and SMEM-P paths work.
|
||||
|
||||
hd=512: ❌ BLOCKED. SMEM budget fixed (192KB, fits 232KB limit). Kernel structurally correct (tracer 0.8s). But CuTeDSL's MLIR→PTX backend optimizer hangs for 3+ hours when compiling the k_sub loop. External k_sub merge is mathematically impossible (k_sub segments additive in logit space, not weight space). Need either: (a) pre-compile offline + cache cubin, (b) add no-softmax mode for S accumulation in Python, or (c) write hd=512 path in raw CUDA C++.
|
||||
|
||||
Done when: identical result at HEAD_DIM=64 (regression), passes at HEAD_DIM=512 against FP32 oracle.
|
||||
|
||||
**D2 — Multi-query grid with head packing** (~1 day)
|
||||
|
||||
Grid changes from `(1, 1, 1)` to `(num_q_blocks, 1, batch)`. DSV4 is MQA — all `n_h=128` query heads share the same K/V. The query-head axis is folded into the M dimension of the Q tile: `M_tile = 128` covers `M = T * n_h` rows. At decode T is small (1-16), so packing heads into M fills the MMA. At prefill T=64, M is already 8192 with heads packed.
|
||||
|
||||
Done when: batch=4, T=64, n_h=128, num_kv_heads=1 produces correct attention against FP32 oracle.
|
||||
|
||||
**D3 — SWA sequence length mask** (~½ day)
|
||||
|
||||
The indexer's `top_k` is fixed (512 for Flash, 1024 for Pro). Compressed-K input is always `[T, top_k, head_dim]` with the same `top_k` at compile time.
|
||||
|
||||
What varies: the SWA window holds up to `n_win=128` tokens but starts with fewer. Add `swa_lens: [batch] int32` as kernel input. Mask SWA-branch logits to `-inf` where `swa_idx >= swa_lens[b]`.
|
||||
|
||||
Done when: batched input with varying SWA fill levels (some requests at position 50, some at 5000) produces correct masked output.
|
||||
|
||||
**D4 — Causal mask on SWA branch** (~½ day)
|
||||
|
||||
The compressed K the indexer selects is already from `s < floor(t/m)` (paper eq. 17). The indexer enforces causality at selection time. FMHA sees only causally-valid candidates. **The main path has no mask.**
|
||||
|
||||
The SWA branch needs a causal mask within the window. Add `is_causal: bool` constructor flag, apply `swa_idx > q_pos` masking to `-inf` in the SWA pass.
|
||||
|
||||
Done when: prefill mode produces correct output with the causal mask applied to SWA.
|
||||
|
||||
**D5 — SWA + sink merge** ✅ COMPLETE (May 26)
|
||||
|
||||
Per `dsv4/ops/decode_sparse.py`:
|
||||
```
|
||||
o = (exp(lse_sparse) * o_sparse + exp(attn_sink) * exp(lse_swa) * o_swa)
|
||||
/ (exp(lse_sparse) + exp(attn_sink) * exp(lse_swa))
|
||||
```
|
||||
|
||||
**Key insight (May 26):** This merge is mathematically identical to a single attention pass over concatenated KV with a logit bias on SWA positions:
|
||||
```
|
||||
S = [S_comp, S_swa + attn_sink]
|
||||
O = softmax(S) @ [V_comp; V_swa]
|
||||
```
|
||||
This means D5c is a **logit bias addition**, not a two-pass + merge kernel. D5d (fused in-kernel merge epilogue) is NO LONGER NEEDED.
|
||||
|
||||
**D5a ✅:** `normalize` flag + LSE + row_sums output. When False, emits un-normalized O + LSE + row_sum.
|
||||
|
||||
**D5b ✅:** Per-row LSE output (all 128 rows now write). Python KV merge with per-row LSE: cos 0.999994.
|
||||
|
||||
**D5c ✅:** Sink bias as logit modification. Parameters: `n_comp` (compressed KV length, compile-time), `apply_sink_bias` (compile-time flag), `sink_bias` (runtime FP32 tensor). Sink bias added to raw logits as `attn_sink / scale` so after `* scale_log2` it correctly becomes `attn_sink * log2(e)` in the exp2 domain. Multi-tile via Python KV merge: cos 0.999996.
|
||||
|
||||
**D5d:** NOT NEEDED. The sink bias approach makes a fused merge epilogue unnecessary.
|
||||
|
||||
Done when: ✅ End-to-end kernel produces correct attention against FP32 oracle.
|
||||
|
||||
**~~D5 (old) paged TMA~~ — REMOVED.** The indexer + gather handles all paging upstream.
|
||||
|
||||
### Kernel Architecture (after D5 — COMPLETE)
|
||||
|
||||
```
|
||||
Input: Q [T, n_h, 512], compressed_kv [T, top_k, 512], swa_kv [batch, n_win, 512]
|
||||
swa_lens [batch], sink_logits [n_h], request_ids [T]
|
||||
│
|
||||
└─ Single pass over concatenated KV [compressed_kv; swa_kv]:
|
||||
QK → online softmax (with sink bias on SWA, D3/D4 masking) → PV
|
||||
→ O_unnorm + LSE + row_sum
|
||||
→ External normalize: O = O_unnorm / row_sum
|
||||
→ (Multi-tile: Python KV merge across 128-token segments)
|
||||
```
|
||||
|
||||
### Reference Files
|
||||
|
||||
- Sink merge spec: `dsv4/ops/decode_sparse.py` (formula)
|
||||
- SWA decode: `dsv4/ops/decode_swa.py`
|
||||
- Attention reference: `dsv4/reference/attention.py`
|
||||
- CSA attention: `dsv4/reference/csa_attention.py`
|
||||
|
||||
### Stage C Note
|
||||
|
||||
When implementing D5a, Stage C's epilogue changes from "multiply by 1/row_sum" to "emit un-normalized o + lse". Defer this until D5. Through D1-D4, keep Stage C normalize as-is and test as standalone dense FMHA.
|
||||
|
||||
---
|
||||
|
||||
## Stage E: Production Extraction (revised May 23)
|
||||
|
||||
### E1 — File placement
|
||||
|
||||
`dsv4/kernels/attention/fmha.py`. Currently contains `FmhaKernel` (migrated from test, hd=64 TMEM-P). Will gain parameterized `head_dim` and SMEM-P path in D1. Constructor takes all dimensions and dtypes, no module-level constants.
|
||||
|
||||
### E2 — Constructor signature
|
||||
|
||||
```python
|
||||
class FmhaKernel:
|
||||
def __init__(
|
||||
self,
|
||||
head_dim: int, # 512 for DSV4
|
||||
num_query_heads: int, # 128 for Pro, 64 for Flash
|
||||
sliding_window: int, # 128
|
||||
top_k: int, # 512 (Flash) or 1024 (Pro)
|
||||
q_dtype=BFloat16,
|
||||
kv_dtype=BFloat16,
|
||||
o_dtype=BFloat16,
|
||||
qk_acc_dtype=Float32,
|
||||
pv_acc_dtype=Float32,
|
||||
is_causal: bool = False, # affects SWA mask only
|
||||
cta_group: tcgen05.CtaGroup = tcgen05.CtaGroup.ONE,
|
||||
cluster_shape_mn: tuple = (1, 1),
|
||||
):
|
||||
```
|
||||
|
||||
All architecture-level shapes from config flow into the constructor. No FMHA-internal magic numbers.
|
||||
|
||||
### E3 — Call signature
|
||||
|
||||
```python
|
||||
def __call__(
|
||||
self,
|
||||
q: torch.Tensor, # [T, n_h, head_dim] BF16
|
||||
compressed_kv: torch.Tensor, # [T, top_k, head_dim] BF16 — from indexer gather
|
||||
swa_kv: torch.Tensor, # [batch, n_win, head_dim] BF16 — from cache prep
|
||||
swa_lens: torch.Tensor, # [batch] int32
|
||||
sink_logits: torch.Tensor, # [n_h] FP32
|
||||
request_ids: torch.Tensor, # [T] int32 — maps query to its SWA slot
|
||||
o: torch.Tensor, # [T, n_h, head_dim] BF16 — preallocated
|
||||
stream: cuda.CUstream,
|
||||
):
|
||||
```
|
||||
|
||||
Notably absent: block_table, paged KV, inv_scale, FP8 dequant. All handled upstream.
|
||||
|
||||
### E4 — Kernel cache + warmup
|
||||
|
||||
Mirror `dsv4/ops/gemm_runner.py`'s `_compiled_kernel_cache`. Key on `(head_dim, num_query_heads, top_k, is_causal, ...)`. Pre-allocate at warmup, reuse at call. For DSV4, the cache has at most ~2 entries (Flash/Pro × causal/non).
|
||||
|
||||
### E5 — torch.library custom op
|
||||
|
||||
```python
|
||||
@torch.library.custom_op("dsv4::sparse_fmha_with_swa", mutates_args=("o",))
|
||||
def sparse_fmha_with_swa_op(
|
||||
q: torch.Tensor,
|
||||
compressed_kv: torch.Tensor,
|
||||
swa_kv: torch.Tensor,
|
||||
swa_lens: torch.Tensor,
|
||||
sink_logits: torch.Tensor,
|
||||
request_ids: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
runner_id: int,
|
||||
) -> None:
|
||||
runner = get_runner(runner_id)
|
||||
runner._run_impl(q, compressed_kv, swa_kv, swa_lens, sink_logits, request_ids, o)
|
||||
```
|
||||
|
||||
Mutates `o` (preallocated buffer). Consistent with cudagraphs.
|
||||
|
||||
### E6 — Reference parity hook
|
||||
|
||||
`dsv4/reference/attention.py` stays as the FP32 oracle. New test: `tests/unit/test_fmha_kernel.py`.
|
||||
|
||||
```python
|
||||
def test_sparse_fmha_matches_spec(T=64, n_h=128, top_k=1024, n_win=128, hd=512):
|
||||
q = torch.randn(T, n_h, hd, dtype=torch.bfloat16, device='cuda')
|
||||
ck = torch.randn(T, top_k, hd, dtype=torch.bfloat16, device='cuda')
|
||||
swa = torch.randn(4, n_win, hd, dtype=torch.bf16, device='cuda')
|
||||
swa_lens = torch.tensor([128, 50, 128, 75], dtype=torch.int32)
|
||||
sink = torch.randn(n_h, device='cuda')
|
||||
req_ids = torch.randint(0, 4, (T,), dtype=torch.int32)
|
||||
|
||||
# Oracle: pure FP32 spec
|
||||
o_sparse, lse_sparse = attention_with_lse_f32(q, ck, ck)
|
||||
o_swa, lse_swa = attention_swa_with_lse_f32(q, swa, swa, swa_lens, req_ids)
|
||||
e_sink = sink.exp()
|
||||
num = lse_sparse.exp().unsqueeze(-1) * o_sparse \
|
||||
+ e_sink[None, :, None] * lse_swa.exp().unsqueeze(-1) * o_swa
|
||||
den = lse_sparse.exp() + e_sink[None, :] * lse_swa.exp()
|
||||
expected = num / den.unsqueeze(-1)
|
||||
|
||||
# Kernel
|
||||
o = torch.empty_like(expected, dtype=torch.bfloat16)
|
||||
fmha = FmhaKernel(head_dim=hd, num_query_heads=n_h, sliding_window=n_win, top_k=top_k)
|
||||
fmha(q, ck, swa, swa_lens, sink, req_ids, o, stream=...)
|
||||
|
||||
torch.testing.assert_close(o.float(), expected, atol=5e-3, rtol=5e-3)
|
||||
```
|
||||
|
||||
### E7 — Cleanup
|
||||
|
||||
Delete all debug test files. `test_fmha_v3.py` becomes `dsv4/kernels/attention/fmha.py`. Only `tests/unit/test_fmha_kernel.py` remains as the attention test.
|
||||
|
||||
---
|
||||
|
||||
## Environment
|
||||
|
||||
- Server: root@45.76.247.107 (B200, 180 GiB HBM3e per GPU)
|
||||
- venv: `source /root/dsv4-nvfp4-workspace/venv/bin/activate`
|
||||
- PYTHONPATH: `/root/dsv4-nvfp4-workspace/kernel`
|
||||
- Model: `/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4`
|
||||
- vLLM repo: `/root/dsv4-nvfp4-workspace/vllm` (modified for Blackwell)
|
||||
- CUTLASS FMHA reference: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py`
|
||||
- Local CUTLASS clone: `/home/openclaw/dev/cutlass`
|
||||
- **DeepSeek V4 paper**: `DeepSeek_V4.pdf` in the repo root.
|
||||
- **DeepGEMM** (V4-aligned reference kernels): https://github.com/deepseek-ai/DeepGEMM
|
||||
- **CUTLASS FMHA reference**: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py` (B200) or `/home/openclaw/dev/cutlass` (local).
|
||||
- **Reference oracles**: `dsv4/reference/` (PyTorch FP32 — slow, never imported by production code).
|
||||
275
ROADMAP.md
Normal file
275
ROADMAP.md
Normal file
@@ -0,0 +1,275 @@
|
||||
# ROADMAP
|
||||
|
||||
Living document. Current state, active blockers, priority order, and what to build next. Architecture and lessons live in README.md — this file is for "what now."
|
||||
|
||||
**Last updated:** 2026-05-26
|
||||
|
||||
---
|
||||
|
||||
## Current status
|
||||
|
||||
### Working
|
||||
|
||||
| Component | hd | n | cos | Status |
|
||||
|---|---:|---:|---:|---|
|
||||
| FMHA TMEM-P | 64 | 128 | 0.999998 | ✅ |
|
||||
| FMHA TMEM-P / SMEM-P | 128 | 128 | 0.999997 | ✅ |
|
||||
| FMHA TMEM-P | 256 | 128 | 0.999998 | ✅ |
|
||||
| FMHA multi-tile (Python KV merge) | 64 | up to 1024 | 0.999998 | ✅ Workaround |
|
||||
| D3 SWA length mask (in-kernel) | 128 | 128 | 0.999996 | ✅ |
|
||||
| D4 causal mask on SWA (in-kernel) | 128 | 128 | 0.999996 | ✅ |
|
||||
| D5c sink merge single-tile | 64 | 128 | 0.999996 | ✅ |
|
||||
| D5c sink merge multi-tile (Python KV merge) | 64 | 256 | 0.999996 | ✅ |
|
||||
| Per-head multi-head launch | 64 | 128 | 0.999995 | ✅ n_h=1–128 |
|
||||
| MoE fused SwiGLU (NVFP4) | — | — | matches ref | ✅ Clamping in kernel |
|
||||
| Dense router (sqrt-softplus) | — | — | matches ref | ✅ |
|
||||
| Hash router | — | — | matches ref | ✅ |
|
||||
| `use_2cta_instrs` conditional | — | — | 1.7–1.9× speedup | ✅ M≥256 prefill |
|
||||
| NVFP4 primitives | — | — | E4M3 SF / mxf4nvf4 / 16-elem | ✅ Verified |
|
||||
|
||||
### Known blockers
|
||||
|
||||
| Blocker | Impact | Workaround | Fix path |
|
||||
|---|---|---|---|
|
||||
| **D1.5 TMEM round-trip corruption** | Hand-built atoms produce 3% error on NO-OP round-trip; blocks in-kernel multi-tile O rescale and in-kernel normalize | Emit un-normalized O + LSE; Python KV merge for s_k>128 | **Priority 1: correction-epilog rewrite** (sketch below) |
|
||||
| hd=512 MLIR backend hang | Cannot compile hd=512 kernel (>3hr optimizer time, structurally correct) | Run hd=512 via head-packed M with hd≤256 chunks; ship without hd=512 if needed for D2 | Pre-compile cubin / raw CUTLASS C++ / report NVIDIA bug |
|
||||
| D2 multi-CTA grid (flat_divide + epilogue_tma_store) | Per-head Python launch wastes 128 launches per decode step at Pro | Per-head launch (works, just slow) | Unblocked by correction-epilog rewrite (uses `flat_divide` + `tma_partition` like MoE does) |
|
||||
|
||||
---
|
||||
|
||||
## Priority 1: Correction epilog rewrite (unblocks D1.5 + a chain of follow-ons)
|
||||
|
||||
**Why this first:** Every downstream item needs the kernel to have a register-level slot in the epilogue for modification. The current `epilogue_tma_store` path with hand-built atoms doesn't have one. The correction-epilog pattern does.
|
||||
|
||||
**The pattern is already in the codebase** — `dsv4/kernels/gemm/fused_swiglu.py` uses it for the MoE SwiGLU epilogue (lines 2021, 2064–2229). Library helpers, paired atoms, one-way TMEM → registers → SMEM → GMEM. SwiGLU + clamping math sits between the t2r and r2s copies. That's the exact slot FMHA needs.
|
||||
|
||||
**What changes:**
|
||||
|
||||
Replace the FMHA epilogue (`dsv4/kernels/attention/fmha.py` lines 549–597 — the `epilogue_tma_store` call) with:
|
||||
|
||||
1. Setup (run once per kernel, outside the kt loop):
|
||||
- `tCtO_transformed = utils.gemm.sm100.transform_partitioned_tensor_layout(tOtO0)`
|
||||
- `tCgC_transformed = utils.gemm.sm100.transform_partitioned_tensor_layout(tCgC)`
|
||||
- `tiled_copy_t2r, tTR_tO_base, tTR_rO = utils.gemm.sm100.epilogue_tmem_copy_and_partition(...)`
|
||||
- `tiled_copy_r2s, tRS_rC, tRS_sC = utils.gemm.sm100.epilogue_smem_copy_and_partition(...)`
|
||||
- TMA partition for C via `flat_divide(tCgC_transformed, epi_tile)` + `cpasync.tma_partition(...)`
|
||||
|
||||
2. Final epilogue (replaces the round-trip normalize):
|
||||
- Subtile loop, in each subtile: `cute.copy(tiled_copy_t2r, tTR_tO_mn, tTR_rO)` → multiply by `inv_row_sum` in registers if `normalize=True` → cast to BF16 → `cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[...])` → TMA SMEM → GMEM.
|
||||
|
||||
3. Per-kt O rescale (replaces the broken hand-built round-trip on lines 524–544):
|
||||
- Inside the kt loop, when `kt > 0`: same t2r → multiply by `acc_scale` in registers → store back via the paired atom (`tiled_copy_t2r.retile_to_S()` or equivalent — verify exact API on B200).
|
||||
|
||||
**What unblocks:**
|
||||
- D1.5 issue 1 (round-trip corruption): gone.
|
||||
- D1.5 issue 2 (per-kt rescale): gone.
|
||||
- In-kernel multi-tile attention (single launch for s_k=1152, not 9).
|
||||
- NVFP4-1.2 (fuse FP4 quant into FMHA output → wo_a path): the register slot is where amax + FP4 pack go.
|
||||
- D2 multi-CTA grid: `flat_divide` + `tma_partition` path is the same one MoE uses successfully. The flat_divide vs local_tile mismatch resolves.
|
||||
|
||||
**Caveats to print and verify on B200:**
|
||||
- Exact CUTLASS helper API for the "store back to TMEM" direction (`retile_to_S` form vs separate helper vs same-base-tensor pattern).
|
||||
- Whether `transform_partitioned_tensor_layout` accepts `tOtO0` (TMEM iterator with offset) or needs a fresh tensor built at `tmem_ptr + self.tmem_o0_offset` with `tCtO_fake.layout`.
|
||||
- Whether `tma_partition` inside `if warp_idx < self.mma_warp_id` works in this kernel's region tree. The MoE kernel does it; if FMHA hits "weakly congruent," hoist the partition call above the warp branch.
|
||||
|
||||
**Done when:**
|
||||
- hd=64/128/256 regression cos ≥ 0.999998 holds with `normalize=True` and `normalize=False`.
|
||||
- New multi-tile s_k=256 test with `kt > 0` rescale gives cos ≥ 0.999998 (not the current 0.997 Python-merge workaround, the real in-kernel rescale).
|
||||
- Existing Python KV merge tests continue to pass (`test_d15_multi_kv.py`).
|
||||
|
||||
---
|
||||
|
||||
## Priority 2: Stage E — Production extraction
|
||||
|
||||
D5 is complete. The kernel works. Wrap it in a proper interface.
|
||||
|
||||
| Step | What | Status |
|
||||
|---|---|---|
|
||||
| E1 | File placement: `dsv4/kernels/attention/fmha.py` | ✅ Done |
|
||||
| E2 | Constructor signature (`head_dim`, `num_query_heads`, `sliding_window`, `top_k`, sink/causal flags, dtypes) | ⚠️ Partial — needs cleanup |
|
||||
| E3 | Call signature: `q`, `compressed_kv`, `swa_kv`, `swa_lens`, `sink_logits`, `request_ids`, `o`, `stream` | ⚠️ Needs sink_bias / row_sums integration |
|
||||
| E4 | Kernel cache + warmup, keyed on `(head_dim, num_query_heads, top_k, n_comp, apply_sink_bias, is_causal, ...)` | TODO |
|
||||
| E5 | `torch.library.custom_op("dsv4::sparse_fmha_with_swa", mutates_args=("o",))` | TODO |
|
||||
| E6 | Reference parity test against FP32 oracle in `dsv4/reference/attention.py` | TODO |
|
||||
| E7 | Cleanup: delete debug test files, keep only `tests/unit/test_fmha_kernel.py` | TODO |
|
||||
|
||||
Notably absent from the call signature: block_table, paged KV, inv_scale, FP8 dequant. All handled upstream by the indexer + gather kernel chain. FMHA sees a dense BF16 `[T, top_k, head_dim]` tile.
|
||||
|
||||
---
|
||||
|
||||
## Priority 3: NVFP4-1.1 — Fuse FP4 quant into MoE SwiGLU epilogue
|
||||
|
||||
**Independent of FMHA.** Biggest bandwidth win in the codebase. Can run in parallel with Priority 1.
|
||||
|
||||
Current:
|
||||
```
|
||||
padded_x_fp4 → L1 GEMM → SwiGLU → BF16 GMEM
|
||||
↓
|
||||
quantize_activation_nvfp4 (separate kernel)
|
||||
↓
|
||||
padded_activated_fp4 → L2 GEMM
|
||||
```
|
||||
|
||||
Target:
|
||||
```
|
||||
padded_x_fp4 → L1 GEMM → SwiGLU → online amax → FP8 scale + FP4 pack → FP4 GMEM → L2 GEMM
|
||||
```
|
||||
|
||||
The SwiGLU + clamp result already lives in registers at `tRS_rC.store(acc_vec_bf16)` (line 2207 of `fused_swiglu.py`). That's the slot for amax + FP4 pack.
|
||||
|
||||
**Per-microblock amax (16 contiguous elements):**
|
||||
1. shfl_xor butterfly reduction across the 4 threads that hold the 16 elements.
|
||||
2. FP8 E4M3 scale = amax / 6 (FP4 e2m1 max).
|
||||
3. Per-element FP4 pack: sign bit << 3 | (clamped val / scale).to(uint3). Two elements → one byte.
|
||||
4. 16 packed nibbles → 64-bit word → SMEM stage → TMA store.
|
||||
5. FP8 scale → separate scale-factor SMEM stage → TMA store to the L2 SFA buffer.
|
||||
|
||||
**Subtlety:** NVFP4 microblock = 16 elements. Port the same 16-element logic from `dsv4/ops/quantize.py`. Don't accidentally use the 32-element MXFP4 block.
|
||||
|
||||
**Done when:**
|
||||
- `padded_activated_fp4` and `padded_activated_x_sf` scratch buffers go away.
|
||||
- `quantize_activation_nvfp4` between L1 and L2 disappears.
|
||||
- L1 → L2 cosine matches reference (no regression from BF16 intermediate).
|
||||
- L2 GEMM reads FP4 scales produced by L1 epilogue.
|
||||
|
||||
---
|
||||
|
||||
## Priority 4: D2 multi-CTA grid
|
||||
|
||||
Currently per-head Python launch (works, cos 0.999995, but 128 launches per decode step at Pro).
|
||||
|
||||
Multi-CTA grid is unblocked by Priority 1 — the `flat_divide` + `tma_partition` path becomes available once the epilogue uses the MoE pattern.
|
||||
|
||||
**Grid:** `(num_M_tiles, num_query_heads, batch)` — at decode T=1: `(1, 128, batch)`.
|
||||
|
||||
**MQA K/V sharing:** start with independent K/V loads per CTA (each CTA loads its own copy). At decode hd=512, K/V per CTA is ~128 KB; 128 CTAs × 128 KB = 16 MB, well within HBM bandwidth. Cluster-wide sharing via `cluster_shape_mn=(1, num_query_heads, 1)` is a future optimization once profiling shows it matters.
|
||||
|
||||
**Q tensor layout:** Option 1 — `(batch, n_h, T, head_dim)` with head as a TMA mode (matches CUTLASS reference and allows per-head LSE output). Picked over Option 2 (heads packed into M) because it generalizes better to GQA later.
|
||||
|
||||
**Done when:**
|
||||
- `n_h=128, batch=4, T=1` at hd=512 produces correct output with single launch.
|
||||
- Per-head LSE writes to correct `mLSE[batch, head, m_row]` position.
|
||||
|
||||
---
|
||||
|
||||
## Priority 5: NVFP4-1.2 — Fuse FP4 quant into FMHA output → wo_a path
|
||||
|
||||
**Depends on Priority 1** (correction epilog gives the register slot).
|
||||
|
||||
Currently: FMHA emits BF16 → inverse RoPE produces BF16 → wo_a quantizes to FP4.
|
||||
|
||||
Target: register slot in FMHA epilogue does the divide-by-row_sum *and* inverse RoPE rotation *and* per-microblock amax + FP4 pack. wo_a reads FP4 directly.
|
||||
|
||||
Same pattern as Priority 3. Different home (FMHA epilogue, not MoE epilogue).
|
||||
|
||||
---
|
||||
|
||||
## Priority 6: NVFP4-2 — FP4 KV pipeline depth in FMHA
|
||||
|
||||
**Depends on Priority 1** being solid at BF16 KV first.
|
||||
|
||||
FP4 KV shrinks tiles ~4×; same SMEM budget supports more pipeline stages.
|
||||
|
||||
| KV dtype | Tile size (hd=512) | Stages that fit (192 KB budget) |
|
||||
|---|---:|---:|
|
||||
| BF16 | 128 KB | 2 |
|
||||
| FP8 | 64 KB | 4 |
|
||||
| FP4 | ~36 KB | 6 |
|
||||
|
||||
At 1M-context decode where KV reads dominate, deeper pipelines hide more TMA latency.
|
||||
|
||||
**Implementation:**
|
||||
- TMA loads FP4 NoPE dims (packed `e2m1_x2`) to SMEM slot 0.
|
||||
- TMA loads BF16 RoPE dims to SMEM slot 1.
|
||||
- TMA loads FP8 scale factors to SMEM slot 2.
|
||||
- SMEM dequant FP4 → BF16 in vectorized form (`* FP8_scale`, 16-element microblocks).
|
||||
- Concatenate `[NoPE, RoPE]` in SMEM.
|
||||
- MMA reads contiguous BF16 from SMEM.
|
||||
|
||||
**Test:** FP4+BF16 split input → identical output to pure BF16 input (dequant must be transparent).
|
||||
|
||||
---
|
||||
|
||||
## Priority 7: hd=512 fix
|
||||
|
||||
**Blocked.** Per Priority 4, multi-CTA grid + head-packed M means decode at hd=512 can route through `pv_n_tile=128` and `n_k_sub_tiles=2`, which compiles fine for hd=256. The hd=512 *single-kernel* compile is the missing piece for prefill efficiency, not correctness.
|
||||
|
||||
**Options:**
|
||||
1. Pre-compile hd=512 cubin offline (accept 1–2 hour compile if MLIR ever finishes — uncertain).
|
||||
2. Add no-softmax mode emitting raw S to GMEM, call twice for k_sub=0/1, accumulate in Python, softmax once. Two launches but no MLIR hang.
|
||||
3. Write hd=512 path in raw CUTLASS C++. Bypasses CuTeDSL MLIR entirely. Most realistic if NVIDIA can't fix the optimizer.
|
||||
4. Report CuTeDSL MLIR optimizer bug to NVIDIA.
|
||||
|
||||
Lower priority than the chain above — at decode T=1, n_h=128, hd=512 the head-packed approach already works without needing a single hd=512 kernel.
|
||||
|
||||
---
|
||||
|
||||
## Priority 8: Indexer FP4 tensor-core scoring (Stage F)
|
||||
|
||||
Paper §5.2.1: *"the QK path in the indexer of CSA, where QK activations are cached, loaded, and multiplied entirely in FP4."*
|
||||
|
||||
Current indexer (`dsv4/kernels/cuda/indexer_score_topk.cu`): scalar FP32 dot products, no tensor cores, spinlock-protected shared-memory heap. Single largest perf gap in the codebase. At 1M-context decode the indexer scores ~250K compressed entries per query token — the spinlock heap will not scale to top_k=1024.
|
||||
|
||||
**Target:** port DeepGEMM `fp8_paged_mqa_logits` to FP4 inputs with `tcgen05.mma.kind=mxf4nvf4`. Plus per-warp partial top-k merged with a final reduction tree (or radix-select). Plus FP32→BF16 score quantization per paper (2× speedup on top-k selector, 99.7% recall).
|
||||
|
||||
**Scope:** 2–3 weeks. Track for Stage F. Do not start until the FP4 epilogue patterns from Priorities 3 and 5 are established — they'll inform the indexer's FP4 load + score paths.
|
||||
|
||||
---
|
||||
|
||||
## Build order — recommended sequencing
|
||||
|
||||
```
|
||||
Now ─┬─ Priority 1 (correction epilog rewrite)
|
||||
│ │
|
||||
│ └─→ unblocks D1.5, D2 multi-CTA, NVFP4-1.2
|
||||
│
|
||||
├─ Priority 3 (NVFP4-1.1 fuse FP4 in SwiGLU) ← parallel, independent
|
||||
│
|
||||
↓
|
||||
Verify hd=64/128/256 regressions hold
|
||||
│
|
||||
↓
|
||||
Priority 2 (Stage E production extraction)
|
||||
│
|
||||
↓
|
||||
Priority 4 (D2 multi-CTA grid)
|
||||
│
|
||||
↓
|
||||
Priority 5 (NVFP4-1.2 fuse FP4 in FMHA output)
|
||||
│
|
||||
↓
|
||||
Priority 6 (NVFP4-2 FP4 KV pipeline)
|
||||
│
|
||||
↓
|
||||
Priority 7 (hd=512 fix — only if prefill efficiency demands it)
|
||||
│
|
||||
↓
|
||||
Priority 8 (indexer FP4 tensor-core scoring) — Stage F
|
||||
```
|
||||
|
||||
Priority 3 has no dependency on Priorities 1 or 2 and can run on a parallel branch.
|
||||
|
||||
---
|
||||
|
||||
## Speculative — beyond what the V4 paper validated
|
||||
|
||||
Listed for completeness. **Do not implement without explicit sign-off.**
|
||||
|
||||
1. **NVFP4 compressed KV NOPE dims** (paper validated FP8 for compressed KV; FP4 would halve cache again). Risk: compounds quantization noise on already-lossy compressed KV.
|
||||
2. **MXFP4 vs NVFP4 for indexer scoring** — not validated for indexer specifically.
|
||||
3. **NVFP4 for full attention Q×K^T GEMM** — closed. Cos 0.86 vs FP32 in earlier tests. Attention stays BF16/FP32.
|
||||
4. **Per-token FP8 activation scaling in FMHA** — not validated. Out of scope.
|
||||
5. **2:4 structured sparsity on FP4 expert weights** — V4 not trained with structured sparsity. Off the table for the released checkpoint.
|
||||
6. **NVFP4 LM head + MTP head** — big VRAM win (~1.4 GB saved on Pro). Modest quality risk on rare-token logits. Test against held-out eval before shipping.
|
||||
|
||||
---
|
||||
|
||||
## Key numbers to remember
|
||||
|
||||
| Config | n_h | top_k | s_k decode | n_kv_tiles | Multi-tile? |
|
||||
|---|---:|---:|---:|---:|:---|
|
||||
| Flash decode | 64 | 512 | 640 | 5 | YES |
|
||||
| Pro decode | 128 | 1024 | 1152 | 9 | YES |
|
||||
| Current single-tile test | 1 | — | 128 | 1 | NO |
|
||||
|
||||
Production decode needs the multi-tile path (Priority 1) working in-kernel. Today's Python KV merge ships correct results at the cost of 5–9 launches per step.
|
||||
636
archived_plans/README.md
Normal file
636
archived_plans/README.md
Normal file
@@ -0,0 +1,636 @@
|
||||
# DSV4 Inference Kernel
|
||||
|
||||
## ⚠️⚠️⚠️ CRITICAL: TMA Partition Tensor Mode Ordering ⚠️⚠️⚠️
|
||||
|
||||
**THIS BUG COST US AN ENTIRE DAY. READ THIS. BURN IT INTO YOUR BRAIN.**
|
||||
|
||||
After `cpasync.tma_partition()`, the output GMEM tensor has **4 modes** (verified on B200):
|
||||
|
||||
```
|
||||
tBgK shape: (((64, 128), 1), ?, KV_tiles, ?)
|
||||
mode 0 1 2 3
|
||||
```
|
||||
|
||||
**Mode 2 is the GMEM tile dimension.** The dimension you index with `kt` to load different K/V tiles.
|
||||
|
||||
### THE WRONG WAY (what we did — silently loads from tile 0 forever):
|
||||
|
||||
```python
|
||||
# ❌❌❌ (None,None,0,0) KEEPS MODES 0,1 FREE, SETS MODE 2 TO 0 ❌❌❌
|
||||
# Mode 2 (the KV tile dim) gets collapsed to coordinate 0.
|
||||
# TMA ALWAYS reads from tile 0.
|
||||
tBgK = tBgK[(None, None, 0, 0)] # ← WRONG! Mode 2 pinned to 0!
|
||||
|
||||
# The copy "works" but kv_coord indexes mode 1 (inner GEMM K, not KV tiles).
|
||||
cute.copy(tma_k, tBgK[(None, kv_coord)], ...) # ← kv_coord indexes wrong mode!
|
||||
```
|
||||
|
||||
### THE RIGHT WAY (verified on B200 at n=128 and n=256):
|
||||
|
||||
```python
|
||||
# ✅ (None,0,None,0) keeps modes 0 and 2 free → 2D tensor
|
||||
# Mode 2 (KV tiles) survives as the second mode.
|
||||
tBgK = tBgK[(None, 0, None, 0)]
|
||||
|
||||
# ✅ [None, kt] indexes the surviving mode 1 (originally mode 2 = KV tiles)
|
||||
cute.copy(tma_k, tBgK[None, kt], ...)
|
||||
# ^^ THIS IS THE KV TILE DIM
|
||||
```
|
||||
|
||||
**Verified shapes on B200 (May 22, n=256, inside @cute.kernel):**
|
||||
```
|
||||
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
|
||||
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
|
||||
```
|
||||
|
||||
### WHY THIS IS SO INSIDIOUS
|
||||
|
||||
1. **No error, no warning.** The slice `tBgK[(None,None,0,0)]` silently sets mode 2 to 0.
|
||||
2. **Single-tile (n=128) works perfectly.** With only 1 KV tile, mode 2 is size 1, so the bug is invisible.
|
||||
3. **Multi-tile tests produce "reasonable" output.** The TMA loads from tile 0 every time, so you get a valid (but wrong) attention computation. Cosine similarity is 0.7-0.9, not NaN.
|
||||
4. **The strides are all 0.** Printing `tBgK.layout.stride` shows all zeros for TMA tensors. You can't detect the bug from strides alone.
|
||||
5. **`cute.printf` shows `kv_coord=0`.** We thought the JIT was constant-folding the variable. It wasn't — the variable was fine, but it was indexing the wrong mode.
|
||||
6. **The 8-mode theory was wrong.** We assumed tma_partition produced 8 TMA coordinate dimensions. It produces 4. The 8-None no-op slice fails with "weakly congruent" at JIT compile.
|
||||
|
||||
### THE LESSON
|
||||
|
||||
**PRINT THE SHAPES. ALWAYS.** Run `print(f"tBgK: shape={cute.shape(tBgK)}")` inside `@cute.kernel` at trace time. The shapes are your ground truth. Reasoning about mode counts without evidence is how we wasted a day.
|
||||
|
||||
**The correct pre-slice depends on which mode is the GMEM tile iteration axis.** For our `local_tile` + `partition_B` + `group_modes(0,3)` pattern, mode 2 is the KV tile axis. `(None,0,None,0)` keeps it free. `(None,None,0,0)` collapses it to 0.
|
||||
|
||||
```python
|
||||
# ALWAYS verify the shape at trace time:
|
||||
print(f"tBgK shape: {cute.shape(tBgK)}") # 4 modes
|
||||
print(f"tBgK after slice: {cute.shape(tBgK[(None,0,None,0)])}") # 2 modes
|
||||
|
||||
# Then index the 2D tensor:
|
||||
cute.copy(tma_k, tBgK[None, kt], ...)
|
||||
```
|
||||
|
||||
**IF YOU USE (None,None,0,0) INSTEAD OF (None,0,None,0), MULTI-TILE TMA WILL BE SILENTLY BROKEN.**
|
||||
|
||||
---
|
||||
|
||||
## Architecture
|
||||
|
||||
DSV4 is **not MLA**. It uses **CSA (Compressed Sparse Attention, m=4)** and **HCA (Heavily Compressed Attention, m′=128)**. KV latent is (T, 512) shared across all 128 heads. Sink weights merge sparse + SWA attention. vLLM misnames this as "MLA" — it is not. The architecture is fundamentally different.
|
||||
|
||||
```
|
||||
DSV4 inference pipeline — component status
|
||||
==========================================
|
||||
|
||||
Legend:
|
||||
[✓] built and tested
|
||||
[~] partial — reference or seam exists, native pending
|
||||
[✗] to build
|
||||
|
||||
|
||||
┌────────────────────────────────────┐
|
||||
│ [✗] Embedding + mHC init │
|
||||
│ token embed + n_hc=4 streams │
|
||||
└────────────────┬───────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─ Transformer layer × L ──────────────────────────────────────────────┐
|
||||
│ HCA on layers 0–1 of Pro, alternating CSA / HCA after │
|
||||
│ │
|
||||
│ ┌─ Attention sub-block ──────────────────────────────────────────┐ │
|
||||
│ │ [✓] Residual mHC pre + post mix │ │
|
||||
│ │ [~] Norms + RoPE RMSNorm + partial RoPE │ │
|
||||
│ │ [✓] Q / KV projection NVFP4 linears + LoRA │ │
|
||||
│ │ [✓] Token compressor CSA m=4 / HCA m′=128 │ │
|
||||
│ │ [✓] Indexer + top-k CSA, FP32 dot + top-k │ │
|
||||
│ │ [~] FMHA core QK → online softmax → PV │ │
|
||||
│ │ + SWA branch + sink merge │ │
|
||||
│ │ [✓] Output projection inv RoPE + wo_a grouped + wo_b │ │
|
||||
│ └────────────────────────────────────────────────────────────────┘ │
|
||||
│ │
|
||||
│ ┌─ FFN sub-block ────────────────────────────────────────────────┐ │
|
||||
│ │ [✓] Residual mHC pre + post mix │ │
|
||||
│ │ [✓] Pre-FFN norm RMSNorm │ │
|
||||
│ │ [✓] Router sqrt(softplus) + topk + hash │ │
|
||||
│ │ [✓] Routed MoE fused SwiGLU L1 + L2 │ │
|
||||
│ │ [✓] Shared expert NVFP4 single-group GEMM │ │
|
||||
│ └────────────────────────────────────────────────────────────────┘ │
|
||||
└──────────────────────────────────┬───────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌──────────────────────────────────────────────────────────────────────┐
|
||||
│ [✗] Final RMSNorm → [✗] LM head → [✗] MTP (depth=1) → [✗] Sampler │
|
||||
└──────────────────────────────────────────────────────────────────────┘
|
||||
|
||||
┌─ Supporting infrastructure ──────────────────────────────────────────┐
|
||||
│ [✗] KV cache management │
|
||||
│ • state cache: SWA window + uncompressed tail per layer │
|
||||
│ • classical paged cache: lcm(m, m′) = 128 tokens per block │
|
||||
│ • heterogeneous layout per layer │
|
||||
└──────────────────────────────────────────────────────────────────────┘
|
||||
|
||||
|
||||
Summary
|
||||
-------
|
||||
Built [✓] : 9 — mHC ×2, Q/KV proj, output proj, routed MoE,
|
||||
shared expert, token compressor, indexer+topk,
|
||||
router, pre-FFN norm
|
||||
Partial [~] : 3 — norms+RoPE, FMHA core
|
||||
To build [✗] : 6 — embedding+init, final norm, LM head, MTP, sampler, KV cache
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Status (May 26, 2026 — 18:40 UTC)
|
||||
|
||||
| Stage | Status | Description |
|
||||
|-------|--------|-------------|
|
||||
| A | ✅ COMPLETE | Q@K^T via tcgen05.mma → TMEM → GMEM |
|
||||
| B | ✅ COMPLETE | QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving) |
|
||||
| C | ✅ COMPLETE | Real online softmax. Kernel outputs un-norm O + LSE (no TMEM round-trip). Migrated to `dsv4/kernels/attention/fmha.py` as `FmhaKernel`. |
|
||||
| D1 | 🟡 hd≤256 DONE | Parameterized HEAD_DIM. qk_mma_tiler fix (hd=64/128/256 cos 0.999998). hd=512 SMEM fits but MLIR compilation hangs (>3hr). External k_sub merge proven impossible. O rescale TMEM round-trip BROKEN (Ld32x32bOp/St32x32bOp corrupt data). Python KV merge workaround works. |
|
||||
| D1.5 | ❌ BLOCKER | O rescale for multi-KV-tile (kt>0). TMEM round-trip corruption (even NO-OP round-trip fails). Python KV merge workaround: cos 0.999994. Production: 5-9 kernel launches per decode. Fix requires correction epilog (one-way TMEM→regs→SMEM→GMEM). |
|
||||
| D2 | 🟡 Per-head DONE | Head-packed M-dimension launch (cos 0.999995, n_h=1-128). Multi-CTA grid blocked: `flat_divide` + `epilogue_tma_store` layout mismatch. |
|
||||
| D3 | ✅ DONE | SWA sequence length mask (in-kernel post-QK via tTMEM_LOADcS coordinates, swa_len Int32 scalar, offset by n_comp for D5c) |
|
||||
| D4 | ✅ DONE | Causal mask on SWA branch (SWA-relative position > m_coord → -inf, combined with D3 via OR logic) |
|
||||
| D5 | ✅ D5a+D5b+D5c DONE | D5a: normalize flag + LSE + row_sums output. D5b: Per-row LSE + Python KV merge (cos 0.999994). D5c: Sink bias as logit modification — mathematically equivalent to separate merge, single pass over combined KV (cos 0.999996 single-tile AND multi-tile). D5d (fused in-kernel merge) NOT NEEDED — sink bias approach supersedes it. |
|
||||
| E1-E7 | TODO | Production extraction (class, custom op, cache, cleanup) |
|
||||
| NVFP4-3 | ✅ DONE | `use_2cta_instrs` conditional in gemm_runner.py. 1.7-1.9× throughput at prefill shapes. |
|
||||
|
||||
---
|
||||
|
||||
## Package Structure
|
||||
|
||||
```
|
||||
dsv4/
|
||||
├── kernels/ Pure GPU code (CuTeDSL @cute.jit, .cu files)
|
||||
│ ├── gemm/ NVFP4 MoE GEMM kernels (grouped, fused_swiglu, dense, scheduler)
|
||||
│ ├── attention/ FMHA kernel — FmhaKernel (hd=64, TMEM-P proven; SMEM-P stub for hd>64)
|
||||
│ ├── compressor/ CSA/HCA token-level compressor (CuTeDSL, 419 lines)
|
||||
│ ├── indexer/ CSA indexer — score+topk (FP32 dot products, top-k selection)
|
||||
│ ├── router/ Dense router decode kernel (warp-specialized persistent GEMM)
|
||||
│ ├── cache/ Cache kernels — append_swa (write KV to split state cache layout)
|
||||
│ ├── decode/ Decode-time attention (sparse, SWA — future)
|
||||
│ └── cuda/ Raw .cu files (deinterleave_quantize, sparse_topk_metadata)
|
||||
├── ops/ PyTorch ↔ kernel bridges
|
||||
│ ├── quantize.py BF16 ↔ NVFP4 conversion, scale factors
|
||||
│ ├── layouts.py Scale swizzle, gate/up interleave, K-major, offsets
|
||||
│ ├── gemm_runner.py Warmup, compile, run grouped/fused GEMMs
|
||||
│ ├── custom_ops.py torch.library.custom_op registrations
|
||||
│ ├── decode_sparse.py native_sparse_decode dispatcher
|
||||
│ ├── decode_swa.py native_swa_decode dispatcher
|
||||
│ ├── rope.py Forward + inverse RoPE
|
||||
│ ├── topk.py Python wrapper for sparse_topk_metadata.cu
|
||||
│ ├── topk_select.py Top-k selection wrapper
|
||||
│ └── router.py Router op bridge
|
||||
├── layers/ nn.Module-style components
|
||||
│ ├── linear.py Nvfp4Linear
|
||||
│ ├── grouped_linear.py Nvfp4GroupedLinear
|
||||
│ ├── moe.py Nvfp4MoE
|
||||
│ ├── shared_expert.py Nvfp4SharedExpert
|
||||
│ ├── mhc.py mHCLayer
|
||||
│ ├── attention.py DSV4 attention sub-block (CSA/HCA/SWA variants, 245 lines)
|
||||
│ ├── norm.py RMSNorm (PyTorch ref, fused kernel later)
|
||||
│ ├── router.py Router — token-to-expert assignment (273 lines)
|
||||
│ ├── embedding.py Token embedding + mHC init (stub)
|
||||
│ └── ffn.py FFN sub-block
|
||||
├── model/ Model assembly
|
||||
│ ├── config.py Model config
|
||||
│ ├── layer.py Transformer layer
|
||||
│ ├── layer_schedule.py Layer scheduling
|
||||
│ ├── mtp.py Multi-token prediction
|
||||
│ ├── sampler.py Token sampler
|
||||
│ └── dsv4.py Full model (stub — Phase 1)
|
||||
├── cache/ KV cache infra
|
||||
│ ├── allocator.py Cache memory allocator
|
||||
│ ├── block_table.py Paged cache block table
|
||||
│ ├── flush.py Cache flush
|
||||
│ ├── handle.py Cache handle
|
||||
│ ├── manager.py Cache manager
|
||||
│ ├── paged_cache.py Paged KV cache
|
||||
│ ├── prepare_forward.py Forward prep
|
||||
│ ├── schema.py Cache schema
|
||||
│ └── state_cache.py State cache (SWA ring buffer)
|
||||
├── loader/ Checkpoint I/O
|
||||
│ ├── hf_checkpoint.py HuggingFace checkpoint loader
|
||||
│ └── layout_convert.py Weight layout conversion
|
||||
└── reference/ Slow PyTorch oracles (never imported by production code)
|
||||
├── attention.py RoPE, KV cache, causal attention, SWA
|
||||
├── csa_attention.py CSA/HCA sparse attention
|
||||
├── compressor.py Compressor PyTorch example
|
||||
└── moe_pipeline.py MoE pipeline reference
|
||||
```
|
||||
|
||||
**Mental model:** `kernels/` → `ops/` → `layers/` → `model/` (dependency flows left to right). `reference/` and `loader/` are sidecars.
|
||||
|
||||
---
|
||||
|
||||
## Active Test Files
|
||||
|
||||
### FMHA (Stages A/B/C/D1) — in `tests/unit/`
|
||||
|
||||
| File | Stage | Status |
|
||||
|------|-------|--------|
|
||||
| `test_fmha_v3.py` | A+B | ✅ Full QK→identity softmax→PV, cosine 0.999999 |
|
||||
| `test_fmha_v3_12w.py` | A+B | ✅ 12-warp QK→PV, cosine 0.999999 |
|
||||
| `test_fmha_v3_stage_c.py` | C | ✅ Real online softmax + normalize, n=128 cos 0.973. **Also in module as `FmhaKernel`.** |
|
||||
| `test_fmha_v3_stage_d1.py` | D1 | ✅ hd=64/128/256 PASS (cos 0.999998, TMEM-P). hd=512 SMEM overflow. |
|
||||
| `test_fmha_v3_stage_d5b.py` | D5b | ✅ Python SWA+sink merge (cos 0.999994, LSE err=0.0) |
|
||||
| `test_d5c_fused.py` | D5c | ✅ Single-tile combined KV + sink bias (cos 0.999996) |
|
||||
| `test_d5c_multitile.py` | D5c | ✅ Multi-tile with Python KV merge + sink bias (cos 0.999996) |
|
||||
| `test_d1_*.py` | D1 | 🔨 Debug/diagnostic variants (hd512, regression, sweep, raw, debug) |
|
||||
| `test_paired_epilog.py` | C | ✅ Paired atom epilogue experiments |
|
||||
| `test_pv64_with_softmax.py` | B | ✅ (128,64) PV, single AB pipeline |
|
||||
| `test_128_128_vdiag.py` | A+B | ✅ (128,128) PV baseline |
|
||||
| `test_qkonly.py` | A | ✅ QK with split Q/KV pipelines |
|
||||
| `test_qk_softmax.py` | A+B | ✅ QK + identity softmax, no PV |
|
||||
|
||||
### MoE / GEMM — in `tests/unit/`
|
||||
|
||||
| File | What |
|
||||
|------|------|
|
||||
| `test_cutedsl.py` | NVFP4 grouped GEMM kernel |
|
||||
| `cudagraph_test.py` | Cudagraph capture + replay |
|
||||
| `layertest.py` | Per-layer correctness |
|
||||
| `test_custom_op.py` | torch.library custom ops |
|
||||
| `test_compile_custom_op.py` | Compile + warmup |
|
||||
| `test_fp4_roundtrip.py` | BF16 → NVFP4 → BF16 roundtrip |
|
||||
| `test_interleave.py` | Gate/up weight interleaving |
|
||||
| `test_interleave_gemm.py` | Interleaved GEMM correctness |
|
||||
| `test_fused_step1.py` | Fused SwiGLU GEMM |
|
||||
|
||||
---
|
||||
|
||||
## Test Harness
|
||||
|
||||
Scripts in `tests/` for running tests on the B200 (`root@45.76.247.107`):
|
||||
|
||||
### `run_test.sh` — Run a test in a screen session
|
||||
|
||||
```bash
|
||||
# On the B200:
|
||||
cd /root/dsv4-nvfp4-workspace/kernel
|
||||
bash tests/run_test.sh tests/unit/test_fmha_v3.py
|
||||
```
|
||||
|
||||
What it does:
|
||||
1. Kills any existing `kernel-test` screen and **SIGKILLs all child processes** (handles deadlocked GPU procs that ignore SIGHUP)
|
||||
2. Deletes the old log file
|
||||
3. Starts a new `screen -dmS kernel-test` running the test
|
||||
4. Logs output to `/tmp/kernel-test.log`
|
||||
5. Verifies the screen started
|
||||
|
||||
### `check_log.sh` — Check test progress
|
||||
|
||||
```bash
|
||||
bash tests/check_log.sh
|
||||
```
|
||||
|
||||
Shows the log contents and whether the screen is still running.
|
||||
|
||||
### Local → B200 workflow
|
||||
|
||||
```bash
|
||||
# 1. Edit locally, commit, push
|
||||
cd ~/dev/nvfp4-megamoe-kernel
|
||||
git add -A && git commit -m "my change" && git push
|
||||
|
||||
# 2. SSH to B200, pull, run
|
||||
ssh root@45.76.247.107
|
||||
cd /root/dsv4-nvfp4-workspace/kernel && git pull
|
||||
bash tests/run_test.sh tests/unit/test_fmha_v3_stage_c_full.py
|
||||
|
||||
# 3. Check results
|
||||
bash tests/check_log.sh
|
||||
```
|
||||
|
||||
### `fire_b200_test` — One-command local test runner
|
||||
|
||||
Lives in `~/.openclaw/workspace/fire_b200_test` (NOT in the repo — project-specific tooling).
|
||||
|
||||
```bash
|
||||
# From your local machine, one command to push, run, and get results:
|
||||
~/.openclaw/workspace/fire_b200_test tests/unit/test_fmha_v3.py
|
||||
```
|
||||
|
||||
What it does:
|
||||
1. Auto-commits and pushes any local changes
|
||||
2. SSH to B200, pulls, starts `run_test.sh` in a screen
|
||||
3. Polls every 15s until the screen exits
|
||||
4. Dumps the full test log to your terminal
|
||||
|
||||
**This is strictly for the DSV4 NVFP4 kernel project.** It hardcodes the B200 IP, repo paths, and git remote.
|
||||
|
||||
---
|
||||
|
||||
|
||||
## Stage C: Online Softmax — TMEM Layout Mismatch Issue
|
||||
|
||||
### Current Results (test_fmha_v3_stage_c.py)
|
||||
|
||||
| n | cos | Status |
|
||||
|---|-----|--------|
|
||||
| 128 | 0.973 | ⚠️ 3% error from TMEM layout mismatch |
|
||||
| 256 | 0.793 | ⚠️ Two TMEM round-trips compound the error |
|
||||
| 384+ | N/A | Pipeline doesn't cycle past 2 KV tiles |
|
||||
|
||||
### Root Cause: TMEM Layout Mismatch
|
||||
|
||||
The MMA instruction writes O to TMEM using the **C-fragment layout**. The `epilogue_tma_store` helper reads O from TMEM using `get_tmem_load_op`, which uses the **correct** C-fragment-compatible layout. **Raw PV output is perfect (cos 0.999998)** when `epilogue_tma_store` reads directly without any round-trip.
|
||||
|
||||
The problem appears when we do a **TMEM round-trip** (load O → modify → store back) using hand-constructed `Ld32x32bOp/St32x32bOp` atoms. These atoms use a different column mapping than the MMA's C-fragment layout, causing ~3% data corruption per round-trip. Both the NO-OP round-trip (previously used to "fix" layout) and the normalize round-trip (multiply by 1/row_sum) suffer from this error.
|
||||
|
||||
**Fix proven but not yet integrated:** The `epilogue_tmem_copy_and_partition` + `epilogue_smem_copy_and_partition` pattern from CUTLASS's `cutlass.utils.gemm.sm100` reads O from TMEM using the correct `get_tmem_load_op` layout and writes to SMEM using `get_smem_store_op`. This is a one-way trip (TMEM→reg→SMEM→GMEM) that eliminates the layout mismatch entirely. Integration requires proper `flat_divide` and `tma_partition` handling inside the kernel's warp-specific if blocks.
|
||||
|
||||
### Key Bug Fix: tOrP0 TMEM Column Offset (May 23)
|
||||
|
||||
The softmax warps store P at `tmem_p0_offset=32` FP32 columns (64 BF16 elements). PV MMA must read from the same offset. **`tOrP0` was missing this offset**, causing PV to read from TMEM column 0 (where S is) instead of column 32 (where P is). This was the root cause of NaN/zeros in D1 tests. Fixed with:
|
||||
```python
|
||||
if const_expr(self.tOrP0_offset > 0):
|
||||
tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout)
|
||||
else:
|
||||
tOrP0 = tOrP
|
||||
```
|
||||
Must use `const_expr` conditional (not Python `if`) because CuTeDSL compiles both branches, and `tOrP.iterator + 0` fails with MLIR type error.
|
||||
|
||||
### Architecture (6-warp, current)
|
||||
|
||||
```
|
||||
Warps 0-3: Softmax + Epilogue (row_max, row_sum, P store, O rescale, final normalize)
|
||||
Warp 4: MMA (QK, PV)
|
||||
Warp 5: TMA (Q/K/V load)
|
||||
```
|
||||
|
||||
### TMEM Layout
|
||||
|
||||
```
|
||||
Col 0-31: S (QK acc, 128 FP32 via Ld32x32bOp Repetition(32))
|
||||
Col 32-95: P (64 FP32 via St32x32bOp Repetition(32), register bridge BF16 view)
|
||||
Col 128+: O (PV acc, 64 FP32, rescale via Ld32x32bOp Repetition(16))
|
||||
```
|
||||
|
||||
### Remaining for Multi-Tile Production
|
||||
|
||||
1. **Fix TMEM layout mismatch** — replace hand-constructed atom round-trips with correction_epilog pattern
|
||||
2. **Pipeline state cycling for n≥384** — kv_stage=2 can only buffer 2 tiles
|
||||
3. **12-warp layout** — separate softmax/correction/epilogue warps
|
||||
4. **O rescale for kt > 0** — must also use paired atoms or correction_epilog
|
||||
|
||||
---
|
||||
|
||||
## CuTeDSL Constraints (hard-won)
|
||||
|
||||
1. **`vectorize=True` loops: ONLY load/store/print** — no fmax, no cmpf, no inner loops, no carry
|
||||
2. **`.reduce(cute.ReductionOp.MAX)`:** reduces ENTIRE C-fragment to scalar — global max, not per-row
|
||||
3. **`cute.arch.fmax`:** impure for vectorizer — use plain `range()` loop
|
||||
4. **TMA partition tensors have 4 modes:** `(((64,128),1), ?, KV_tiles, ?)` — `(None,0,None,0)` keeps mode 2 (KV tiles) free, `[None, kt]` indexes it
|
||||
5. **`tBgK[(None, None, 0, 0)]` pins mode 2 to 0** — silently reads tile 0 forever. Use `(None,0,None,0)` instead.
|
||||
6. **`softmax_done_bar` NamedBarrier is reusable** across tiles
|
||||
7. **Hand-constructed TMEM atoms corrupt data on round-trip:** `Ld32x32bOp` + `St32x32bOp` built independently introduce ~3% error. Use `get_tmem_load_op` + `get_smem_store_op` paired atoms for one-way trips.
|
||||
8. **CuTeDSL region isolation:** `flat_divide` and `tma_partition` can't be called inside `if warp_idx` blocks. Do partitioning outside `if` blocks or in regular (non-`@cute.kernel`) helper functions.
|
||||
9. **`composition` vs `logical_divide`:** Both re-tile a tensor, but produce different layouts. The CUTLASS `correction_rescale` uses `composition`, `correction_epilog` uses `logical_divide`. The copy atoms must match the tensor layout they were created with.
|
||||
10. **Variables in CuTeDSL `if` blocks are NOT visible in other `if` blocks.** Even when the condition is a compile-time constant (`self.use_smem_p`), CuTeDSL's MLIR lowering creates separate regions. Variables must be defined *unconditionally* before the first `if` that uses them. This applies across `if warp_idx == X` blocks, `for` loops, and nested branches. If a variable is set in `if not use_smem_p:` and read in another `if not use_smem_p:` inside a `for` loop inside an `if warp_idx < mma_warp_id:`, it won't be visible. Define all such variables before *any* branching.
|
||||
11. **`tOrP0` MUST include the `tmem_p0_offset` column offset.** The softmax warps store P at `tmem_p0_offset=32` (FP32 columns = 64 BF16 elements). PV MMA must read from the same offset. Missing this causes NaN/zeros (MMA reads S from column 0, not P from column 32). Use `const_expr` conditional: `if const_expr(self.tOrP0_offset > 0): tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout) else: tOrP0 = tOrP`. Cannot use `tOrP.iterator + 0` (MLIR OpResult + int fails).
|
||||
12. **LSE formula: `lse = ln(row_sum) + row_max * ln(2)`.** `row_max` is in the scale_log2 domain (`max(S * scale * log2(e))`). Multiply by `ln(2)` to convert to natural log domain: `attn_max = row_max * ln(2)`. So `lse = ln(row_sum) + row_max * ln(2)`. Verified: LSE err=0.000000.
|
||||
13. **CuTeDSL MLIR backend cannot handle complex pipeline loops.** The MLIR→PTX optimizer has exponential-or-worse behavior for kernels with TMA pipeline acquire/release inside loops. Both Python `range()` (unrolled) and `cutlass.range(unroll=1)` (runtime) trigger 3+ hour compilation for hd=512. Consider raw CUDA C++ for complex kernels. Pre-compilation + cubin caching is a viable workaround if the optimizer eventually finishes.
|
||||
14. **Guard dead code with `const_expr`.** CuTeDSL compiles BOTH branches of Python `if` statements. Use `const_expr(condition)` to eliminate dead code at compile time. Critical for: O rescale (only when n_kv_tiles>1), LSE (only when normalize=False), SMEM-P path (only when use_smem_p=True), k_sub path (only when n_k_sub_tiles>1).
|
||||
15. **External k_sub merge is mathematically impossible.** k_sub segments are additive in LOGIT space (S = S_0 + S_1), not attention weight space. You cannot recover softmax(S_0+S_1)@V from softmax(S_0)@V and softmax(S_1)@V. The D5 merge formula works for different token sets (additive in weight space), NOT for partial dot products. In-kernel k_sub accumulation before softmax is the only correct approach.
|
||||
16. **`pv_n_tile` reduction is the easiest SMEM knob.** At hd>256, reducing pv_n_tile from 256 to 128 shrinks sV and sC by 2× each. Cost: 4 PV GEMM passes instead of 2. But PV is typically not the bottleneck, and this is simpler than SMEM overlap or Q tiling.
|
||||
17. **O rescale TMEM round-trip with Ld32x32bOp/St32x32bOp is BROKEN.** Even a NO-OP round-trip (load O, multiply by 1.0, store back) corrupts data (cos 0.804 at s_k=256). The hand-constructed atoms don't preserve the C-fragment layout during round-trips. CUTLASS `correction_rescale` uses the same pattern — unclear why theirs works. **Workaround:** Python KV merge with per-segment LSE (cos 0.999998 for s_k up to 1024).
|
||||
18. **KV merge formula uses NORMALIZED outputs, not un-normalized.** The correct D5 merge for different token sets: `O = sum_i [exp(lse_i) * O_i_norm] / sum_i [exp(lse_i)]`. Using `O_i_unnorm` instead of `O_i_norm` gives cos ~0.91. The un-norm merge only works when both segments share the same `row_max` (global max), which isn't the case for separate KV segments.
|
||||
19. **`flat_divide` + `epilogue_tma_store` layout mismatch.** When using `cute.flat_divide` to create per-CTA GMEM views with runtime block coordinates (for multi-CTA grid), the resulting tensor layout is incompatible with CUTLASS's `epilogue_tma_store` pipeline, which expects the layout from `local_tile`. The tma_partition and epilogue must be refactored together to support multi-CTA grids.
|
||||
20. **`local_tile` does not support runtime coordinates.** `cute.local_tile(mQ, tiler, (runtime_val, None))` fails at trace time. Must use `cute.flat_divide(mQ, tiler)` instead, which creates a tiled view with all rest dimensions accessible via runtime indexing.
|
||||
21. **Sink bias domain correction.** Adding `attn_sink` directly to raw logits is wrong — it gets scaled by `scale_log2`. Fix: add `attn_sink / scale` to raw logits, so after `* scale_log2` it becomes `attn_sink * log2(e)`, correctly multiplying attention weights by `exp(attn_sink)`.
|
||||
22. **O normalization uses row_sum, NOT LSE.** `O_norm = O_unnorm / row_sum` is correct. `O_unnorm * exp(-LSE)` is WRONG because O_unnorm is max-shifted (divided by `2^row_max`), not raw `exp(S) @ V`. The kernel now outputs `row_sum` alongside LSE.
|
||||
23. **n_comp is compile-time, swa_len is runtime.** The `n_comp` parameter controls `const_expr` guards in the kernel and cannot vary between segments of the same kernel instance. `swa_len` is an `Int32` scalar and can vary per request. For multi-tile production, use a kernel cache keyed on `(n_comp, apply_sink_bias, head_dim, s_k)`.
|
||||
|
||||
---
|
||||
|
||||
## Key Lessons
|
||||
|
||||
1. **NEVER use `find_tmem_tensor_col_offset()` as TMEM placement.** It returns footprint size, not a safe offset.
|
||||
2. **FMHA never trusts DLPack tensor layouts.** Reconstruct V as (hd, s_k) MN-major inside CuTe.
|
||||
3. **TMEM allocation must be power of 2.**
|
||||
4. **Square hides bugs.** (128,128) worked for every wrong approach. Always test non-square.
|
||||
5. **St32x32bOp MUST use Float32**, NOT BFloat16. BFloat16 causes illegal memory access.
|
||||
6. **First PV ACCUMULATE=False.** Otherwise adds uninitialized TMEM to output.
|
||||
7. **FMHA P store uses QK C-fragment composition, NOT PV A-fragment.** Two aliases, same TMEM.
|
||||
8. **Register bridge: FP32 backing (store partition) + BF16 view (QK-load layout).** Do not skip this.
|
||||
9. **PRINT THE SHAPES. ALWAYS.** Reasoning about TMEM layouts without evidence is how we waste days.
|
||||
10. **Never assume TMEM round-trips are safe.** Verify with NO-OP tests before adding logic.
|
||||
|
||||
---
|
||||
|
||||
## Stage D: Full Decode Attention (revised May 23)
|
||||
|
||||
### Key Insight: The Indexer Solves Paging Upstream
|
||||
|
||||
The indexer now hands the kernel `selected_kv: [T, top_k, head_dim] BF16` — a **dense, materialized, dequantized** K/V tile. FMHA sees a dense `[T, top_k, 512]` tile, exactly like Stage A/B's existing `k` and `v` inputs. **The kernel doesn't need to know it's sparse.** Paged TMA, scattered HBM reads, FP8 dequantization — all handled by `gather_selected_kv` upstream.
|
||||
|
||||
The SWA branch is the only "irregular" thing: it reads from the state cache's ring buffer with a position mask. SWA is small (`n_win=128` per query), so it's a separate fused branch with a sink-weighted merge.
|
||||
|
||||
**One FMHA kernel serves all three DSV4 attention types:**
|
||||
- **CSA:** `compressed_kv` = top-k from indexer, `swa_kv` from cache → sink merge
|
||||
- **HCA:** `compressed_kv` = all classical pool entries (gather-all mode), `swa_kv` from cache → sink merge
|
||||
- **SWA-only (Flash layers 0-1):** `compressed_kv` = empty (`top_k=0`), only SWA runs. Sink merge degenerates to just `o_swa` after renormalization.
|
||||
|
||||
### Build Order
|
||||
|
||||
**D1 — Parameterize HEAD_DIM + SMEM-P** (~1 day, MOSTLY DONE)
|
||||
|
||||
Currently hardcoded at 64. Promote to constructor arg, thread through `_setup`. Test at 64, then 512 (DSV4's real value).
|
||||
|
||||
hd≤256: ✅ DONE. cos 0.999998 at hd=64/128/256. Both TMEM-P and SMEM-P paths work.
|
||||
|
||||
hd=512: ❌ BLOCKED. SMEM budget fixed (192KB, fits 232KB limit). Kernel structurally correct (tracer 0.8s). But CuTeDSL's MLIR→PTX backend optimizer hangs for 3+ hours when compiling the k_sub loop. External k_sub merge is mathematically impossible (k_sub segments additive in logit space, not weight space). Need either: (a) pre-compile offline + cache cubin, (b) add no-softmax mode for S accumulation in Python, or (c) write hd=512 path in raw CUDA C++.
|
||||
|
||||
Done when: identical result at HEAD_DIM=64 (regression), passes at HEAD_DIM=512 against FP32 oracle.
|
||||
|
||||
**D2 — Multi-query grid with head packing** (~1 day)
|
||||
|
||||
Grid changes from `(1, 1, 1)` to `(num_q_blocks, 1, batch)`. DSV4 is MQA — all `n_h=128` query heads share the same K/V. The query-head axis is folded into the M dimension of the Q tile: `M_tile = 128` covers `M = T * n_h` rows. At decode T is small (1-16), so packing heads into M fills the MMA. At prefill T=64, M is already 8192 with heads packed.
|
||||
|
||||
Done when: batch=4, T=64, n_h=128, num_kv_heads=1 produces correct attention against FP32 oracle.
|
||||
|
||||
**D3 — SWA sequence length mask** (~½ day)
|
||||
|
||||
The indexer's `top_k` is fixed (512 for Flash, 1024 for Pro). Compressed-K input is always `[T, top_k, head_dim]` with the same `top_k` at compile time.
|
||||
|
||||
What varies: the SWA window holds up to `n_win=128` tokens but starts with fewer. Add `swa_lens: [batch] int32` as kernel input. Mask SWA-branch logits to `-inf` where `swa_idx >= swa_lens[b]`.
|
||||
|
||||
Done when: batched input with varying SWA fill levels (some requests at position 50, some at 5000) produces correct masked output.
|
||||
|
||||
**D4 — Causal mask on SWA branch** (~½ day)
|
||||
|
||||
The compressed K the indexer selects is already from `s < floor(t/m)` (paper eq. 17). The indexer enforces causality at selection time. FMHA sees only causally-valid candidates. **The main path has no mask.**
|
||||
|
||||
The SWA branch needs a causal mask within the window. Add `is_causal: bool` constructor flag, apply `swa_idx > q_pos` masking to `-inf` in the SWA pass.
|
||||
|
||||
Done when: prefill mode produces correct output with the causal mask applied to SWA.
|
||||
|
||||
**D5 — SWA + sink merge** ✅ COMPLETE (May 26)
|
||||
|
||||
Per `dsv4/ops/decode_sparse.py`:
|
||||
```
|
||||
o = (exp(lse_sparse) * o_sparse + exp(attn_sink) * exp(lse_swa) * o_swa)
|
||||
/ (exp(lse_sparse) + exp(attn_sink) * exp(lse_swa))
|
||||
```
|
||||
|
||||
**Key insight (May 26):** This merge is mathematically identical to a single attention pass over concatenated KV with a logit bias on SWA positions:
|
||||
```
|
||||
S = [S_comp, S_swa + attn_sink]
|
||||
O = softmax(S) @ [V_comp; V_swa]
|
||||
```
|
||||
This means D5c is a **logit bias addition**, not a two-pass + merge kernel. D5d (fused in-kernel merge epilogue) is NO LONGER NEEDED.
|
||||
|
||||
**D5a ✅:** `normalize` flag + LSE + row_sums output. When False, emits un-normalized O + LSE + row_sum.
|
||||
|
||||
**D5b ✅:** Per-row LSE output (all 128 rows now write). Python KV merge with per-row LSE: cos 0.999994.
|
||||
|
||||
**D5c ✅:** Sink bias as logit modification. Parameters: `n_comp` (compressed KV length, compile-time), `apply_sink_bias` (compile-time flag), `sink_bias` (runtime FP32 tensor). Sink bias added to raw logits as `attn_sink / scale` so after `* scale_log2` it correctly becomes `attn_sink * log2(e)` in the exp2 domain. Multi-tile via Python KV merge: cos 0.999996.
|
||||
|
||||
**D5d:** NOT NEEDED. The sink bias approach makes a fused merge epilogue unnecessary.
|
||||
|
||||
Done when: ✅ End-to-end kernel produces correct attention against FP32 oracle.
|
||||
|
||||
**~~D5 (old) paged TMA~~ — REMOVED.** The indexer + gather handles all paging upstream.
|
||||
|
||||
### Kernel Architecture (after D5 — COMPLETE)
|
||||
|
||||
```
|
||||
Input: Q [T, n_h, 512], compressed_kv [T, top_k, 512], swa_kv [batch, n_win, 512]
|
||||
swa_lens [batch], sink_logits [n_h], request_ids [T]
|
||||
│
|
||||
└─ Single pass over concatenated KV [compressed_kv; swa_kv]:
|
||||
QK → online softmax (with sink bias on SWA, D3/D4 masking) → PV
|
||||
→ O_unnorm + LSE + row_sum
|
||||
→ External normalize: O = O_unnorm / row_sum
|
||||
→ (Multi-tile: Python KV merge across 128-token segments)
|
||||
```
|
||||
|
||||
### Reference Files
|
||||
|
||||
- Sink merge spec: `dsv4/ops/decode_sparse.py` (formula)
|
||||
- SWA decode: `dsv4/ops/decode_swa.py`
|
||||
- Attention reference: `dsv4/reference/attention.py`
|
||||
- CSA attention: `dsv4/reference/csa_attention.py`
|
||||
|
||||
### Stage C Note
|
||||
|
||||
When implementing D5a, Stage C's epilogue changes from "multiply by 1/row_sum" to "emit un-normalized o + lse". Defer this until D5. Through D1-D4, keep Stage C normalize as-is and test as standalone dense FMHA.
|
||||
|
||||
---
|
||||
|
||||
## Stage E: Production Extraction (revised May 23)
|
||||
|
||||
### E1 — File placement
|
||||
|
||||
`dsv4/kernels/attention/fmha.py`. Currently contains `FmhaKernel` (migrated from test, hd=64 TMEM-P). Will gain parameterized `head_dim` and SMEM-P path in D1. Constructor takes all dimensions and dtypes, no module-level constants.
|
||||
|
||||
### E2 — Constructor signature
|
||||
|
||||
```python
|
||||
class FmhaKernel:
|
||||
def __init__(
|
||||
self,
|
||||
head_dim: int, # 512 for DSV4
|
||||
num_query_heads: int, # 128 for Pro, 64 for Flash
|
||||
sliding_window: int, # 128
|
||||
top_k: int, # 512 (Flash) or 1024 (Pro)
|
||||
q_dtype=BFloat16,
|
||||
kv_dtype=BFloat16,
|
||||
o_dtype=BFloat16,
|
||||
qk_acc_dtype=Float32,
|
||||
pv_acc_dtype=Float32,
|
||||
is_causal: bool = False, # affects SWA mask only
|
||||
cta_group: tcgen05.CtaGroup = tcgen05.CtaGroup.ONE,
|
||||
cluster_shape_mn: tuple = (1, 1),
|
||||
):
|
||||
```
|
||||
|
||||
All architecture-level shapes from config flow into the constructor. No FMHA-internal magic numbers.
|
||||
|
||||
### E3 — Call signature
|
||||
|
||||
```python
|
||||
def __call__(
|
||||
self,
|
||||
q: torch.Tensor, # [T, n_h, head_dim] BF16
|
||||
compressed_kv: torch.Tensor, # [T, top_k, head_dim] BF16 — from indexer gather
|
||||
swa_kv: torch.Tensor, # [batch, n_win, head_dim] BF16 — from cache prep
|
||||
swa_lens: torch.Tensor, # [batch] int32
|
||||
sink_logits: torch.Tensor, # [n_h] FP32
|
||||
request_ids: torch.Tensor, # [T] int32 — maps query to its SWA slot
|
||||
o: torch.Tensor, # [T, n_h, head_dim] BF16 — preallocated
|
||||
stream: cuda.CUstream,
|
||||
):
|
||||
```
|
||||
|
||||
Notably absent: block_table, paged KV, inv_scale, FP8 dequant. All handled upstream.
|
||||
|
||||
### E4 — Kernel cache + warmup
|
||||
|
||||
Mirror `dsv4/ops/gemm_runner.py`'s `_compiled_kernel_cache`. Key on `(head_dim, num_query_heads, top_k, is_causal, ...)`. Pre-allocate at warmup, reuse at call. For DSV4, the cache has at most ~2 entries (Flash/Pro × causal/non).
|
||||
|
||||
### E5 — torch.library custom op
|
||||
|
||||
```python
|
||||
@torch.library.custom_op("dsv4::sparse_fmha_with_swa", mutates_args=("o",))
|
||||
def sparse_fmha_with_swa_op(
|
||||
q: torch.Tensor,
|
||||
compressed_kv: torch.Tensor,
|
||||
swa_kv: torch.Tensor,
|
||||
swa_lens: torch.Tensor,
|
||||
sink_logits: torch.Tensor,
|
||||
request_ids: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
runner_id: int,
|
||||
) -> None:
|
||||
runner = get_runner(runner_id)
|
||||
runner._run_impl(q, compressed_kv, swa_kv, swa_lens, sink_logits, request_ids, o)
|
||||
```
|
||||
|
||||
Mutates `o` (preallocated buffer). Consistent with cudagraphs.
|
||||
|
||||
### E6 — Reference parity hook
|
||||
|
||||
`dsv4/reference/attention.py` stays as the FP32 oracle. New test: `tests/unit/test_fmha_kernel.py`.
|
||||
|
||||
```python
|
||||
def test_sparse_fmha_matches_spec(T=64, n_h=128, top_k=1024, n_win=128, hd=512):
|
||||
q = torch.randn(T, n_h, hd, dtype=torch.bfloat16, device='cuda')
|
||||
ck = torch.randn(T, top_k, hd, dtype=torch.bfloat16, device='cuda')
|
||||
swa = torch.randn(4, n_win, hd, dtype=torch.bf16, device='cuda')
|
||||
swa_lens = torch.tensor([128, 50, 128, 75], dtype=torch.int32)
|
||||
sink = torch.randn(n_h, device='cuda')
|
||||
req_ids = torch.randint(0, 4, (T,), dtype=torch.int32)
|
||||
|
||||
# Oracle: pure FP32 spec
|
||||
o_sparse, lse_sparse = attention_with_lse_f32(q, ck, ck)
|
||||
o_swa, lse_swa = attention_swa_with_lse_f32(q, swa, swa, swa_lens, req_ids)
|
||||
e_sink = sink.exp()
|
||||
num = lse_sparse.exp().unsqueeze(-1) * o_sparse \
|
||||
+ e_sink[None, :, None] * lse_swa.exp().unsqueeze(-1) * o_swa
|
||||
den = lse_sparse.exp() + e_sink[None, :] * lse_swa.exp()
|
||||
expected = num / den.unsqueeze(-1)
|
||||
|
||||
# Kernel
|
||||
o = torch.empty_like(expected, dtype=torch.bfloat16)
|
||||
fmha = FmhaKernel(head_dim=hd, num_query_heads=n_h, sliding_window=n_win, top_k=top_k)
|
||||
fmha(q, ck, swa, swa_lens, sink, req_ids, o, stream=...)
|
||||
|
||||
torch.testing.assert_close(o.float(), expected, atol=5e-3, rtol=5e-3)
|
||||
```
|
||||
|
||||
### E7 — Cleanup
|
||||
|
||||
Delete all debug test files. `test_fmha_v3.py` becomes `dsv4/kernels/attention/fmha.py`. Only `tests/unit/test_fmha_kernel.py` remains as the attention test.
|
||||
|
||||
---
|
||||
|
||||
## Environment
|
||||
|
||||
- Server: root@45.76.247.107 (B200, 180 GiB HBM3e per GPU)
|
||||
- venv: `source /root/dsv4-nvfp4-workspace/venv/bin/activate`
|
||||
- PYTHONPATH: `/root/dsv4-nvfp4-workspace/kernel`
|
||||
- Model: `/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4`
|
||||
- vLLM repo: `/root/dsv4-nvfp4-workspace/vllm` (modified for Blackwell)
|
||||
- CUTLASS FMHA reference: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py`
|
||||
- Local CUTLASS clone: `/home/openclaw/dev/cutlass`
|
||||
@@ -14,7 +14,6 @@ from cutlass.utils.blackwell_helpers import get_smem_store_op
|
||||
from cutlass.utils.gemm.sm100 import (
|
||||
transform_partitioned_tensor_layout,
|
||||
epilogue_tmem_copy_and_partition,
|
||||
epilogue_smem_copy_and_partition,
|
||||
)
|
||||
import cuda.bindings.driver as cuda
|
||||
import cutlass.torch as ct
|
||||
@@ -399,50 +398,24 @@ class FmhaKernel:
|
||||
scale_log2 = Float32(self.scale_softmax_log2)
|
||||
|
||||
# ============================================================
|
||||
# CORRECTION EPILOGUE SETUP (paired atoms, one-way TMEM→REGS→SMEM→GMEM)
|
||||
# Pattern proven in dsv4/kernels/gemm/fused_swiglu.py (lines 2021-2076).
|
||||
# Replaces broken hand-constructed TMEM round-trip (D1.5 fix).
|
||||
# O RESCALE PAIRED ATOMS (D1.5 fix, multi-KV-tile only)
|
||||
# ============================================================
|
||||
# Build the O accumulator tensor at the TMEM pointer + o0_offset.
|
||||
# This is the TMEM source for the correction epilogue.
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
tCtO_transformed = transform_partitioned_tensor_layout(tCtO_base)
|
||||
tCgC_transformed = transform_partitioned_tensor_layout(tCgC)
|
||||
|
||||
# Paired atoms: tiled_copy_t2r (TMEM→REGS) and tiled_copy_r2s (REGS→SMEM)
|
||||
# share addressing so the round trip is lossless.
|
||||
tiled_copy_t2r, tTR_tO_base, tTR_rO = epilogue_tmem_copy_and_partition(
|
||||
self, sfw_idx, tCtO_transformed, tCgC_transformed,
|
||||
epi_tile, self.use_2cta_instrs,
|
||||
)
|
||||
# Register tile for BF16 output (after normalization/conversion)
|
||||
tTR_rC = cute.make_rmem_tensor(tTR_rO.shape, self.c_dtype)
|
||||
tiled_copy_r2s, tRS_rC, tRS_sC = epilogue_smem_copy_and_partition(
|
||||
self, tiled_copy_t2r, tTR_rC, sfw_idx, sC,
|
||||
)
|
||||
|
||||
# TMA partition for SMEM → GMEM store (same as MoE pattern)
|
||||
tCgC_epi = cute.flat_divide(tCgC_transformed, epi_tile)
|
||||
bSG_sC, bSG_gC_partitioned = cpasync.tma_partition(
|
||||
tma_c, 0, cute.make_layout(1),
|
||||
cute.group_modes(sC, 0, 2),
|
||||
cute.group_modes(tCgC_epi, 0, 2),
|
||||
)
|
||||
# Single-CTA grid: all block coordinates are 0
|
||||
bSG_gC = bSG_gC_partitioned[(None, None, None, 0, 0, 0)]
|
||||
|
||||
# Epilogue sync barrier + C-store pipeline
|
||||
epilog_sync_bar = pipeline.NamedBarrier(
|
||||
barrier_id=self.epilog_sync_bar_id,
|
||||
num_threads=32 * len(self.epilogue_warp_id),
|
||||
)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
|
||||
# Group modes for the subtile iteration (same pattern as fused_swiglu)
|
||||
tTR_tO_grouped = cute.group_modes(tTR_tO_base, 3, cute.rank(tTR_tO_base))
|
||||
bSG_gC_grouped = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC))
|
||||
subtile_cnt = cute.size(tTR_tO_grouped.shape, mode=[3])
|
||||
# Replace broken hand-constructed Ld32x32bOp/St32x32bOp round-trip
|
||||
# with paired atoms from epilogue_tmem_copy_and_partition.
|
||||
# The paired atoms share addressing, so the TMEM→REGS→modify→TMEM
|
||||
# cycle is lossless (unlike independently constructed atoms).
|
||||
# Only needed when n_kv_tiles > 1 (multi-KV-tile O rescale).
|
||||
# ============================================================
|
||||
if const_expr(self.n_kv_tiles > 1):
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
tCtO_transformed = transform_partitioned_tensor_layout(tCtO_base)
|
||||
tCgC_transformed = transform_partitioned_tensor_layout(tCgC)
|
||||
tiled_copy_t2r, tTR_tO_base, tTR_rO = epilogue_tmem_copy_and_partition(
|
||||
self, sfw_idx, tCtO_transformed, tCgC_transformed,
|
||||
epi_tile, self.use_2cta_instrs,
|
||||
)
|
||||
tTR_tO_grouped = cute.group_modes(tTR_tO_base, 3, cute.rank(tTR_tO_base))
|
||||
subtile_cnt = cute.size(tTR_tO_grouped.shape, mode=[3])
|
||||
|
||||
for kt in range(self.n_kv_tiles):
|
||||
si_handle = s_cons.wait_and_advance()
|
||||
@@ -543,21 +516,20 @@ class FmhaKernel:
|
||||
_sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0]
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
# O rescale for kt > 0 using paired atoms (D1.5 fix).
|
||||
# One-way TMEM→REGS (multiply by acc_scale) → TMEM via paired store atom.
|
||||
# The paired atom's addressing is consistent for both load and store,
|
||||
# TMEM→REGS (paired load), multiply by acc_scale,
|
||||
# REGS→TMEM (paired store via retile_to_S).
|
||||
# The paired atom's addressing is consistent for load and store,
|
||||
# so this does NOT suffer from the layout mismatch that broke the
|
||||
# hand-constructed Ld32x32bOp/St32x32bOp round-trip.
|
||||
if const_expr(self.n_kv_tiles > 1):
|
||||
if kt > 0:
|
||||
for subtile_idx in cutlass.range(subtile_cnt, unroll=1):
|
||||
for subtile_idx in range(subtile_cnt):
|
||||
tTR_tO_mn = tTR_tO_grouped[(None, None, None, subtile_idx)]
|
||||
cute.copy(tiled_copy_t2r, tTR_tO_mn, tTR_rO)
|
||||
# Modify in registers — acc_scale is per-row, same for all elements
|
||||
# in this thread's fragment. (Each thread handles one row.)
|
||||
# Modify in registers
|
||||
for k in cutlass.range(cute.size(tTR_rO), vectorize=True):
|
||||
tTR_rO[k] = tTR_rO[k] * acc_scale
|
||||
# Store back to TMEM via the paired atom's store direction.
|
||||
# Use retile_to_S() to get the store-compatible layout.
|
||||
# Store back to TMEM via paired atom's store direction
|
||||
cute.copy(tiled_copy_t2r.retile_to_S(), tTR_rO, tTR_tO_mn)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
@@ -568,52 +540,32 @@ class FmhaKernel:
|
||||
final_o_bar.arrive_and_wait()
|
||||
|
||||
# ============================================================
|
||||
# CORRECTION EPILOGUE: one-way TMEM → REGS → SMEM → GMEM
|
||||
# EPILOGUE: TMA store O to GMEM + compute LSE
|
||||
# ============================================================
|
||||
# Uses paired atoms from epilogue_tmem_copy_and_partition /
|
||||
# epilogue_smem_copy_and_partition (same pattern as fused_swiglu.py).
|
||||
# This is the D1.5 fix: no TMEM round-trip corruption because we
|
||||
# use library-paired atoms for the one-way trip through registers.
|
||||
# The raw un-normalized O in TMEM is perfect (cos 0.999998).
|
||||
# We use epilogue_tma_store which reads O from TMEM directly via
|
||||
# the correct get_tmem_load_op layout — no round-trip needed.
|
||||
#
|
||||
# For multi-KV-tile: the paired-atom O rescale above (kt>0) ensures
|
||||
# O is correctly rescaled before this epilogue reads it.
|
||||
#
|
||||
# External normalization (D5a path): kernel outputs un-normalized O +
|
||||
# LSE + row_sum. Caller normalizes using O_norm = O_unnorm / row_sum.
|
||||
# This is exact and composes with D5c sink bias merge.
|
||||
# ============================================================
|
||||
|
||||
# Compute inv_row_sum for normalization (in registers, no TMEM round-trip)
|
||||
_row_max_safe = row_max
|
||||
if row_max == -cutlass.Float32.inf:
|
||||
_row_max_safe = Float32(0.0)
|
||||
if const_expr(self.normalize):
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
|
||||
# Iterate over output subtiles: TMEM → REGS → (normalize/convert) → SMEM → GMEM
|
||||
for subtile_idx in cutlass.range(subtile_cnt, unroll=1):
|
||||
# TMEM → REGS using the paired atom (lossless)
|
||||
tTR_tO_mn = tTR_tO_grouped[(None, None, None, subtile_idx)]
|
||||
cute.copy(tiled_copy_t2r, tTR_tO_mn, tTR_rO)
|
||||
|
||||
# Register-level modification:
|
||||
# - normalize=True: divide by row_sum, cast to BF16
|
||||
# - normalize=False: just cast to BF16 (un-normalized O for D5a)
|
||||
acc_vec = tiled_copy_r2s.retile(tTR_rO).load()
|
||||
if const_expr(self.normalize):
|
||||
acc_vec = acc_vec * inv_row_sum
|
||||
tRS_rC.store(acc_vec.to(self.c_dtype))
|
||||
|
||||
# REGS → SMEM
|
||||
c_buffer = subtile_idx % self.num_c_stage
|
||||
cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)])
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
epilog_sync_bar.arrive_and_wait()
|
||||
|
||||
# SMEM → GMEM (one warp does the TMA store)
|
||||
if warp_idx == self.epilogue_warp_id[0]:
|
||||
cute.copy(
|
||||
tma_c,
|
||||
bSG_sC[(None, c_buffer)],
|
||||
bSG_gC_grouped[(None, subtile_idx)],
|
||||
)
|
||||
c_pipe.producer_commit()
|
||||
c_pipe.producer_acquire()
|
||||
epilog_sync_bar.arrive_and_wait()
|
||||
|
||||
# TMA store via CUTLASS epilogue_tma_store (reads raw O from TMEM)
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self.num_acc_stage
|
||||
)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, sfw_idx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile,
|
||||
0, const_expr(lambda x: x), (0, 0, 0),
|
||||
acc_cons_st, acc_pipe, c_pipe,
|
||||
)
|
||||
c_pipe.producer_tail()
|
||||
|
||||
# Compute LSE: lse = ln(row_sum) + row_max * ln(2)
|
||||
@@ -624,6 +576,9 @@ class FmhaKernel:
|
||||
# sfw_idx maps directly to the row index in the attention matrix.
|
||||
# All 128 threads write independently to mLSE[sfw_idx] — no sync needed.
|
||||
if const_expr(not self.normalize):
|
||||
_row_max_safe = row_max
|
||||
if row_max == -cutlass.Float32.inf:
|
||||
_row_max_safe = Float32(0.0)
|
||||
_ln2 = Float32(0.6931471805599453) # ln(2)
|
||||
lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2
|
||||
mLSE[sfw_idx, Int32(0), Int32(0)] = lse_val
|
||||
|
||||
Reference in New Issue
Block a user