diff --git a/SMEM_P_GUIDANCE_REQUEST.md b/SMEM_P_GUIDANCE_REQUEST.md new file mode 100644 index 00000000..cbac7d63 --- /dev/null +++ b/SMEM_P_GUIDANCE_REQUEST.md @@ -0,0 +1,115 @@ +# SMEM-P Copy Problem — CUTLASS CuTeDSL Guidance Request + +## Summary +I need to write a register→SMEM copy that writes P (attention weight) values from softmax warp registers to a swizzled SMEM buffer (`sP`), so the MMA warp can read P as the A-operand of a PV GEMM. The source register layout comes from a TMEM load partition (`Ld32x32bOp`), and the destination is a swizzled SMEM layout from `make_smem_layout_a` for the PV MMA's A-operand. I cannot find a `make_tiled_copy*` variant that correctly handles both the source and destination layouts simultaneously. + +## Architecture +- **6-warp FMHA kernel** on Blackwell (SM100) + - Warps 0-3: Softmax + Epilogue (128 threads) + - Warp 4: MMA (32 threads) + - Warp 5: TMA (32 threads) +- **QK GEMM**: `(128, 128) @ (128, 128) → S(128, 128)` — A,B from SMEM, C to TMEM +- **PV GEMM**: `P(128, 128) @ V(128, N) → O(128, N)` — for hd>64, P from SMEM (A-operand), V from SMEM (B-operand), C to TMEM + +## The Data Flow +1. QK GEMM writes S to TMEM (C-fragment layout) +2. Softmax warps read S from TMEM via `Ld32x32bOp(Repetition(32))` into registers +3. Softmax computes `P = exp2(S * scale_log2 - row_max)` in registers +4. **[THIS IS THE PROBLEM]** Softmax warps write P to `sP` (SMEM, swizzled, PV A-operand layout) +5. MMA warp reads P from `sP` via `pv_mma.make_fragment_A(sP)` (standard SMEM A-operand load) +6. MMA warp executes `P @ V → O` GEMM + +## Layout Details (from diagnostics on B200) + +### QK C-fragment (TMEM accumulator): +``` +tStS shape: ((128, 128), 1, 1) +tStS layout: ((128,128),1,1):((65536,1),0,0) +``` + +### Softmax TMEM load partition (what each thread owns): +``` +tTMEM_LOADtS shape: (((32, 32), 1), 4, 1, 1) -- source partition +tTMEM_LOADcS shape: ((32, 1), 4, 1, 1) -- coordinate partition +tTMEM_LOADcS layout: ((32,1),4,1,1):((1@1,0),32@1,0,0) +``` +Each softmax thread has 128 S/P coordinates (32 × 4 fragments). +Total: 128 threads × 128 elements = 16384 = 128×128 ✓ + +### sP (PV A-operand SMEM, swizzled): +``` +sP shape: ((128, 16), 1, (4, 2), 1) +sP_2d shape: (((128, 16), 1, (4, 2)), 1) rank=2 +sP_2d layout: (((128,16),1,(4,2)),1):(((64,1),0,(16,8192)),0) +Swizzle: S<3,4,3> +``` +Total elements: 128 × 16 × 4 × 2 = 16384 = 128×128 ✓ + +### PV A-operand fragments (what MMA warp reads from sP): +For SMEM-P, the MMA warp uses: +- `tCrP = pv_mma.make_fragment_A(sP)` — SMEM load fragment +- `tOrP = pv_thr.make_fragment_A(sP)[(None,None,None,0)]` — TMEM/SMEM read fragment + +## What I've Tried + +### Attempt 1: `make_tiled_copy_C(copy_atom, qk_mma)` +This creates a copy where the thread partition follows the QK MMA's C-fragment thread mapping. + +**Result:** Source and destination partitions both have rank 3 (matching!): +``` +partition_S(qk_C_2d) shape: ((8, (16, 128)), 128, 1) rank=3 +partition_D(sP_2d) shape: ((8, (16, 128)), (64, 2), 1) rank=3 +``` + +**Problem:** The `CopyUniversalOp` with 128-bit vectorization fails: +``` +'cute.copy' op cannot vectorize copy to 8 elements (static strides must be 1) +Source strides: ((1,(8,128)),(16384,0)) -- NOT contiguous +Dest strides: ((64,(512,0)),((1,8192),0)) +``` +The source tensor has stride 8 in one dimension, meaning elements are not contiguous. The `CopyUniversalOp` expects stride-1 contiguity for vectorized copies. + +**Root cause:** The QK C-fragment register layout (from `make_fragment_C`) has non-contiguous strides when viewed as a register tensor. The `make_tiled_copy_C` creates the partition based on the MMA's C-fragment thread mapping, which works for the destination (sP with swizzle) but the source partition inherits non-contiguous strides. + +**Key insight:** The softmax threads have P values in the TMEM load register layout (`rP_bf16` shares layout with `tTMEM_LOADrS`), NOT the QK C-fragment register layout. These are different layouts with different stride patterns. + +### Attempt 2: `get_smem_store_op` + `make_tiled_copy_D` +This uses the CUTLASS blackwell_helpers pattern from the FMHA correction epilog: +```python +smem_store_atom = get_smem_store_op(c_layout, c_dtype, acc_dtype, tiled_tmem_load) +tiled_smem_store = cute.make_tiled_copy_D(smem_store_atom, tiled_tmem_load) +``` + +**Problem:** Size mismatch in mode-1: +``` +partition_S(tTMEM_LOADtS) → source has 4 in mode-1 +partition_D(sP) → destination has 8 in mode-1 +``` + +**Root cause:** The TMEM load partition and the sP SMEM layout have different tiling in the K-dimension. The `tTMEM_LOADtS` has shape `(((32, 32), 1), 4, 1, 1)` (4 fragments), while `sP` has shape `((128, 16), 1, (4, 2), 1)` (4×2 sub-tiles). The `make_tiled_copy_D` creates a partition that doesn't match between source (4 fragments) and destination (8 sub-tiles). + +## The Core Question + +**How do I create a `TiledCopy` that:** +1. **Source:** Uses the TMEM load partition's thread mapping (128 softmax threads, each owning 128 P values in `((32, 1), 4, 1, 1)` coordinate layout) +2. **Destination:** Writes to `sP` in the PV A-operand SMEM layout (swizzled with `S<3,4,3>`) +3. **Copy atom:** Handles the swizzled SMEM layout correctly + +**OR: Is there a completely different approach I'm missing?** + +The CUTLASS FMHA reference (12-warp layout) uses TMEM-P exclusively and doesn't have this problem. But our 6-warp layout requires SMEM-P for head_dim > 64 because TMEM can't hold S, P, and O simultaneously. + +## Environment +- CuTeDSL (Python DSL for CUTLASS) on Blackwell (B200, SM100) +- `cutlass.utils.blackwell_helpers`: `get_tmem_load_op`, `get_smem_store_op` +- `cutlass.utils.gemm.sm100`: `epilogue_tmem_copy_and_partition`, `epilogue_smem_copy_and_partition`, `make_smem_layout_a`, `make_trivial_tiled_mma` +- `tcgen05.copy`: `Ld32x32bOp`, `St32x32bOp` +- `cute.nvgpu.CopyUniversalOp` + +## Additional Context + +The P matrix is (128, 128) — same dimensions as the QK attention matrix. The first 128 is the M-dimension (query heads), and the second 128 is the K-dimension of the PV GEMM (which equals n_kv_tokens, the number of key/value positions). For a single KV tile (n_kv=128), the P matrix is the full 128×128. + +For head_dim=64 (TMEM-P, working), P is stored in TMEM at column offset `tmem_p0_offset=32` using the register bridge pattern (FP32 `exp2` values stored as BF16 via recast_ptr). For head_dim>64, P must go to SMEM instead. + +The `sP` SMEM buffer is allocated with `smem.allocate_tensor(element_type=BF16, layout=p_smem_s.outer, byte_alignment=128, swizzle=p_smem_s.inner)` and has the PV MMA's A-operand SMEM layout, which is designed for `pv_mma.make_fragment_A(sP)` to read correctly during the PV GEMM.