- README.md: Updated Stage status table (D1 🟡, D5 🟢), D5 section with D5a/D5b results, tOrP0 bug fix docs, new CuTeDSL constraints #11-12 - STAGE_D1.3.md: Added progress update - TMEM-P works, SMEM-P still blocked, recommended next steps - STAGE_D.md was already updated
8.7 KiB
STAGE_D1.3.md — SMEM-P Manual Addressing Notes
Problem Statement
Copy P (attention probabilities) from TMEM to SMEM with layout conversion:
- Source: TMEM layout from QK C-fragment
- Destination: SMEM layout for PV A-operand
- Both represent same 128×128 P matrix, different tiling
Layouts Observed (from debug prints)
QK C-fragment Layout (TMEM)
tStS0 layout: ((128,128),1,1):((65536,1),0,0)
tStS0 shape: ((128, 128), 1, 1)
Interpretation: Shape (128,128) with stride (65536, 1) — column-major? Stride between rows = 65536, between columns = 1.
PV A-operand SMEM Layout
sP layout: ((128,16),1,(4,2),1):((64,1),0,(16,8192),0)
sP shape: ((128, 16), 1, (4, 2), 1)
Interpretation: Complex tiling with modes: (128,16), 1, (4,2), 1
Failed Approaches
1. make_tiled_copy_C (Helpers are a trap)
- Created
tiled_smem_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma) partition_Sandpartition_Dproduce incompatible tensors- Source partition: 65536 elements, rank 4
- Destination partition: 2097152 elements, rank 5 (32× mismatch)
- Confirmed: helpers assume compatible tiling, but layouts are incompatible
2. Manual Copy Attempts
Attempt 1: Get thread coordinates with cute.coord(tStS0) → AttributeError: module 'cutlass.cute' has no attribute 'coord'
Attempt 2: Simple test pattern with dynamic indexing → DSLRuntimeError: object cannot be interpreted as an integer
- CuTeDSL JIT requires compile-time constants or vectorized loops
- Dynamic SMEM offset computation not trivial
Key Unknowns Blocking Progress
- Coordinate Systems: How to get thread's logical position in QK C-fragment partition?
- Mapping Formula: Given QK C-fragment coordinate (i,j), what is corresponding PV A-operand coordinate?
- Offset Computation: How to compute SMEM offset from logical coordinate given a layout?
- Manual Pattern: What's the CuTeDSL pattern for manual layout conversion?
Hypotheses
Hypothesis A: Layouts represent same 128×128 data
- QK C-fragment: tiles for QK MMA (C-fragment partitioning)
- PV A-operand: tiles for PV MMA (A-operand partitioning)
- Need coordinate transformation: QK tile → PV tile mapping
Hypothesis B: tStS0 is whole 128×128 matrix view
- Not thread's slice but entire matrix in TMEM
- Each thread accesses portion via its thread coordinates
- Need thread's position within QK C-fragment partition
Hypothesis C: Manual transpose required
- Implement 128×128 transpose between two tiling patterns
- Could brute-force with nested loops over 128×128
- But thread owns only subset (partitioned by QK MMA)
Questions for CUTLASS LLM
-
Thread Coordinates: How to get thread's logical coordinates within a layout partition?
cute.coord()doesn't exist- Need: thread position in QK C-fragment partition
- Possibly:
cute.arch.thread_idx(),cute.arch.lane_idx(), or from MMA?
-
Layout Mapping: Given source layout L1 and destination layout L2 for same logical space (128×128 matrix), how to compute coordinate transform?
- L1:
((128,128),1,1):((65536,1),0,0)(QK C-fragment) - L2:
((128,16),1,(4,2),1):((64,1),0,(16,8192),0)(PV A-operand) - Need mapping: L1 coordinate → L2 coordinate
- L1:
-
Offset API: Is there
cute.offset(layout, coord)? How to compute physical offset from logical coordinate given a layout? -
Manual Pattern: Example of manual copy between incompatible layouts in CuTeDSL?
- Source: tensor with layout L1
- Destination: tensor with layout L2
- Manual element-wise copy with address computation
-
MMA Partitioning: How does QK MMA partition the 128×128 C-fragment among threads?
- Each thread owns subset of elements
- Need to know which elements thread owns
Found Clue in vendored Code
In moe_torch_scaled_grouped_mm.py:
bidx, bidy, bidz = cute.arch.block_idx()
mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
- Uses block index to compute MMA tile coordinate
tiled_mma.thr_id.shapegives thread partition shape- But we need thread coordinate within warp/CTA, not block
Current Code State
In fmha.py, SMEM-P section is:
else:
# SMEM-P: Manual addressing (helpers are a trap)
print(f"[SMEM-P MANUAL] Starting manual P write to SMEM")
print(f"[SMEM-P MANUAL] tStS0 layout: {tStS0.layout}")
print(f"[SMEM-P MANUAL] sP layout: {sP.layout}")
# TODO: Implement proper mapping
# For now, zero sP (wrong but compiles)
for j in cutlass.range(cute.size(sP), vectorize=True):
sP[j] = BFloat16(0.0)
Test Status
test_fmha_v3_stage_d1.pyfails with cosine NaN for hd=64, -0.047 for hd=256, 0.011 for hd=512- Zeroed sP causes PV MMA to read zeros → garbage output
- TMEM-P path works for hd=64 but fails for hd>64 due to TMEM layout mismatch
Progress Update (2026-05-23 19:35 UTC)
CUTLASS LLM responded! Got complete solution:
Key Solution:
- Use
make_identity_tensor(tStS0.shape)for coordinate tensor - Partition coordinate tensor same way as data tensor
- Mapping formula: QK
((m, n), 0, 0)→ PV((m, n % 16), 0, ((n // 16) % 4, n // 64), 0) - Use tensor indexing
sP[dst_coord] = value, not manual offsets
Implementation Progress:
- ✅ Coordinate mapping function implemented and works (
qk_to_pv_coord) - ✅ Tensor indexing with coordinate works (
sP[test_coord] = value) - ❌ Need to implement full 128-value mapping per thread
- ❌ Need to get QK coordinates for each of thread's 128 P values
Implementation Status (2026-05-23 19:55 UTC)
✅ Implemented full SMEM-P with coordinate mapping
- Created coordinate tensor
cS(already existed for row_max) - Partitioned as
tTMEM_LOADcS_frgmatching P value fragments - In softmax loop, for each
(k,j):- Get QK coordinate
(m,n)fromtTMEM_LOADcS_frg[k,j] - Map to PV SMEM coordinate using formula
- Write P value (or test pattern) to
sP[pv_coord]
- Get QK coordinate
❌ Result: Cosine ~0.02 (near zero correlation)
- Kernel compiles and runs
- PV reads SOMETHING from SMEM (output non-zero)
- But mapping appears wrong (random correlation)
- Output scaling huge (280k vs reference 0.2)
Possible Issues:
- Coordinate mapping formula wrong — PV A-operand layout might differ
- SMEM swizzle mismatch — tensor indexing might not handle swizzle correctly
- Thread collisions — Multiple threads writing same SMEM location
- P value normalization — Unnormalized P values cause scaling issues
Debug Attempts:
- Test pattern
(k+j)*0.01→ cosine 0.02 - Linear index
m*128+n→ cosine 0.006 (huge output as expected) - Both show mapping is bijective but wrong locations
Next Actions:
- Verify coordinate mapping by computing SMEM offset manually
- Check if PV expects transposed P matrix
- Examine PV MMA tiler and SMEM layout generation
- Consider alternative: fix TMEM layout generation instead
Progress Update (2026-05-23 21:30 UTC)
TMEM-P path now works at hd=64 (cos 0.973). The root cause of NaN/zeros was a missing TMEM column offset on tOrP0 — PV MMA was reading from column 0 (where S is) instead of column 32 (where P is stored by softmax warps). Fixed with const_expr conditional.
SMEM-P remains unsolved. The make_tiled_copy_C approach gives rank mismatch. Manual coordinate mapping compiles but produces near-zero cosine (wrong addresses). The CUTLASS reference FMHA uses TMEM-P exclusively (12-warp layout with more TMEM budget). For our 6-warp layout, SMEM-P is needed for hd>64.
Current D1 status:
- hd=64 (TMEM-P): cos 0.973 ✅
- hd=256 (SMEM-P stub): FAIL (zeros)
- hd=512 (SMEM-P stub): FAIL (zeros)
Workaround for hd>64: The D5b milestone (Python SWA+sink merge) works at hd=64. SMEM-P for hd>64 is a production optimization, not a correctness blocker. The full DSV4 pipeline (CSA + HCA + SWA) can be tested at hd=64 with TMEM-P.
Key discoveries:
make_tiled_copy_C(store_atom, qk_mma)creates a copy that partitions threads by QK C-fragment layout, but the source and destination have incompatible ranks (4 vs 3). This is a fundamental layout incompatibility.- Manual coordinate mapping (
qk_to_pv_coord) compiles and runs but produces wrong results. The mapping formula may be incorrect, or SMEM swizzle may interfere with tensor indexing. - The CUTLASS reference FMHA (12-warp) avoids SMEM-P entirely by using TMEM-P with more warps (more TMEM budget). A 12-warp layout would solve the SMEM-P problem architecturally.
Recommended next steps for SMEM-P:
- Try 12-warp layout (like CUTLASS reference) to avoid SMEM-P entirely
- OR: Use
make_tiled_copywithpv_mma(notqk_mma) for the copy, since PV MMA knows the SMEM layout - OR: Implement a two-stage copy: QK C-fragment → intermediate buffer → PV A-operand SMEM