Remove obsolete STAGE_D1.3.md and SMEM_P_GUIDANCE_REQUEST.md
This commit is contained in:
@@ -1,115 +0,0 @@
|
||||
# SMEM-P Copy Problem — CUTLASS CuTeDSL Guidance Request
|
||||
|
||||
## Summary
|
||||
I need to write a register→SMEM copy that writes P (attention weight) values from softmax warp registers to a swizzled SMEM buffer (`sP`), so the MMA warp can read P as the A-operand of a PV GEMM. The source register layout comes from a TMEM load partition (`Ld32x32bOp`), and the destination is a swizzled SMEM layout from `make_smem_layout_a` for the PV MMA's A-operand. I cannot find a `make_tiled_copy*` variant that correctly handles both the source and destination layouts simultaneously.
|
||||
|
||||
## Architecture
|
||||
- **6-warp FMHA kernel** on Blackwell (SM100)
|
||||
- Warps 0-3: Softmax + Epilogue (128 threads)
|
||||
- Warp 4: MMA (32 threads)
|
||||
- Warp 5: TMA (32 threads)
|
||||
- **QK GEMM**: `(128, 128) @ (128, 128) → S(128, 128)` — A,B from SMEM, C to TMEM
|
||||
- **PV GEMM**: `P(128, 128) @ V(128, N) → O(128, N)` — for hd>64, P from SMEM (A-operand), V from SMEM (B-operand), C to TMEM
|
||||
|
||||
## The Data Flow
|
||||
1. QK GEMM writes S to TMEM (C-fragment layout)
|
||||
2. Softmax warps read S from TMEM via `Ld32x32bOp(Repetition(32))` into registers
|
||||
3. Softmax computes `P = exp2(S * scale_log2 - row_max)` in registers
|
||||
4. **[THIS IS THE PROBLEM]** Softmax warps write P to `sP` (SMEM, swizzled, PV A-operand layout)
|
||||
5. MMA warp reads P from `sP` via `pv_mma.make_fragment_A(sP)` (standard SMEM A-operand load)
|
||||
6. MMA warp executes `P @ V → O` GEMM
|
||||
|
||||
## Layout Details (from diagnostics on B200)
|
||||
|
||||
### QK C-fragment (TMEM accumulator):
|
||||
```
|
||||
tStS shape: ((128, 128), 1, 1)
|
||||
tStS layout: ((128,128),1,1):((65536,1),0,0)
|
||||
```
|
||||
|
||||
### Softmax TMEM load partition (what each thread owns):
|
||||
```
|
||||
tTMEM_LOADtS shape: (((32, 32), 1), 4, 1, 1) -- source partition
|
||||
tTMEM_LOADcS shape: ((32, 1), 4, 1, 1) -- coordinate partition
|
||||
tTMEM_LOADcS layout: ((32,1),4,1,1):((1@1,0),32@1,0,0)
|
||||
```
|
||||
Each softmax thread has 128 S/P coordinates (32 × 4 fragments).
|
||||
Total: 128 threads × 128 elements = 16384 = 128×128 ✓
|
||||
|
||||
### sP (PV A-operand SMEM, swizzled):
|
||||
```
|
||||
sP shape: ((128, 16), 1, (4, 2), 1)
|
||||
sP_2d shape: (((128, 16), 1, (4, 2)), 1) rank=2
|
||||
sP_2d layout: (((128,16),1,(4,2)),1):(((64,1),0,(16,8192)),0)
|
||||
Swizzle: S<3,4,3>
|
||||
```
|
||||
Total elements: 128 × 16 × 4 × 2 = 16384 = 128×128 ✓
|
||||
|
||||
### PV A-operand fragments (what MMA warp reads from sP):
|
||||
For SMEM-P, the MMA warp uses:
|
||||
- `tCrP = pv_mma.make_fragment_A(sP)` — SMEM load fragment
|
||||
- `tOrP = pv_thr.make_fragment_A(sP)[(None,None,None,0)]` — TMEM/SMEM read fragment
|
||||
|
||||
## What I've Tried
|
||||
|
||||
### Attempt 1: `make_tiled_copy_C(copy_atom, qk_mma)`
|
||||
This creates a copy where the thread partition follows the QK MMA's C-fragment thread mapping.
|
||||
|
||||
**Result:** Source and destination partitions both have rank 3 (matching!):
|
||||
```
|
||||
partition_S(qk_C_2d) shape: ((8, (16, 128)), 128, 1) rank=3
|
||||
partition_D(sP_2d) shape: ((8, (16, 128)), (64, 2), 1) rank=3
|
||||
```
|
||||
|
||||
**Problem:** The `CopyUniversalOp` with 128-bit vectorization fails:
|
||||
```
|
||||
'cute.copy' op cannot vectorize copy to 8 elements (static strides must be 1)
|
||||
Source strides: ((1,(8,128)),(16384,0)) -- NOT contiguous
|
||||
Dest strides: ((64,(512,0)),((1,8192),0))
|
||||
```
|
||||
The source tensor has stride 8 in one dimension, meaning elements are not contiguous. The `CopyUniversalOp` expects stride-1 contiguity for vectorized copies.
|
||||
|
||||
**Root cause:** The QK C-fragment register layout (from `make_fragment_C`) has non-contiguous strides when viewed as a register tensor. The `make_tiled_copy_C` creates the partition based on the MMA's C-fragment thread mapping, which works for the destination (sP with swizzle) but the source partition inherits non-contiguous strides.
|
||||
|
||||
**Key insight:** The softmax threads have P values in the TMEM load register layout (`rP_bf16` shares layout with `tTMEM_LOADrS`), NOT the QK C-fragment register layout. These are different layouts with different stride patterns.
|
||||
|
||||
### Attempt 2: `get_smem_store_op` + `make_tiled_copy_D`
|
||||
This uses the CUTLASS blackwell_helpers pattern from the FMHA correction epilog:
|
||||
```python
|
||||
smem_store_atom = get_smem_store_op(c_layout, c_dtype, acc_dtype, tiled_tmem_load)
|
||||
tiled_smem_store = cute.make_tiled_copy_D(smem_store_atom, tiled_tmem_load)
|
||||
```
|
||||
|
||||
**Problem:** Size mismatch in mode-1:
|
||||
```
|
||||
partition_S(tTMEM_LOADtS) → source has 4 in mode-1
|
||||
partition_D(sP) → destination has 8 in mode-1
|
||||
```
|
||||
|
||||
**Root cause:** The TMEM load partition and the sP SMEM layout have different tiling in the K-dimension. The `tTMEM_LOADtS` has shape `(((32, 32), 1), 4, 1, 1)` (4 fragments), while `sP` has shape `((128, 16), 1, (4, 2), 1)` (4×2 sub-tiles). The `make_tiled_copy_D` creates a partition that doesn't match between source (4 fragments) and destination (8 sub-tiles).
|
||||
|
||||
## The Core Question
|
||||
|
||||
**How do I create a `TiledCopy` that:**
|
||||
1. **Source:** Uses the TMEM load partition's thread mapping (128 softmax threads, each owning 128 P values in `((32, 1), 4, 1, 1)` coordinate layout)
|
||||
2. **Destination:** Writes to `sP` in the PV A-operand SMEM layout (swizzled with `S<3,4,3>`)
|
||||
3. **Copy atom:** Handles the swizzled SMEM layout correctly
|
||||
|
||||
**OR: Is there a completely different approach I'm missing?**
|
||||
|
||||
The CUTLASS FMHA reference (12-warp layout) uses TMEM-P exclusively and doesn't have this problem. But our 6-warp layout requires SMEM-P for head_dim > 64 because TMEM can't hold S, P, and O simultaneously.
|
||||
|
||||
## Environment
|
||||
- CuTeDSL (Python DSL for CUTLASS) on Blackwell (B200, SM100)
|
||||
- `cutlass.utils.blackwell_helpers`: `get_tmem_load_op`, `get_smem_store_op`
|
||||
- `cutlass.utils.gemm.sm100`: `epilogue_tmem_copy_and_partition`, `epilogue_smem_copy_and_partition`, `make_smem_layout_a`, `make_trivial_tiled_mma`
|
||||
- `tcgen05.copy`: `Ld32x32bOp`, `St32x32bOp`
|
||||
- `cute.nvgpu.CopyUniversalOp`
|
||||
|
||||
## Additional Context
|
||||
|
||||
The P matrix is (128, 128) — same dimensions as the QK attention matrix. The first 128 is the M-dimension (query heads), and the second 128 is the K-dimension of the PV GEMM (which equals n_kv_tokens, the number of key/value positions). For a single KV tile (n_kv=128), the P matrix is the full 128×128.
|
||||
|
||||
For head_dim=64 (TMEM-P, working), P is stored in TMEM at column offset `tmem_p0_offset=32` using the register bridge pattern (FP32 `exp2` values stored as BF16 via recast_ptr). For head_dim>64, P must go to SMEM instead.
|
||||
|
||||
The `sP` SMEM buffer is allocated with `smem.allocate_tensor(element_type=BF16, layout=p_smem_s.outer, byte_alignment=128, swizzle=p_smem_s.inner)` and has the PV MMA's A-operand SMEM layout, which is designed for `pv_mma.make_fragment_A(sP)` to read correctly during the PV GEMM.
|
||||
188
STAGE_D1.3.md
188
STAGE_D1.3.md
@@ -1,188 +0,0 @@
|
||||
# STAGE_D1.3.md — SMEM-P Manual Addressing Notes
|
||||
|
||||
## Problem Statement
|
||||
|
||||
Copy P (attention probabilities) from TMEM to SMEM with layout conversion:
|
||||
- **Source:** TMEM layout from QK C-fragment
|
||||
- **Destination:** SMEM layout for PV A-operand
|
||||
- Both represent same 128×128 P matrix, different tiling
|
||||
|
||||
## Layouts Observed (from debug prints)
|
||||
|
||||
### QK C-fragment Layout (TMEM)
|
||||
```
|
||||
tStS0 layout: ((128,128),1,1):((65536,1),0,0)
|
||||
tStS0 shape: ((128, 128), 1, 1)
|
||||
```
|
||||
Interpretation: Shape (128,128) with stride (65536, 1) — column-major? Stride between rows = 65536, between columns = 1.
|
||||
|
||||
### PV A-operand SMEM Layout
|
||||
```
|
||||
sP layout: ((128,16),1,(4,2),1):((64,1),0,(16,8192),0)
|
||||
sP shape: ((128, 16), 1, (4, 2), 1)
|
||||
```
|
||||
Interpretation: Complex tiling with modes: (128,16), 1, (4,2), 1
|
||||
|
||||
## Failed Approaches
|
||||
|
||||
### 1. `make_tiled_copy_C` (Helpers are a trap)
|
||||
- Created `tiled_smem_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma)`
|
||||
- `partition_S` and `partition_D` produce incompatible tensors
|
||||
- Source partition: 65536 elements, rank 4
|
||||
- Destination partition: 2097152 elements, rank 5 (32× mismatch)
|
||||
- Confirmed: helpers assume compatible tiling, but layouts are incompatible
|
||||
|
||||
### 2. Manual Copy Attempts
|
||||
**Attempt 1:** Get thread coordinates with `cute.coord(tStS0)` → `AttributeError: module 'cutlass.cute' has no attribute 'coord'`
|
||||
|
||||
**Attempt 2:** Simple test pattern with dynamic indexing → `DSLRuntimeError: object cannot be interpreted as an integer`
|
||||
- CuTeDSL JIT requires compile-time constants or vectorized loops
|
||||
- Dynamic SMEM offset computation not trivial
|
||||
|
||||
## Key Unknowns Blocking Progress
|
||||
|
||||
1. **Coordinate Systems:** How to get thread's logical position in QK C-fragment partition?
|
||||
2. **Mapping Formula:** Given QK C-fragment coordinate (i,j), what is corresponding PV A-operand coordinate?
|
||||
3. **Offset Computation:** How to compute SMEM offset from logical coordinate given a layout?
|
||||
4. **Manual Pattern:** What's the CuTeDSL pattern for manual layout conversion?
|
||||
|
||||
## Hypotheses
|
||||
|
||||
### Hypothesis A: Layouts represent same 128×128 data
|
||||
- QK C-fragment: tiles for QK MMA (C-fragment partitioning)
|
||||
- PV A-operand: tiles for PV MMA (A-operand partitioning)
|
||||
- Need coordinate transformation: QK tile → PV tile mapping
|
||||
|
||||
### Hypothesis B: `tStS0` is whole 128×128 matrix view
|
||||
- Not thread's slice but entire matrix in TMEM
|
||||
- Each thread accesses portion via its thread coordinates
|
||||
- Need thread's position within QK C-fragment partition
|
||||
|
||||
### Hypothesis C: Manual transpose required
|
||||
- Implement 128×128 transpose between two tiling patterns
|
||||
- Could brute-force with nested loops over 128×128
|
||||
- But thread owns only subset (partitioned by QK MMA)
|
||||
|
||||
## Questions for CUTLASS LLM
|
||||
|
||||
1. **Thread Coordinates:** How to get thread's logical coordinates within a layout partition?
|
||||
- `cute.coord()` doesn't exist
|
||||
- Need: thread position in QK C-fragment partition
|
||||
- Possibly: `cute.arch.thread_idx()`, `cute.arch.lane_idx()`, or from MMA?
|
||||
|
||||
2. **Layout Mapping:** Given source layout L1 and destination layout L2 for same logical space (128×128 matrix), how to compute coordinate transform?
|
||||
- L1: `((128,128),1,1):((65536,1),0,0)` (QK C-fragment)
|
||||
- L2: `((128,16),1,(4,2),1):((64,1),0,(16,8192),0)` (PV A-operand)
|
||||
- Need mapping: L1 coordinate → L2 coordinate
|
||||
|
||||
3. **Offset API:** Is there `cute.offset(layout, coord)`? How to compute physical offset from logical coordinate given a layout?
|
||||
|
||||
4. **Manual Pattern:** Example of manual copy between incompatible layouts in CuTeDSL?
|
||||
- Source: tensor with layout L1
|
||||
- Destination: tensor with layout L2
|
||||
- Manual element-wise copy with address computation
|
||||
|
||||
5. **MMA Partitioning:** How does QK MMA partition the 128×128 C-fragment among threads?
|
||||
- Each thread owns subset of elements
|
||||
- Need to know which elements thread owns
|
||||
|
||||
## Found Clue in vendored Code
|
||||
In `moe_torch_scaled_grouped_mm.py`:
|
||||
```python
|
||||
bidx, bidy, bidz = cute.arch.block_idx()
|
||||
mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
|
||||
```
|
||||
- Uses block index to compute MMA tile coordinate
|
||||
- `tiled_mma.thr_id.shape` gives thread partition shape
|
||||
- But we need thread coordinate within warp/CTA, not block
|
||||
|
||||
## Current Code State
|
||||
|
||||
In `fmha.py`, SMEM-P section is:
|
||||
```python
|
||||
else:
|
||||
# SMEM-P: Manual addressing (helpers are a trap)
|
||||
print(f"[SMEM-P MANUAL] Starting manual P write to SMEM")
|
||||
print(f"[SMEM-P MANUAL] tStS0 layout: {tStS0.layout}")
|
||||
print(f"[SMEM-P MANUAL] sP layout: {sP.layout}")
|
||||
|
||||
# TODO: Implement proper mapping
|
||||
# For now, zero sP (wrong but compiles)
|
||||
for j in cutlass.range(cute.size(sP), vectorize=True):
|
||||
sP[j] = BFloat16(0.0)
|
||||
```
|
||||
|
||||
## Test Status
|
||||
- `test_fmha_v3_stage_d1.py` fails with cosine NaN for hd=64, -0.047 for hd=256, 0.011 for hd=512
|
||||
- Zeroed sP causes PV MMA to read zeros → garbage output
|
||||
- TMEM-P path works for hd=64 but fails for hd>64 due to TMEM layout mismatch
|
||||
|
||||
## Progress Update (2026-05-23 19:35 UTC)
|
||||
**CUTLASS LLM responded!** Got complete solution:
|
||||
|
||||
### Key Solution:
|
||||
1. Use `make_identity_tensor(tStS0.shape)` for coordinate tensor
|
||||
2. Partition coordinate tensor same way as data tensor
|
||||
3. Mapping formula: QK `((m, n), 0, 0)` → PV `((m, n % 16), 0, ((n // 16) % 4, n // 64), 0)`
|
||||
4. Use tensor indexing `sP[dst_coord] = value`, not manual offsets
|
||||
|
||||
### Implementation Progress:
|
||||
- ✅ Coordinate mapping function implemented and works (`qk_to_pv_coord`)
|
||||
- ✅ Tensor indexing with coordinate works (`sP[test_coord] = value`)
|
||||
- ❌ Need to implement full 128-value mapping per thread
|
||||
- ❌ Need to get QK coordinates for each of thread's 128 P values
|
||||
|
||||
### Implementation Status (2026-05-23 19:55 UTC)
|
||||
✅ **Implemented full SMEM-P with coordinate mapping**
|
||||
- Created coordinate tensor `cS` (already existed for row_max)
|
||||
- Partitioned as `tTMEM_LOADcS_frg` matching P value fragments
|
||||
- In softmax loop, for each `(k,j)`:
|
||||
- Get QK coordinate `(m,n)` from `tTMEM_LOADcS_frg[k,j]`
|
||||
- Map to PV SMEM coordinate using formula
|
||||
- Write P value (or test pattern) to `sP[pv_coord]`
|
||||
|
||||
❌ **Result:** Cosine ~0.02 (near zero correlation)
|
||||
- Kernel compiles and runs
|
||||
- PV reads SOMETHING from SMEM (output non-zero)
|
||||
- But mapping appears wrong (random correlation)
|
||||
- Output scaling huge (280k vs reference 0.2)
|
||||
|
||||
### Possible Issues:
|
||||
1. **Coordinate mapping formula wrong** — PV A-operand layout might differ
|
||||
2. **SMEM swizzle mismatch** — tensor indexing might not handle swizzle correctly
|
||||
3. **Thread collisions** — Multiple threads writing same SMEM location
|
||||
4. **P value normalization** — Unnormalized P values cause scaling issues
|
||||
|
||||
### Debug Attempts:
|
||||
1. Test pattern `(k+j)*0.01` → cosine 0.02
|
||||
2. Linear index `m*128+n` → cosine 0.006 (huge output as expected)
|
||||
3. Both show mapping is bijective but wrong locations
|
||||
|
||||
### Next Actions:
|
||||
1. Verify coordinate mapping by computing SMEM offset manually
|
||||
2. Check if PV expects transposed P matrix
|
||||
3. Examine PV MMA tiler and SMEM layout generation
|
||||
4. Consider alternative: fix TMEM layout generation instead
|
||||
|
||||
## Progress Update (2026-05-23 21:30 UTC)
|
||||
|
||||
**TMEM-P path now works at hd=64 (cos 0.973).** The root cause of NaN/zeros was a missing TMEM column offset on `tOrP0` — PV MMA was reading from column 0 (where S is) instead of column 32 (where P is stored by softmax warps). Fixed with `const_expr` conditional.
|
||||
|
||||
**SMEM-P remains unsolved.** The `make_tiled_copy_C` approach gives rank mismatch. Manual coordinate mapping compiles but produces near-zero cosine (wrong addresses). The CUTLASS reference FMHA uses TMEM-P exclusively (12-warp layout with more TMEM budget). For our 6-warp layout, SMEM-P is needed for hd>64.
|
||||
|
||||
**Current D1 status:**
|
||||
- hd=64 (TMEM-P): cos 0.973 ✅
|
||||
- hd=256 (SMEM-P stub): FAIL (zeros)
|
||||
- hd=512 (SMEM-P stub): FAIL (zeros)
|
||||
|
||||
**Workaround for hd>64:** The D5b milestone (Python SWA+sink merge) works at hd=64. SMEM-P for hd>64 is a production optimization, not a correctness blocker. The full DSV4 pipeline (CSA + HCA + SWA) can be tested at hd=64 with TMEM-P.
|
||||
|
||||
**Key discoveries:**
|
||||
1. `make_tiled_copy_C(store_atom, qk_mma)` creates a copy that partitions threads by QK C-fragment layout, but the source and destination have incompatible ranks (4 vs 3). This is a fundamental layout incompatibility.
|
||||
2. Manual coordinate mapping (`qk_to_pv_coord`) compiles and runs but produces wrong results. The mapping formula may be incorrect, or SMEM swizzle may interfere with tensor indexing.
|
||||
3. The CUTLASS reference FMHA (12-warp) avoids SMEM-P entirely by using TMEM-P with more warps (more TMEM budget). A 12-warp layout would solve the SMEM-P problem architecturally.
|
||||
|
||||
**Recommended next steps for SMEM-P:**
|
||||
1. Try 12-warp layout (like CUTLASS reference) to avoid SMEM-P entirely
|
||||
2. OR: Use `make_tiled_copy` with `pv_mma` (not `qk_mma`) for the copy, since PV MMA knows the SMEM layout
|
||||
3. OR: Implement a two-stage copy: QK C-fragment → intermediate buffer → PV A-operand SMEM
|
||||
Reference in New Issue
Block a user