auto: pre-test commit
This commit is contained in:
127
STAGE_D1.3.md
Normal file
127
STAGE_D1.3.md
Normal file
@@ -0,0 +1,127 @@
|
||||
# STAGE_D1.3.md — SMEM-P Manual Addressing Notes
|
||||
|
||||
## Problem Statement
|
||||
|
||||
Copy P (attention probabilities) from TMEM to SMEM with layout conversion:
|
||||
- **Source:** TMEM layout from QK C-fragment
|
||||
- **Destination:** SMEM layout for PV A-operand
|
||||
- Both represent same 128×128 P matrix, different tiling
|
||||
|
||||
## Layouts Observed (from debug prints)
|
||||
|
||||
### QK C-fragment Layout (TMEM)
|
||||
```
|
||||
tStS0 layout: ((128,128),1,1):((65536,1),0,0)
|
||||
tStS0 shape: ((128, 128), 1, 1)
|
||||
```
|
||||
Interpretation: Shape (128,128) with stride (65536, 1) — column-major? Stride between rows = 65536, between columns = 1.
|
||||
|
||||
### PV A-operand SMEM Layout
|
||||
```
|
||||
sP layout: ((128,16),1,(4,2),1):((64,1),0,(16,8192),0)
|
||||
sP shape: ((128, 16), 1, (4, 2), 1)
|
||||
```
|
||||
Interpretation: Complex tiling with modes: (128,16), 1, (4,2), 1
|
||||
|
||||
## Failed Approaches
|
||||
|
||||
### 1. `make_tiled_copy_C` (Helpers are a trap)
|
||||
- Created `tiled_smem_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma)`
|
||||
- `partition_S` and `partition_D` produce incompatible tensors
|
||||
- Source partition: 65536 elements, rank 4
|
||||
- Destination partition: 2097152 elements, rank 5 (32× mismatch)
|
||||
- Confirmed: helpers assume compatible tiling, but layouts are incompatible
|
||||
|
||||
### 2. Manual Copy Attempts
|
||||
**Attempt 1:** Get thread coordinates with `cute.coord(tStS0)` → `AttributeError: module 'cutlass.cute' has no attribute 'coord'`
|
||||
|
||||
**Attempt 2:** Simple test pattern with dynamic indexing → `DSLRuntimeError: object cannot be interpreted as an integer`
|
||||
- CuTeDSL JIT requires compile-time constants or vectorized loops
|
||||
- Dynamic SMEM offset computation not trivial
|
||||
|
||||
## Key Unknowns Blocking Progress
|
||||
|
||||
1. **Coordinate Systems:** How to get thread's logical position in QK C-fragment partition?
|
||||
2. **Mapping Formula:** Given QK C-fragment coordinate (i,j), what is corresponding PV A-operand coordinate?
|
||||
3. **Offset Computation:** How to compute SMEM offset from logical coordinate given a layout?
|
||||
4. **Manual Pattern:** What's the CuTeDSL pattern for manual layout conversion?
|
||||
|
||||
## Hypotheses
|
||||
|
||||
### Hypothesis A: Layouts represent same 128×128 data
|
||||
- QK C-fragment: tiles for QK MMA (C-fragment partitioning)
|
||||
- PV A-operand: tiles for PV MMA (A-operand partitioning)
|
||||
- Need coordinate transformation: QK tile → PV tile mapping
|
||||
|
||||
### Hypothesis B: `tStS0` is whole 128×128 matrix view
|
||||
- Not thread's slice but entire matrix in TMEM
|
||||
- Each thread accesses portion via its thread coordinates
|
||||
- Need thread's position within QK C-fragment partition
|
||||
|
||||
### Hypothesis C: Manual transpose required
|
||||
- Implement 128×128 transpose between two tiling patterns
|
||||
- Could brute-force with nested loops over 128×128
|
||||
- But thread owns only subset (partitioned by QK MMA)
|
||||
|
||||
## Questions for CUTLASS LLM
|
||||
|
||||
1. **Thread Coordinates:** How to get thread's logical coordinates within a layout partition?
|
||||
- `cute.coord()` doesn't exist
|
||||
- Need: thread position in QK C-fragment partition
|
||||
- Possibly: `cute.arch.thread_idx()`, `cute.arch.lane_idx()`, or from MMA?
|
||||
|
||||
2. **Layout Mapping:** Given source layout L1 and destination layout L2 for same logical space (128×128 matrix), how to compute coordinate transform?
|
||||
- L1: `((128,128),1,1):((65536,1),0,0)` (QK C-fragment)
|
||||
- L2: `((128,16),1,(4,2),1):((64,1),0,(16,8192),0)` (PV A-operand)
|
||||
- Need mapping: L1 coordinate → L2 coordinate
|
||||
|
||||
3. **Offset API:** Is there `cute.offset(layout, coord)`? How to compute physical offset from logical coordinate given a layout?
|
||||
|
||||
4. **Manual Pattern:** Example of manual copy between incompatible layouts in CuTeDSL?
|
||||
- Source: tensor with layout L1
|
||||
- Destination: tensor with layout L2
|
||||
- Manual element-wise copy with address computation
|
||||
|
||||
5. **MMA Partitioning:** How does QK MMA partition the 128×128 C-fragment among threads?
|
||||
- Each thread owns subset of elements
|
||||
- Need to know which elements thread owns
|
||||
|
||||
## Found Clue in vendored Code
|
||||
In `moe_torch_scaled_grouped_mm.py`:
|
||||
```python
|
||||
bidx, bidy, bidz = cute.arch.block_idx()
|
||||
mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
|
||||
```
|
||||
- Uses block index to compute MMA tile coordinate
|
||||
- `tiled_mma.thr_id.shape` gives thread partition shape
|
||||
- But we need thread coordinate within warp/CTA, not block
|
||||
|
||||
## Current Code State
|
||||
|
||||
In `fmha.py`, SMEM-P section is:
|
||||
```python
|
||||
else:
|
||||
# SMEM-P: Manual addressing (helpers are a trap)
|
||||
print(f"[SMEM-P MANUAL] Starting manual P write to SMEM")
|
||||
print(f"[SMEM-P MANUAL] tStS0 layout: {tStS0.layout}")
|
||||
print(f"[SMEM-P MANUAL] sP layout: {sP.layout}")
|
||||
|
||||
# TODO: Implement proper mapping
|
||||
# For now, zero sP (wrong but compiles)
|
||||
for j in cutlass.range(cute.size(sP), vectorize=True):
|
||||
sP[j] = BFloat16(0.0)
|
||||
```
|
||||
|
||||
## Test Status
|
||||
- `test_fmha_v3_stage_d1.py` fails with cosine NaN for hd=64, -0.047 for hd=256, 0.011 for hd=512
|
||||
- Zeroed sP causes PV MMA to read zeros → garbage output
|
||||
- TMEM-P path works for hd=64 but fails for hd>64 due to TMEM layout mismatch
|
||||
|
||||
## Blocked Until
|
||||
1. Answer from CUTLASS LLM on coordinate mapping
|
||||
2. Or alternative approach suggested
|
||||
|
||||
## Time Pressure
|
||||
- Been debugging for ~20 minutes
|
||||
- Manual addressing proving much harder than anticipated
|
||||
- Without mapping formula, guessing is low-probability
|
||||
Reference in New Issue
Block a user