README: add fire_b200_test docs, update multi-tile blocker with real findings
This commit is contained in:
32
README.md
32
README.md
@@ -199,6 +199,23 @@ bash tests/run_test.sh tests/unit/test_fmha_v3_stage_c_full.py
|
||||
bash tests/check_log.sh
|
||||
```
|
||||
|
||||
### `fire_b200_test` — One-command local test runner
|
||||
|
||||
Lives in `~/.openclaw/workspace/fire_b200_test` (NOT in the repo — project-specific tooling).
|
||||
|
||||
```bash
|
||||
# From your local machine, one command to push, run, and get results:
|
||||
~/.openclaw/workspace/fire_b200_test tests/unit/test_fmha_v3.py
|
||||
```
|
||||
|
||||
What it does:
|
||||
1. Auto-commits and pushes any local changes
|
||||
2. SSH to B200, pulls, starts `run_test.sh` in a screen
|
||||
3. Polls every 15s until the screen exits
|
||||
4. Dumps the full test log to your terminal
|
||||
|
||||
**This is strictly for the DSV4 NVFP4 kernel project.** It hardcodes the B200 IP, repo paths, and git remote.
|
||||
|
||||
---
|
||||
|
||||
## Stage C: Online Softmax — SINGLE-TILE ONLY
|
||||
@@ -210,13 +227,20 @@ bash tests/check_log.sh
|
||||
|
||||
### Multi-Tile Blocker: TMA GMEM Tile Indexing
|
||||
|
||||
The original TMA partition slices `tBgK` with `(None, 0, None, 0)` which **hardcodes the GMEM iteration dimension to tile 0**. This means TMA always loads K/V from the first 128 tokens regardless of kt. Output is identical for all n>128.
|
||||
The TMA partition slices `tBgK`/`tVgV` with `(None, 0, None, 0)`. The free mode after slicing is the GMEM iteration dimension. A `kv_coord` variable is used to index it. **Problem: the `kv_coord` increment is not propagating to the TMA at runtime.**
|
||||
|
||||
**Why you can't just index with kt:** CuTeDSL's TMA copy API accepts pipeline state values (like `kh.count`) as TMA coordinates but does NOT accept Python int from `range()`. Indexing with kt fails at operation creation.
|
||||
**Evidence (May 22):**
|
||||
- `kv_coord = Int32(0)` + `kv_coord += 1` in `cutlass.range` loop → all multi-tile outputs identical (TMA loads from tile 0 every iteration)
|
||||
- `kv_coord = 0` (plain Python int) + `kv_coord += 1` → same broken result
|
||||
- `kv_coord = Int32(1)` hardcoded → output **changes** (TMA CAN load from tile 1, the coordinate just isn't being dynamically updated)
|
||||
- Pipeline handle `.count` also doesn't work (it's opaque pipeline state, not a GMEM coordinate)
|
||||
|
||||
**Fix (Mike):** Combined K+V barrier — one `acquire_and_advance` per kt, two cute.copy calls sharing `kvh.barrier`. With no interleaving, `kvh.count` naturally equals kt and stays a first-class pipeline state value. See `fmha_v3_stage_c_example2.py`.
|
||||
**Root cause:** CuTeDSL's JIT appears to constant-fold or not propagate the `kv_coord += 1` increment to the TMA descriptor at runtime. The CUTLASS reference uses the same pattern with a Python int `kv_coord` — unclear why it works there but not here (possibly different CuTeDSL version or loop structure).
|
||||
|
||||
**Current status of fix:** Compiles but deadlocks at runtime (even n=128). The 3-way sync between `acc_pipe`, `softmax_done_bar`, and `final_o_bar` needs debugging. Fallback: `kh.count // 2` in the original interleaved kernel (CuTeDSL Int32 overloads `__floordiv__` in recent versions).
|
||||
**Debug shape info:**
|
||||
- `tBgK` before slice: `(((64, 128), 1), Int32(?), Int32(?), Int32(?))` — modes 1,2,3 all dynamic
|
||||
- `tVgV` before slice: `(((64, 128), 1), 1, N, 1)` — mode 2 grows with n (confirmed GMEM iter)
|
||||
- After `(None,0,None,0)`: both become `(((64, 128), 1), N_or_Int32(?))` — 2D
|
||||
|
||||
### Files
|
||||
|
||||
|
||||
Reference in New Issue
Block a user