# STAGE_D1.3.md — SMEM-P Manual Addressing Notes ## Problem Statement Copy P (attention probabilities) from TMEM to SMEM with layout conversion: - **Source:** TMEM layout from QK C-fragment - **Destination:** SMEM layout for PV A-operand - Both represent same 128×128 P matrix, different tiling ## Layouts Observed (from debug prints) ### QK C-fragment Layout (TMEM) ``` tStS0 layout: ((128,128),1,1):((65536,1),0,0) tStS0 shape: ((128, 128), 1, 1) ``` Interpretation: Shape (128,128) with stride (65536, 1) — column-major? Stride between rows = 65536, between columns = 1. ### PV A-operand SMEM Layout ``` sP layout: ((128,16),1,(4,2),1):((64,1),0,(16,8192),0) sP shape: ((128, 16), 1, (4, 2), 1) ``` Interpretation: Complex tiling with modes: (128,16), 1, (4,2), 1 ## Failed Approaches ### 1. `make_tiled_copy_C` (Helpers are a trap) - Created `tiled_smem_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma)` - `partition_S` and `partition_D` produce incompatible tensors - Source partition: 65536 elements, rank 4 - Destination partition: 2097152 elements, rank 5 (32× mismatch) - Confirmed: helpers assume compatible tiling, but layouts are incompatible ### 2. Manual Copy Attempts **Attempt 1:** Get thread coordinates with `cute.coord(tStS0)` → `AttributeError: module 'cutlass.cute' has no attribute 'coord'` **Attempt 2:** Simple test pattern with dynamic indexing → `DSLRuntimeError: object cannot be interpreted as an integer` - CuTeDSL JIT requires compile-time constants or vectorized loops - Dynamic SMEM offset computation not trivial ## Key Unknowns Blocking Progress 1. **Coordinate Systems:** How to get thread's logical position in QK C-fragment partition? 2. **Mapping Formula:** Given QK C-fragment coordinate (i,j), what is corresponding PV A-operand coordinate? 3. **Offset Computation:** How to compute SMEM offset from logical coordinate given a layout? 4. **Manual Pattern:** What's the CuTeDSL pattern for manual layout conversion? ## Hypotheses ### Hypothesis A: Layouts represent same 128×128 data - QK C-fragment: tiles for QK MMA (C-fragment partitioning) - PV A-operand: tiles for PV MMA (A-operand partitioning) - Need coordinate transformation: QK tile → PV tile mapping ### Hypothesis B: `tStS0` is whole 128×128 matrix view - Not thread's slice but entire matrix in TMEM - Each thread accesses portion via its thread coordinates - Need thread's position within QK C-fragment partition ### Hypothesis C: Manual transpose required - Implement 128×128 transpose between two tiling patterns - Could brute-force with nested loops over 128×128 - But thread owns only subset (partitioned by QK MMA) ## Questions for CUTLASS LLM 1. **Thread Coordinates:** How to get thread's logical coordinates within a layout partition? - `cute.coord()` doesn't exist - Need: thread position in QK C-fragment partition - Possibly: `cute.arch.thread_idx()`, `cute.arch.lane_idx()`, or from MMA? 2. **Layout Mapping:** Given source layout L1 and destination layout L2 for same logical space (128×128 matrix), how to compute coordinate transform? - L1: `((128,128),1,1):((65536,1),0,0)` (QK C-fragment) - L2: `((128,16),1,(4,2),1):((64,1),0,(16,8192),0)` (PV A-operand) - Need mapping: L1 coordinate → L2 coordinate 3. **Offset API:** Is there `cute.offset(layout, coord)`? How to compute physical offset from logical coordinate given a layout? 4. **Manual Pattern:** Example of manual copy between incompatible layouts in CuTeDSL? - Source: tensor with layout L1 - Destination: tensor with layout L2 - Manual element-wise copy with address computation 5. **MMA Partitioning:** How does QK MMA partition the 128×128 C-fragment among threads? - Each thread owns subset of elements - Need to know which elements thread owns ## Found Clue in vendored Code In `moe_torch_scaled_grouped_mm.py`: ```python bidx, bidy, bidz = cute.arch.block_idx() mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) ``` - Uses block index to compute MMA tile coordinate - `tiled_mma.thr_id.shape` gives thread partition shape - But we need thread coordinate within warp/CTA, not block ## Current Code State In `fmha.py`, SMEM-P section is: ```python else: # SMEM-P: Manual addressing (helpers are a trap) print(f"[SMEM-P MANUAL] Starting manual P write to SMEM") print(f"[SMEM-P MANUAL] tStS0 layout: {tStS0.layout}") print(f"[SMEM-P MANUAL] sP layout: {sP.layout}") # TODO: Implement proper mapping # For now, zero sP (wrong but compiles) for j in cutlass.range(cute.size(sP), vectorize=True): sP[j] = BFloat16(0.0) ``` ## Test Status - `test_fmha_v3_stage_d1.py` fails with cosine NaN for hd=64, -0.047 for hd=256, 0.011 for hd=512 - Zeroed sP causes PV MMA to read zeros → garbage output - TMEM-P path works for hd=64 but fails for hd>64 due to TMEM layout mismatch ## Progress Update (2026-05-23 19:35 UTC) **CUTLASS LLM responded!** Got complete solution: ### Key Solution: 1. Use `make_identity_tensor(tStS0.shape)` for coordinate tensor 2. Partition coordinate tensor same way as data tensor 3. Mapping formula: QK `((m, n), 0, 0)` → PV `((m, n % 16), 0, ((n // 16) % 4, n // 64), 0)` 4. Use tensor indexing `sP[dst_coord] = value`, not manual offsets ### Implementation Progress: - ✅ Coordinate mapping function implemented and works (`qk_to_pv_coord`) - ✅ Tensor indexing with coordinate works (`sP[test_coord] = value`) - ❌ Need to implement full 128-value mapping per thread - ❌ Need to get QK coordinates for each of thread's 128 P values ### Implementation Status (2026-05-23 19:55 UTC) ✅ **Implemented full SMEM-P with coordinate mapping** - Created coordinate tensor `cS` (already existed for row_max) - Partitioned as `tTMEM_LOADcS_frg` matching P value fragments - In softmax loop, for each `(k,j)`: - Get QK coordinate `(m,n)` from `tTMEM_LOADcS_frg[k,j]` - Map to PV SMEM coordinate using formula - Write P value (or test pattern) to `sP[pv_coord]` ❌ **Result:** Cosine ~0.02 (near zero correlation) - Kernel compiles and runs - PV reads SOMETHING from SMEM (output non-zero) - But mapping appears wrong (random correlation) - Output scaling huge (280k vs reference 0.2) ### Possible Issues: 1. **Coordinate mapping formula wrong** — PV A-operand layout might differ 2. **SMEM swizzle mismatch** — tensor indexing might not handle swizzle correctly 3. **Thread collisions** — Multiple threads writing same SMEM location 4. **P value normalization** — Unnormalized P values cause scaling issues ### Debug Attempts: 1. Test pattern `(k+j)*0.01` → cosine 0.02 2. Linear index `m*128+n` → cosine 0.006 (huge output as expected) 3. Both show mapping is bijective but wrong locations ### Next Actions: 1. Verify coordinate mapping by computing SMEM offset manually 2. Check if PV expects transposed P matrix 3. Examine PV MMA tiler and SMEM layout generation 4. Consider alternative: fix TMEM layout generation instead ## Progress Update (2026-05-23 21:30 UTC) **TMEM-P path now works at hd=64 (cos 0.973).** The root cause of NaN/zeros was a missing TMEM column offset on `tOrP0` — PV MMA was reading from column 0 (where S is) instead of column 32 (where P is stored by softmax warps). Fixed with `const_expr` conditional. **SMEM-P remains unsolved.** The `make_tiled_copy_C` approach gives rank mismatch. Manual coordinate mapping compiles but produces near-zero cosine (wrong addresses). The CUTLASS reference FMHA uses TMEM-P exclusively (12-warp layout with more TMEM budget). For our 6-warp layout, SMEM-P is needed for hd>64. **Current D1 status:** - hd=64 (TMEM-P): cos 0.973 ✅ - hd=256 (SMEM-P stub): FAIL (zeros) - hd=512 (SMEM-P stub): FAIL (zeros) **Workaround for hd>64:** The D5b milestone (Python SWA+sink merge) works at hd=64. SMEM-P for hd>64 is a production optimization, not a correctness blocker. The full DSV4 pipeline (CSA + HCA + SWA) can be tested at hd=64 with TMEM-P. **Key discoveries:** 1. `make_tiled_copy_C(store_atom, qk_mma)` creates a copy that partitions threads by QK C-fragment layout, but the source and destination have incompatible ranks (4 vs 3). This is a fundamental layout incompatibility. 2. Manual coordinate mapping (`qk_to_pv_coord`) compiles and runs but produces wrong results. The mapping formula may be incorrect, or SMEM swizzle may interfere with tensor indexing. 3. The CUTLASS reference FMHA (12-warp) avoids SMEM-P entirely by using TMEM-P with more warps (more TMEM budget). A 12-warp layout would solve the SMEM-P problem architecturally. **Recommended next steps for SMEM-P:** 1. Try 12-warp layout (like CUTLASS reference) to avoid SMEM-P entirely 2. OR: Use `make_tiled_copy` with `pv_mma` (not `qk_mma`) for the copy, since PV MMA knows the SMEM layout 3. OR: Implement a two-stage copy: QK C-fragment → intermediate buffer → PV A-operand SMEM