auto: pre-test commit

This commit is contained in:
2026-05-23 19:29:29 +00:00
parent dee046287e
commit 97e97b63ea

127
STAGE_D1.3.md Normal file
View File

@@ -0,0 +1,127 @@
# STAGE_D1.3.md — SMEM-P Manual Addressing Notes
## Problem Statement
Copy P (attention probabilities) from TMEM to SMEM with layout conversion:
- **Source:** TMEM layout from QK C-fragment
- **Destination:** SMEM layout for PV A-operand
- Both represent same 128×128 P matrix, different tiling
## Layouts Observed (from debug prints)
### QK C-fragment Layout (TMEM)
```
tStS0 layout: ((128,128),1,1):((65536,1),0,0)
tStS0 shape: ((128, 128), 1, 1)
```
Interpretation: Shape (128,128) with stride (65536, 1) — column-major? Stride between rows = 65536, between columns = 1.
### PV A-operand SMEM Layout
```
sP layout: ((128,16),1,(4,2),1):((64,1),0,(16,8192),0)
sP shape: ((128, 16), 1, (4, 2), 1)
```
Interpretation: Complex tiling with modes: (128,16), 1, (4,2), 1
## Failed Approaches
### 1. `make_tiled_copy_C` (Helpers are a trap)
- Created `tiled_smem_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma)`
- `partition_S` and `partition_D` produce incompatible tensors
- Source partition: 65536 elements, rank 4
- Destination partition: 2097152 elements, rank 5 (32× mismatch)
- Confirmed: helpers assume compatible tiling, but layouts are incompatible
### 2. Manual Copy Attempts
**Attempt 1:** Get thread coordinates with `cute.coord(tStS0)``AttributeError: module 'cutlass.cute' has no attribute 'coord'`
**Attempt 2:** Simple test pattern with dynamic indexing → `DSLRuntimeError: object cannot be interpreted as an integer`
- CuTeDSL JIT requires compile-time constants or vectorized loops
- Dynamic SMEM offset computation not trivial
## Key Unknowns Blocking Progress
1. **Coordinate Systems:** How to get thread's logical position in QK C-fragment partition?
2. **Mapping Formula:** Given QK C-fragment coordinate (i,j), what is corresponding PV A-operand coordinate?
3. **Offset Computation:** How to compute SMEM offset from logical coordinate given a layout?
4. **Manual Pattern:** What's the CuTeDSL pattern for manual layout conversion?
## Hypotheses
### Hypothesis A: Layouts represent same 128×128 data
- QK C-fragment: tiles for QK MMA (C-fragment partitioning)
- PV A-operand: tiles for PV MMA (A-operand partitioning)
- Need coordinate transformation: QK tile → PV tile mapping
### Hypothesis B: `tStS0` is whole 128×128 matrix view
- Not thread's slice but entire matrix in TMEM
- Each thread accesses portion via its thread coordinates
- Need thread's position within QK C-fragment partition
### Hypothesis C: Manual transpose required
- Implement 128×128 transpose between two tiling patterns
- Could brute-force with nested loops over 128×128
- But thread owns only subset (partitioned by QK MMA)
## Questions for CUTLASS LLM
1. **Thread Coordinates:** How to get thread's logical coordinates within a layout partition?
- `cute.coord()` doesn't exist
- Need: thread position in QK C-fragment partition
- Possibly: `cute.arch.thread_idx()`, `cute.arch.lane_idx()`, or from MMA?
2. **Layout Mapping:** Given source layout L1 and destination layout L2 for same logical space (128×128 matrix), how to compute coordinate transform?
- L1: `((128,128),1,1):((65536,1),0,0)` (QK C-fragment)
- L2: `((128,16),1,(4,2),1):((64,1),0,(16,8192),0)` (PV A-operand)
- Need mapping: L1 coordinate → L2 coordinate
3. **Offset API:** Is there `cute.offset(layout, coord)`? How to compute physical offset from logical coordinate given a layout?
4. **Manual Pattern:** Example of manual copy between incompatible layouts in CuTeDSL?
- Source: tensor with layout L1
- Destination: tensor with layout L2
- Manual element-wise copy with address computation
5. **MMA Partitioning:** How does QK MMA partition the 128×128 C-fragment among threads?
- Each thread owns subset of elements
- Need to know which elements thread owns
## Found Clue in vendored Code
In `moe_torch_scaled_grouped_mm.py`:
```python
bidx, bidy, bidz = cute.arch.block_idx()
mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
```
- Uses block index to compute MMA tile coordinate
- `tiled_mma.thr_id.shape` gives thread partition shape
- But we need thread coordinate within warp/CTA, not block
## Current Code State
In `fmha.py`, SMEM-P section is:
```python
else:
# SMEM-P: Manual addressing (helpers are a trap)
print(f"[SMEM-P MANUAL] Starting manual P write to SMEM")
print(f"[SMEM-P MANUAL] tStS0 layout: {tStS0.layout}")
print(f"[SMEM-P MANUAL] sP layout: {sP.layout}")
# TODO: Implement proper mapping
# For now, zero sP (wrong but compiles)
for j in cutlass.range(cute.size(sP), vectorize=True):
sP[j] = BFloat16(0.0)
```
## Test Status
- `test_fmha_v3_stage_d1.py` fails with cosine NaN for hd=64, -0.047 for hd=256, 0.011 for hd=512
- Zeroed sP causes PV MMA to read zeros → garbage output
- TMEM-P path works for hd=64 but fails for hd>64 due to TMEM layout mismatch
## Blocked Until
1. Answer from CUTLASS LLM on coordinate mapping
2. Or alternative approach suggested
## Time Pressure
- Been debugging for ~20 minutes
- Manual addressing proving much harder than anticipated
- Without mapping formula, guessing is low-probability