43 lines
2.2 KiB
Markdown
43 lines
2.2 KiB
Markdown
# STATUS — DSV4 Inference Kernel (post-cleanup 2026-05-30)
|
||
|
||
## Production Path
|
||
|
||
**One FMHA kernel:** `fmha_6warp_tma_multirow_multitile.cuh` — 6-warp, TMA, UMMA, tcgen05.mma SS, in-kernel multi-tile SMEM accumulator, multi-row softmax. Loaded via `fmha_multitile_capi.cu` (C API) + `fmha_multitile_op.py` (ctypes). Dispatched from `production.py`.
|
||
|
||
**Head dims:** 64, 128, 256, 512. **T=1 decode** proven (cos ≥ 0.999996). **T>1 prefill** via multi-row path (P5, P7).
|
||
|
||
**No CuTeDSL runtime dependency.** All kernel code is raw CUDA C++. CuTeDSL (fmha.py) deleted; Python KV merge deleted; `FmhaKernel` deleted.
|
||
|
||
## Live Attention Files
|
||
|
||
| File | Role |
|
||
|---|---|
|
||
| `fmha_6warp_tma_multirow_multitile.cuh` | Production kernel |
|
||
| `fmha_common.cuh` | Shared types/defs |
|
||
| `fmha_tma.cuh` | TMA descriptor helpers |
|
||
| `fmha_umma_desc.cuh` | UMMA descriptor creation |
|
||
| `fmha_multitile_capi.cu` | C API wrapper (nvcc compiled) |
|
||
| `fmha_multitile_op.py` | ctypes loader |
|
||
| `production.py` | Public API (dsv4_attention) |
|
||
| `__init__.py` | Bridge to layers (sparse/dense/swa) |
|
||
|
||
## Stage E Checklist (from ROADMAP/NEXT_PRIORITIES_PART_2)
|
||
|
||
- [x] **E1:** Wire `LayerCacheHandle` → gather methods ✅
|
||
- [x] **E2:** E2E smoke tests (SWA + CSA + HCA) ✅
|
||
- [x] **E3:** DSV4Model class ✅
|
||
- [x] **E4:** Removed `torch.cuda.synchronize` ✅
|
||
- [x] **E5:** Batch loop folded into kernel grid ✅
|
||
- [x] **Single-shot inference:** Full 61-layer pipeline runs on B200 ✅
|
||
- FMHA kernel verified: hd=512, 128 query heads, all layers correct
|
||
- Garbage output expected without mHC/MoE/KV-cache (architecture gaps, not kernel)
|
||
- [ ] **E6:** FP4 output fusion for FMHA → wo_a
|
||
- [ ] **E7:** Lightning indexer FP4 tensor-core scoring
|
||
- [ ] **E8:** Multi-CTA grid for prefill
|
||
- [ ] **E9:** CUDA graph capture
|
||
|
||
## Cleanup Done (C1–C7)
|
||
|
||
- Deleted: fmha.py, fmha_sm100.cuh, fmha_sm100_tc.cuh, fmha_sm100_launch.cu, fmha_epilogue_sm100.cuh, fmha_qk_verify.cuh (moved to tests/unit/), decode_sparse.py, decode_swa.py, kernels/decode/, 46 test_d*.py probes, root scratch files, archive/ (moved to archived_plans/code_archive/)
|
||
- Removed: FmhaKernel import, CuTeDSL slow path, Python KV merge, torch.cuda.synchronize in _run_fmha_segmented (function deleted)
|