From 44fc59b8fa88d7f935b4c2fdfb2decdd003828da Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 19:53:45 +0000 Subject: [PATCH] auto: pre-test commit --- STAGE_D1.3.md | 44 ++++++++++++++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/STAGE_D1.3.md b/STAGE_D1.3.md index c719c5c5..21a1c03d 100644 --- a/STAGE_D1.3.md +++ b/STAGE_D1.3.md @@ -132,16 +132,36 @@ else: - ❌ Need to implement full 128-value mapping per thread - ❌ Need to get QK coordinates for each of thread's 128 P values -### Next Steps: -1. Create coordinate tensor `cP_qk = cute.make_identity_tensor(tStS0.shape)` -2. Partition it same way as `rP_bf16` (through `tTMEM_LOADcP`) -3. In softmax loop, for each fragment j and element k: - - Get P value from `rP_bf16_frg` or directly from `tTMEM_LOADrS_frg` - - Get coordinate from partitioned coordinate tensor - - Map to PV coordinate using `qk_to_pv_coord` - - Write to SMEM: `sP[dst_coord] = value` +### Implementation Status (2026-05-23 19:55 UTC) +✅ **Implemented full SMEM-P with coordinate mapping** +- Created coordinate tensor `cS` (already existed for row_max) +- Partitioned as `tTMEM_LOADcS_frg` matching P value fragments +- In softmax loop, for each `(k,j)`: + - Get QK coordinate `(m,n)` from `tTMEM_LOADcS_frg[k,j]` + - Map to PV SMEM coordinate using formula + - Write P value (or test pattern) to `sP[pv_coord]` -### Time Pressure: -- Got working coordinate mapping -- Need to implement full mapping (~15-30 min) -- Then test and debug \ No newline at end of file +❌ **Result:** Cosine ~0.02 (near zero correlation) +- Kernel compiles and runs +- PV reads SOMETHING from SMEM (output non-zero) +- But mapping appears wrong (random correlation) +- Output scaling huge (280k vs reference 0.2) + +### Possible Issues: +1. **Coordinate mapping formula wrong** — PV A-operand layout might differ +2. **SMEM swizzle mismatch** — tensor indexing might not handle swizzle correctly +3. **Thread collisions** — Multiple threads writing same SMEM location +4. **P value normalization** — Unnormalized P values cause scaling issues + +### Debug Attempts: +1. Test pattern `(k+j)*0.01` → cosine 0.02 +2. Linear index `m*128+n` → cosine 0.006 (huge output as expected) +3. Both show mapping is bijective but wrong locations + +### Next Actions: +1. Verify coordinate mapping by computing SMEM offset manually +2. Check if PV expects transposed P matrix +3. Examine PV MMA tiler and SMEM layout generation +4. Consider alternative: fix TMEM layout generation instead + +**Time spent:** ~45 minutes. Have working but incorrect SMEM-P. \ No newline at end of file