diff --git a/STAGE_D1.3.md b/STAGE_D1.3.md new file mode 100644 index 00000000..a4c1bad9 --- /dev/null +++ b/STAGE_D1.3.md @@ -0,0 +1,127 @@ +# STAGE_D1.3.md — SMEM-P Manual Addressing Notes + +## Problem Statement + +Copy P (attention probabilities) from TMEM to SMEM with layout conversion: +- **Source:** TMEM layout from QK C-fragment +- **Destination:** SMEM layout for PV A-operand +- Both represent same 128×128 P matrix, different tiling + +## Layouts Observed (from debug prints) + +### QK C-fragment Layout (TMEM) +``` +tStS0 layout: ((128,128),1,1):((65536,1),0,0) +tStS0 shape: ((128, 128), 1, 1) +``` +Interpretation: Shape (128,128) with stride (65536, 1) — column-major? Stride between rows = 65536, between columns = 1. + +### PV A-operand SMEM Layout +``` +sP layout: ((128,16),1,(4,2),1):((64,1),0,(16,8192),0) +sP shape: ((128, 16), 1, (4, 2), 1) +``` +Interpretation: Complex tiling with modes: (128,16), 1, (4,2), 1 + +## Failed Approaches + +### 1. `make_tiled_copy_C` (Helpers are a trap) +- Created `tiled_smem_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma)` +- `partition_S` and `partition_D` produce incompatible tensors +- Source partition: 65536 elements, rank 4 +- Destination partition: 2097152 elements, rank 5 (32× mismatch) +- Confirmed: helpers assume compatible tiling, but layouts are incompatible + +### 2. Manual Copy Attempts +**Attempt 1:** Get thread coordinates with `cute.coord(tStS0)` → `AttributeError: module 'cutlass.cute' has no attribute 'coord'` + +**Attempt 2:** Simple test pattern with dynamic indexing → `DSLRuntimeError: object cannot be interpreted as an integer` +- CuTeDSL JIT requires compile-time constants or vectorized loops +- Dynamic SMEM offset computation not trivial + +## Key Unknowns Blocking Progress + +1. **Coordinate Systems:** How to get thread's logical position in QK C-fragment partition? +2. **Mapping Formula:** Given QK C-fragment coordinate (i,j), what is corresponding PV A-operand coordinate? +3. **Offset Computation:** How to compute SMEM offset from logical coordinate given a layout? +4. **Manual Pattern:** What's the CuTeDSL pattern for manual layout conversion? + +## Hypotheses + +### Hypothesis A: Layouts represent same 128×128 data +- QK C-fragment: tiles for QK MMA (C-fragment partitioning) +- PV A-operand: tiles for PV MMA (A-operand partitioning) +- Need coordinate transformation: QK tile → PV tile mapping + +### Hypothesis B: `tStS0` is whole 128×128 matrix view +- Not thread's slice but entire matrix in TMEM +- Each thread accesses portion via its thread coordinates +- Need thread's position within QK C-fragment partition + +### Hypothesis C: Manual transpose required +- Implement 128×128 transpose between two tiling patterns +- Could brute-force with nested loops over 128×128 +- But thread owns only subset (partitioned by QK MMA) + +## Questions for CUTLASS LLM + +1. **Thread Coordinates:** How to get thread's logical coordinates within a layout partition? + - `cute.coord()` doesn't exist + - Need: thread position in QK C-fragment partition + - Possibly: `cute.arch.thread_idx()`, `cute.arch.lane_idx()`, or from MMA? + +2. **Layout Mapping:** Given source layout L1 and destination layout L2 for same logical space (128×128 matrix), how to compute coordinate transform? + - L1: `((128,128),1,1):((65536,1),0,0)` (QK C-fragment) + - L2: `((128,16),1,(4,2),1):((64,1),0,(16,8192),0)` (PV A-operand) + - Need mapping: L1 coordinate → L2 coordinate + +3. **Offset API:** Is there `cute.offset(layout, coord)`? How to compute physical offset from logical coordinate given a layout? + +4. **Manual Pattern:** Example of manual copy between incompatible layouts in CuTeDSL? + - Source: tensor with layout L1 + - Destination: tensor with layout L2 + - Manual element-wise copy with address computation + +5. **MMA Partitioning:** How does QK MMA partition the 128×128 C-fragment among threads? + - Each thread owns subset of elements + - Need to know which elements thread owns + +## Found Clue in vendored Code +In `moe_torch_scaled_grouped_mm.py`: +```python +bidx, bidy, bidz = cute.arch.block_idx() +mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) +``` +- Uses block index to compute MMA tile coordinate +- `tiled_mma.thr_id.shape` gives thread partition shape +- But we need thread coordinate within warp/CTA, not block + +## Current Code State + +In `fmha.py`, SMEM-P section is: +```python +else: + # SMEM-P: Manual addressing (helpers are a trap) + print(f"[SMEM-P MANUAL] Starting manual P write to SMEM") + print(f"[SMEM-P MANUAL] tStS0 layout: {tStS0.layout}") + print(f"[SMEM-P MANUAL] sP layout: {sP.layout}") + + # TODO: Implement proper mapping + # For now, zero sP (wrong but compiles) + for j in cutlass.range(cute.size(sP), vectorize=True): + sP[j] = BFloat16(0.0) +``` + +## Test Status +- `test_fmha_v3_stage_d1.py` fails with cosine NaN for hd=64, -0.047 for hd=256, 0.011 for hd=512 +- Zeroed sP causes PV MMA to read zeros → garbage output +- TMEM-P path works for hd=64 but fails for hd>64 due to TMEM layout mismatch + +## Blocked Until +1. Answer from CUTLASS LLM on coordinate mapping +2. Or alternative approach suggested + +## Time Pressure +- Been debugging for ~20 minutes +- Manual addressing proving much harder than anticipated +- Without mapping formula, guessing is low-probability \ No newline at end of file