Update README: reflect Stage C migration, built indexer/router/compressor, SMEM-P path, CuTeDSL scoping lesson

This commit is contained in:
2026-05-23 05:42:44 +00:00
parent b39301ebc6
commit 787a25516d

View File

@@ -98,8 +98,8 @@ Legend:
│ │ [✓] Residual mHC pre + post mix │ │
│ │ [~] Norms + RoPE RMSNorm + partial RoPE │ │
│ │ [✓] Q / KV projection NVFP4 linears + LoRA │ │
│ │ [~] Token compressor CSA m=4 / HCA m=128 │ │
│ │ [] Indexer + top-k CSA only, FP4 QK │ │
│ │ [] Token compressor CSA m=4 / HCA m=128 │ │
│ │ [] Indexer + top-k CSA, FP32 dot + top-k │ │
│ │ [~] FMHA core QK → online softmax → PV │ │
│ │ + SWA branch + sink merge │ │
│ │ [✓] Output projection inv RoPE + wo_a grouped + wo_b │ │
@@ -107,8 +107,8 @@ Legend:
│ │
│ ┌─ FFN sub-block ────────────────────────────────────────────────┐ │
│ │ [✓] Residual mHC pre + post mix │ │
│ │ [~] Pre-FFN norm RMSNorm │ │
│ │ [] Router sqrt(softplus) + topk + hash │ │
│ │ [] Pre-FFN norm RMSNorm │ │
│ │ [] Router sqrt(softplus) + topk + hash │ │
│ │ [✓] Routed MoE fused SwiGLU L1 + L2 │ │
│ │ [✓] Shared expert NVFP4 single-group GEMM │ │
│ └────────────────────────────────────────────────────────────────┘ │
@@ -129,24 +129,23 @@ Legend:
Summary
-------
Built [✓] : 6 — mHC ×2, Q/KV proj, output proj, routed MoE,
shared expert
Partial [~] : 4 — norms+RoPE, token compressor, FMHA core,
pre-FFN norm
To build [✗] : 8 — embedding+init, indexer+top-k, router,
final norm, LM head, MTP, sampler, KV cache
Built [✓] : 9 — mHC ×2, Q/KV proj, output proj, routed MoE,
shared expert, token compressor, indexer+topk,
router, pre-FFN norm
Partial [~] : 3 — norms+RoPE, FMHA core
To build [✗] : 6 — embedding+init, final norm, LM head, MTP, sampler, KV cache
```
---
## Status (May 23, 2026 — 02:55 UTC)
## Status (May 23, 2026 — 05:30 UTC)
| Stage | Status | Description |
|-------|--------|-------------|
| A | ✅ COMPLETE | Q@K^T via tcgen05.mma → TMEM → GMEM |
| B | ✅ COMPLETE | QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving) |
| C | ⚠️ SINGLE-TILE OK, MULTI-TILE 3% ERROR | n=128 cos 0.973. n=256 cos 0.793. TMEM layout mismatch. See below. |
| D1 | TODO | Parameterize HEAD_DIM (64 → 512) |
| C | ✅ MIGRATED TO MODULE | Real online softmax + normalize. n=128 cos 0.973. Migrated to `dsv4/kernels/attention/fmha.py` as `FmhaV3StageC`. TMEM layout mismatch still present (3% error). |
| D1 | 🔨 IN PROGRESS | Parameterize HEAD_DIM (64 → 512). SMEM-P path for hd>64 (register→SMEM copy TODO). |
| D2 | TODO | Multi-query grid with head packing (128 Q heads, MQA) |
| D3 | TODO | SWA sequence length mask (swa_lens per batch) |
| D4 | TODO | Causal mask on SWA branch only |
@@ -161,8 +160,11 @@ Summary
dsv4/
├── kernels/ Pure GPU code (CuTeDSL @cute.jit, .cu files)
│ ├── gemm/ NVFP4 MoE GEMM kernels (grouped, fused_swiglu, dense, scheduler)
│ ├── attention/ FMHA kernel (stub — extraction is Stage E)
│ ├── compressor/ CSA/HCA token-level compressor
│ ├── attention/ FMHA kernel — FmhaV3StageC (migrated from tests), SMEM-P stub
│ ├── compressor/ CSA/HCA token-level compressor (CuTeDSL, 419 lines)
│ ├── indexer/ CSA indexer — score+topk (FP32 dot products, top-k selection)
│ ├── router/ Dense router decode kernel (warp-specialized persistent GEMM)
│ ├── cache/ Cache kernels — append_swa (write KV to split state cache layout)
│ ├── decode/ Decode-time attention (sparse, SWA — future)
│ └── cuda/ Raw .cu files (deinterleave_quantize, sparse_topk_metadata)
├── ops/ PyTorch ↔ kernel bridges
@@ -173,17 +175,40 @@ dsv4/
│ ├── decode_sparse.py native_sparse_decode dispatcher
│ ├── decode_swa.py native_swa_decode dispatcher
│ ├── rope.py Forward + inverse RoPE
── topk.py Python wrapper for sparse_topk_metadata.cu
── topk.py Python wrapper for sparse_topk_metadata.cu
│ ├── topk_select.py Top-k selection wrapper
│ └── router.py Router op bridge
├── layers/ nn.Module-style components
│ ├── linear.py Nvfp4Linear
│ ├── grouped_linear.py Nvfp4GroupedLinear
│ ├── moe.py Nvfp4MoE
│ ├── shared_expert.py Nvfp4SharedExpert
│ ├── mhc.py mHCLayer
── (stubs: attention, ffn, router, norm, embedding)
├── model/ Model assembly (stubs — Phase 1)
├── cache/ KV cache infra (stubs — Phase 3)
├── loader/ Checkpoint I/O (stubs — Phase 1)
── attention.py DSV4 attention sub-block (CSA/HCA/SWA variants, 245 lines)
│ ├── norm.py RMSNorm (PyTorch ref, fused kernel later)
│ ├── router.py Router — token-to-expert assignment (273 lines)
│ ├── embedding.py Token embedding + mHC init (stub)
│ └── ffn.py FFN sub-block
├── model/ Model assembly
│ ├── config.py Model config
│ ├── layer.py Transformer layer
│ ├── layer_schedule.py Layer scheduling
│ ├── mtp.py Multi-token prediction
│ ├── sampler.py Token sampler
│ └── dsv4.py Full model (stub — Phase 1)
├── cache/ KV cache infra
│ ├── allocator.py Cache memory allocator
│ ├── block_table.py Paged cache block table
│ ├── flush.py Cache flush
│ ├── handle.py Cache handle
│ ├── manager.py Cache manager
│ ├── paged_cache.py Paged KV cache
│ ├── prepare_forward.py Forward prep
│ ├── schema.py Cache schema
│ └── state_cache.py State cache (SWA ring buffer)
├── loader/ Checkpoint I/O
│ ├── hf_checkpoint.py HuggingFace checkpoint loader
│ └── layout_convert.py Weight layout conversion
└── reference/ Slow PyTorch oracles (never imported by production code)
├── attention.py RoPE, KV cache, causal attention, SWA
├── csa_attention.py CSA/HCA sparse attention
@@ -197,13 +222,16 @@ dsv4/
## Active Test Files
### FMHA (Stages A/B/C) — in `tests/unit/`
### FMHA (Stages A/B/C/D1) — in `tests/unit/`
| File | Stage | Status |
|------|-------|--------|
| `test_fmha_v3.py` | A+B | ✅ Full QK→identity softmax→PV, cosine 0.999999 |
| `test_fmha_v3_12w.py` | A+B | ✅ 12-warp QK→PV, cosine 0.999999 |
| `test_fmha_v3_stage_c.py` | C | ⚠️ Real online softmax + normalize, n=128 cos 0.973 (TMEM layout mismatch), n=256 cos 0.793 |
| `test_fmha_v3_stage_c.py` | C | Real online softmax + normalize, n=128 cos 0.973. **Also migrated to `dsv4/kernels/attention/fmha.py` as `FmhaV3StageC`.** |
| `test_fmha_v3_stage_d1.py` | D1 | 🔨 Parameterized hd + SMEM-P path (WIP) |
| `test_d1_*.py` | D1 | 🔨 Debug/diagnostic variants (hd512, regression, sweep, raw, debug) |
| `test_paired_epilog.py` | C | ✅ Paired atom epilogue experiments |
| `test_pv64_with_softmax.py` | B | ✅ (128,64) PV, single AB pipeline |
| `test_128_128_vdiag.py` | A+B | ✅ (128,128) PV baseline |
| `test_qkonly.py` | A | ✅ QK with split Q/KV pipelines |
@@ -375,6 +403,7 @@ Col 128+: O (PV acc, 64 FP32, rescale via Ld32x32bOp Repetition(16))
7. **Hand-constructed TMEM atoms corrupt data on round-trip:** `Ld32x32bOp` + `St32x32bOp` built independently introduce ~3% error. Use `get_tmem_load_op` + `get_smem_store_op` paired atoms for one-way trips.
8. **CuTeDSL region isolation:** `flat_divide` and `tma_partition` can't be called inside `if warp_idx` blocks. Do partitioning outside `if` blocks or in regular (non-`@cute.kernel`) helper functions.
9. **`composition` vs `logical_divide`:** Both re-tile a tensor, but produce different layouts. The CUTLASS `correction_rescale` uses `composition`, `correction_epilog` uses `logical_divide`. The copy atoms must match the tensor layout they were created with.
10. **Variables in CuTeDSL `if` blocks are NOT visible in other `if` blocks.** Even when the condition is a compile-time constant (`self.use_smem_p`), CuTeDSL's MLIR lowering creates separate regions. Variables must be defined *unconditionally* before the first `if` that uses them. This applies across `if warp_idx == X` blocks, `for` loops, and nested branches. If a variable is set in `if not use_smem_p:` and read in another `if not use_smem_p:` inside a `for` loop inside an `if warp_idx < mma_warp_id:`, it won't be visible. Define all such variables before *any* branching.
---
@@ -408,10 +437,14 @@ The SWA branch is the only "irregular" thing: it reads from the state cache's ri
### Build Order
**D1 — Parameterize HEAD_DIM** (~½ day)
**D1 — Parameterize HEAD_DIM + SMEM-P** (~1 day, in progress)
Currently hardcoded at 64. Promote to constructor arg, thread through `_setup`. Test at 64, then 512 (DSV4's real value).
**Two P staging paths:**
- **TMEM-P** (hd≤64): P stored to TMEM via register bridge. PV reads from TMEM. Proven at cos 0.973.
- **SMEM-P** (hd>64): P stored to SMEM via PV A-operand layout. PV reads from SMEM. Avoids QK↔PV TMEM layout mismatch at large hd. **Register→SMEM copy needs `make_tiled_copy_C(store_atom, qk_mma)` to partition threads by QK C-fragment.** The SMEM rendezvous pattern: softmax writes P to SMEM at logical (row, col) addresses using `p_smem_s` layout, MMA warp reads from same SMEM. Barrier in between.
Risk at HEAD_DIM=512: TMEM column budget. `_setup` already does `find_tmem_tensor_col_offset(tOtO)` dynamically. Verify the total fits in 512 TMEM columns. If not, reduce `kv_stage` from 2 to 1 (lose K/V double-buffering) before sacrificing math.
Done when: identical result at HEAD_DIM=64 (regression), passes at HEAD_DIM=512 against FP32 oracle.
@@ -499,7 +532,7 @@ When implementing D5a, Stage C's epilogue changes from "multiply by 1/row_sum" t
### E1 — File placement
`dsv4/kernels/attention/fmha.py`. Class: `FmhaKernel`. Constructor takes all dimensions and dtypes, no module-level constants. Drop `if __name__ == "__main__"` test block.
`dsv4/kernels/attention/fmha.py`. Currently contains `FmhaV3StageC` (exact migration from test). Will become `FmhaKernel` once D1 parameterization is complete and the SMEM-P path is working. Constructor takes all dimensions and dtypes, no module-level constants.
### E2 — Constructor signature
@@ -524,6 +557,8 @@ class FmhaKernel:
All architecture-level shapes from config flow into the constructor. No FMHA-internal magic numbers.
**Naming convention:** The class will be `FmhaKernel` once D1 is complete (replacing the current `FmhaV3StageC`). The progression: `FmhaV3StageC` (hd=64, TMEM-P only) → `FmhaKernel` (parameterized hd, TMEM-P + SMEM-P). The old name stays in the test file for regression.
### E3 — Call signature
```python