auto: pre-test commit
This commit is contained in:
@@ -132,16 +132,36 @@ else:
|
||||
- ❌ Need to implement full 128-value mapping per thread
|
||||
- ❌ Need to get QK coordinates for each of thread's 128 P values
|
||||
|
||||
### Next Steps:
|
||||
1. Create coordinate tensor `cP_qk = cute.make_identity_tensor(tStS0.shape)`
|
||||
2. Partition it same way as `rP_bf16` (through `tTMEM_LOADcP`)
|
||||
3. In softmax loop, for each fragment j and element k:
|
||||
- Get P value from `rP_bf16_frg` or directly from `tTMEM_LOADrS_frg`
|
||||
- Get coordinate from partitioned coordinate tensor
|
||||
- Map to PV coordinate using `qk_to_pv_coord`
|
||||
- Write to SMEM: `sP[dst_coord] = value`
|
||||
### Implementation Status (2026-05-23 19:55 UTC)
|
||||
✅ **Implemented full SMEM-P with coordinate mapping**
|
||||
- Created coordinate tensor `cS` (already existed for row_max)
|
||||
- Partitioned as `tTMEM_LOADcS_frg` matching P value fragments
|
||||
- In softmax loop, for each `(k,j)`:
|
||||
- Get QK coordinate `(m,n)` from `tTMEM_LOADcS_frg[k,j]`
|
||||
- Map to PV SMEM coordinate using formula
|
||||
- Write P value (or test pattern) to `sP[pv_coord]`
|
||||
|
||||
### Time Pressure:
|
||||
- Got working coordinate mapping
|
||||
- Need to implement full mapping (~15-30 min)
|
||||
- Then test and debug
|
||||
❌ **Result:** Cosine ~0.02 (near zero correlation)
|
||||
- Kernel compiles and runs
|
||||
- PV reads SOMETHING from SMEM (output non-zero)
|
||||
- But mapping appears wrong (random correlation)
|
||||
- Output scaling huge (280k vs reference 0.2)
|
||||
|
||||
### Possible Issues:
|
||||
1. **Coordinate mapping formula wrong** — PV A-operand layout might differ
|
||||
2. **SMEM swizzle mismatch** — tensor indexing might not handle swizzle correctly
|
||||
3. **Thread collisions** — Multiple threads writing same SMEM location
|
||||
4. **P value normalization** — Unnormalized P values cause scaling issues
|
||||
|
||||
### Debug Attempts:
|
||||
1. Test pattern `(k+j)*0.01` → cosine 0.02
|
||||
2. Linear index `m*128+n` → cosine 0.006 (huge output as expected)
|
||||
3. Both show mapping is bijective but wrong locations
|
||||
|
||||
### Next Actions:
|
||||
1. Verify coordinate mapping by computing SMEM offset manually
|
||||
2. Check if PV expects transposed P matrix
|
||||
3. Examine PV MMA tiler and SMEM layout generation
|
||||
4. Consider alternative: fix TMEM layout generation instead
|
||||
|
||||
**Time spent:** ~45 minutes. Have working but incorrect SMEM-P.
|
||||
Reference in New Issue
Block a user