Update docs: D1.5 TMEM round-trip fundamentally broken, Python KV merge is production path
This commit is contained in:
@@ -165,21 +165,19 @@ pv_n_tile shown in parens; hd>256 uses pv_n_tile=128 (4 PV GEMM passes) to fit S
|
||||
|
||||
---
|
||||
|
||||
## D1.5: Correction Epilogue (TMEM Round-Trip Error) + O Rescale
|
||||
## D1.5: Multi-KV-Tile O Rescale — TMEM Round-Trip Fundamentally Broken (2026-05-26)
|
||||
|
||||
**Issue 1: TMEM round-trip error.** Hand-constructed `Ld32x32bOp`/`St32x32bOp` atoms don't preserve the C-fragment layout during TMEM round-trips (load→modify→store). Causes ~3% error per round-trip.
|
||||
**Root cause:** `Ld32x32bOp` and `St32x32bOp` have DIFFERENT column mappings at the hardware level. No layout transformation can fix this.
|
||||
|
||||
**Current workaround:** Kernel outputs un-normalized O + LSE. No in-kernel normalization needed. External normalization is exact.
|
||||
**Investigation (2026-05-26):**
|
||||
- Attempted paired-atom approach using `epilogue_tmem_copy_and_partition` for TMEM load + `retile_to_S()` for TMEM store — `retile_to_S()` doesn't exist on TiledCopy
|
||||
- Even if API worked, the round-trip is fundamentally broken: load and store atoms have different column mappings
|
||||
- Full correction epilogue (one-way TMEM→REGS→SMEM→GMEM) causes 20+ minute MLIR compilation hang
|
||||
- The compilation hang is from `transform_partitioned_tensor_layout` + `flat_divide` + `tma_partition` inside the softmax warp block
|
||||
|
||||
**Proper fix (future):** Use CUTLASS `epilogue_tmem_copy_and_partition` + `epilogue_smem_copy_and_partition` pattern with paired atoms. One-way trip: TMEM → registers (normalize) → SMEM → GMEM.
|
||||
**Production path:** Python KV merge (cos 0.999998 for s_k up to 1024).
|
||||
|
||||
**Priority:** MEDIUM. Not a correctness blocker (external normalization is exact). Would enable in-kernel normalization for D5c/d. Also blocks NVFP4-1.2 (inverse RoPE FP4 fuse).
|
||||
|
||||
**Issue 2: O rescale for kt>0 (multi-KV-tile).** CONFIRMED BROKEN (May 24). Even a NO-OP round-trip (load O, multiply by 1.0, store back) produces cos 0.804 at s_k=256. The `Ld32x32bOp`/`St32x32bOp` atoms corrupt data regardless of the rescale factor. The same atoms in CUTLASS `correction_rescale` use the same pattern — unclear why theirs works with 12-warp layout but ours fails with 6-warp.
|
||||
|
||||
**Workaround (VERIFIED):** Python KV merge with per-segment LSE. Run kernel with s_k=128 (1 KV tile, no rescale) per segment. Merge using: `O = sum_i [exp(lse_i) * O_i_norm] / sum_i [exp(lse_i)]`. Verified cos 0.999998 for s_k=256, 384, 512, 1024 at hd=64. **Caveat:** requires per-row LSE output (currently only row 0 is written; per-row verified correct with max err 0.000001 but CuTe tensor indexing needs work for full per-row output).
|
||||
|
||||
**Priority:** HIGH for production (DSV4 Pro needs s_k=1024). The Python merge works but adds kernel launch overhead (8 launches for s_k=1024). Fused in-kernel rescale requires fixing the TMEM round-trip or using a different accumulator strategy.
|
||||
**Future fix:** Restructure PV to accumulate into REGS/SMEM instead of TMEM. Enables one-way correction epilogue. Major refactor (1-2 days).
|
||||
|
||||
---
|
||||
|
||||
|
||||
Reference in New Issue
Block a user