Files
nvfp4-megamoe-kernel/SMEM_P_GUIDANCE_REQUEST.md

116 lines
6.4 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# SMEM-P Copy Problem — CUTLASS CuTeDSL Guidance Request
## Summary
I need to write a register→SMEM copy that writes P (attention weight) values from softmax warp registers to a swizzled SMEM buffer (`sP`), so the MMA warp can read P as the A-operand of a PV GEMM. The source register layout comes from a TMEM load partition (`Ld32x32bOp`), and the destination is a swizzled SMEM layout from `make_smem_layout_a` for the PV MMA's A-operand. I cannot find a `make_tiled_copy*` variant that correctly handles both the source and destination layouts simultaneously.
## Architecture
- **6-warp FMHA kernel** on Blackwell (SM100)
- Warps 0-3: Softmax + Epilogue (128 threads)
- Warp 4: MMA (32 threads)
- Warp 5: TMA (32 threads)
- **QK GEMM**: `(128, 128) @ (128, 128) → S(128, 128)` — A,B from SMEM, C to TMEM
- **PV GEMM**: `P(128, 128) @ V(128, N) → O(128, N)` — for hd>64, P from SMEM (A-operand), V from SMEM (B-operand), C to TMEM
## The Data Flow
1. QK GEMM writes S to TMEM (C-fragment layout)
2. Softmax warps read S from TMEM via `Ld32x32bOp(Repetition(32))` into registers
3. Softmax computes `P = exp2(S * scale_log2 - row_max)` in registers
4. **[THIS IS THE PROBLEM]** Softmax warps write P to `sP` (SMEM, swizzled, PV A-operand layout)
5. MMA warp reads P from `sP` via `pv_mma.make_fragment_A(sP)` (standard SMEM A-operand load)
6. MMA warp executes `P @ V → O` GEMM
## Layout Details (from diagnostics on B200)
### QK C-fragment (TMEM accumulator):
```
tStS shape: ((128, 128), 1, 1)
tStS layout: ((128,128),1,1):((65536,1),0,0)
```
### Softmax TMEM load partition (what each thread owns):
```
tTMEM_LOADtS shape: (((32, 32), 1), 4, 1, 1) -- source partition
tTMEM_LOADcS shape: ((32, 1), 4, 1, 1) -- coordinate partition
tTMEM_LOADcS layout: ((32,1),4,1,1):((1@1,0),32@1,0,0)
```
Each softmax thread has 128 S/P coordinates (32 × 4 fragments).
Total: 128 threads × 128 elements = 16384 = 128×128 ✓
### sP (PV A-operand SMEM, swizzled):
```
sP shape: ((128, 16), 1, (4, 2), 1)
sP_2d shape: (((128, 16), 1, (4, 2)), 1) rank=2
sP_2d layout: (((128,16),1,(4,2)),1):(((64,1),0,(16,8192)),0)
Swizzle: S<3,4,3>
```
Total elements: 128 × 16 × 4 × 2 = 16384 = 128×128 ✓
### PV A-operand fragments (what MMA warp reads from sP):
For SMEM-P, the MMA warp uses:
- `tCrP = pv_mma.make_fragment_A(sP)` — SMEM load fragment
- `tOrP = pv_thr.make_fragment_A(sP)[(None,None,None,0)]` — TMEM/SMEM read fragment
## What I've Tried
### Attempt 1: `make_tiled_copy_C(copy_atom, qk_mma)`
This creates a copy where the thread partition follows the QK MMA's C-fragment thread mapping.
**Result:** Source and destination partitions both have rank 3 (matching!):
```
partition_S(qk_C_2d) shape: ((8, (16, 128)), 128, 1) rank=3
partition_D(sP_2d) shape: ((8, (16, 128)), (64, 2), 1) rank=3
```
**Problem:** The `CopyUniversalOp` with 128-bit vectorization fails:
```
'cute.copy' op cannot vectorize copy to 8 elements (static strides must be 1)
Source strides: ((1,(8,128)),(16384,0)) -- NOT contiguous
Dest strides: ((64,(512,0)),((1,8192),0))
```
The source tensor has stride 8 in one dimension, meaning elements are not contiguous. The `CopyUniversalOp` expects stride-1 contiguity for vectorized copies.
**Root cause:** The QK C-fragment register layout (from `make_fragment_C`) has non-contiguous strides when viewed as a register tensor. The `make_tiled_copy_C` creates the partition based on the MMA's C-fragment thread mapping, which works for the destination (sP with swizzle) but the source partition inherits non-contiguous strides.
**Key insight:** The softmax threads have P values in the TMEM load register layout (`rP_bf16` shares layout with `tTMEM_LOADrS`), NOT the QK C-fragment register layout. These are different layouts with different stride patterns.
### Attempt 2: `get_smem_store_op` + `make_tiled_copy_D`
This uses the CUTLASS blackwell_helpers pattern from the FMHA correction epilog:
```python
smem_store_atom = get_smem_store_op(c_layout, c_dtype, acc_dtype, tiled_tmem_load)
tiled_smem_store = cute.make_tiled_copy_D(smem_store_atom, tiled_tmem_load)
```
**Problem:** Size mismatch in mode-1:
```
partition_S(tTMEM_LOADtS) → source has 4 in mode-1
partition_D(sP) → destination has 8 in mode-1
```
**Root cause:** The TMEM load partition and the sP SMEM layout have different tiling in the K-dimension. The `tTMEM_LOADtS` has shape `(((32, 32), 1), 4, 1, 1)` (4 fragments), while `sP` has shape `((128, 16), 1, (4, 2), 1)` (4×2 sub-tiles). The `make_tiled_copy_D` creates a partition that doesn't match between source (4 fragments) and destination (8 sub-tiles).
## The Core Question
**How do I create a `TiledCopy` that:**
1. **Source:** Uses the TMEM load partition's thread mapping (128 softmax threads, each owning 128 P values in `((32, 1), 4, 1, 1)` coordinate layout)
2. **Destination:** Writes to `sP` in the PV A-operand SMEM layout (swizzled with `S<3,4,3>`)
3. **Copy atom:** Handles the swizzled SMEM layout correctly
**OR: Is there a completely different approach I'm missing?**
The CUTLASS FMHA reference (12-warp layout) uses TMEM-P exclusively and doesn't have this problem. But our 6-warp layout requires SMEM-P for head_dim > 64 because TMEM can't hold S, P, and O simultaneously.
## Environment
- CuTeDSL (Python DSL for CUTLASS) on Blackwell (B200, SM100)
- `cutlass.utils.blackwell_helpers`: `get_tmem_load_op`, `get_smem_store_op`
- `cutlass.utils.gemm.sm100`: `epilogue_tmem_copy_and_partition`, `epilogue_smem_copy_and_partition`, `make_smem_layout_a`, `make_trivial_tiled_mma`
- `tcgen05.copy`: `Ld32x32bOp`, `St32x32bOp`
- `cute.nvgpu.CopyUniversalOp`
## Additional Context
The P matrix is (128, 128) — same dimensions as the QK attention matrix. The first 128 is the M-dimension (query heads), and the second 128 is the K-dimension of the PV GEMM (which equals n_kv_tokens, the number of key/value positions). For a single KV tile (n_kv=128), the P matrix is the full 128×128.
For head_dim=64 (TMEM-P, working), P is stored in TMEM at column offset `tmem_p0_offset=32` using the register bridge pattern (FP32 `exp2` values stored as BF16 via recast_ptr). For head_dim>64, P must go to SMEM instead.
The `sP` SMEM buffer is allocated with `smem.allocate_tensor(element_type=BF16, layout=p_smem_s.outer, byte_alignment=128, swizzle=p_smem_s.inner)` and has the PV MMA's A-operand SMEM layout, which is designed for `pv_mma.make_fragment_A(sP)` to read correctly during the PV GEMM.