Update GETTING_CUDAGRAPH_READY.md and CUDA_GRAPH_SYNC_INVENTORY.md
- L0 CUDA graph capture PASSES on B200 - All compute-forward sync violations fixed - 3/5 Section C hazards done, 2 deferred to Phase 2 - Full violation fix log with commits - Next steps: extend to all 61 layers + replay verification
This commit is contained in:
@@ -1,123 +1,120 @@
|
||||
# CUDA Graph Readiness — Sync Violation Inventory
|
||||
|
||||
**Date:** 2026-06-03
|
||||
**Source:** Section A detector run + manual code grep (Section B checklist)
|
||||
**Date:** 2026-06-03 (updated 19:12 UTC)
|
||||
**Source:** Section A detector runs on B200 + manual code grep (Section B checklist)
|
||||
**Target:** single_shot_inference.py decode forward (1 token step, T=1)
|
||||
|
||||
## B200 Detector Results (first run)
|
||||
## Summary
|
||||
|
||||
Method 1 (sync debug mode): **1 violation** caught
|
||||
- `dec_tid_buf[0] = all_tokens[-1]` — CPU→GPU sync from writing Python int to GPU tensor
|
||||
- **FIXED**: Use pinned CPU buffer + copy_
|
||||
**ALL sync violations in the compute forward path have been fixed.** Layer 0 CUDA graph capture PASSES on B200.
|
||||
|
||||
Method 2 (graph capture L0): **FAIL**
|
||||
- `expert_offsets[g] = (g + 1) * padded_rows_per_group` — CPU→GPU sync in Python loop
|
||||
- **FIXED**: Pre-allocated range tensor + element-wise multiply
|
||||
|
||||
Both fixes committed and pushed. Re-running detector to verify.
|
||||
|
||||
The decode forward has **numerous device→host sync violations** that must be fixed before CUDA graph capture can succeed. The violations fall into clear categories below.
|
||||
- **Method 1** (sync debug): 0 violations in forward compute. The `dec_tid_buf.copy_(dec_tid_pinned)` is a valid graph-capturable pinned memcpy (sync debug is overly strict).
|
||||
- **Method 2** (L0 graph capture): **PASS** ✅
|
||||
|
||||
---
|
||||
|
||||
## CATEGORY 1: Explicit `.item()` syncs on hot path
|
||||
## B200 Detector Results
|
||||
|
||||
### single_shot_inference.py — decode loop (lines ~1600-1700)
|
||||
### Run 1 (commit 0ca7bed)
|
||||
- Method 1: 1 violation — `dec_tid_buf[0] = all_tokens[-1]` (CPU→GPU sync from Python int)
|
||||
- Method 2: FAIL — `expert_offsets[g] = (g + 1) * padded_rows_per_group` (CPU→GPU sync in Python loop)
|
||||
|
||||
| Line | Code | Severity | Fix |
|
||||
|------|------|----------|-----|
|
||||
| ~1618 | `lin._gsa_buf.item()` in warmup_gsa block | HIGH — syncs per projection | Move warmup_gsa to a single `torch.cuda.synchronize()` + batched read; eliminate from graph region |
|
||||
| ~1642 | `torch.argmax(logits, -1).item()` for greedy sampling | HIGH — but outside graph | Sampling is outside captured region by design (vLLM pattern) |
|
||||
| ~1683 | `sampled[0].item()` for sampling | HIGH — but outside graph | Same as above |
|
||||
| ~1657 | `torch.cuda.synchronize()` for error checking | MEDIUM | Remove from graph region; only check outside |
|
||||
### Run 2 (commit e07d798)
|
||||
- Method 1: 1 violation — same `dec_tid_buf` (test code not yet fixed)
|
||||
- Method 2: FAIL — `torch.bincount` in MoE (data-dependent shapes)
|
||||
|
||||
### single_shot_inference.py — diagnostics (controlled by VERBOSE >= 2)
|
||||
### Run 3 (commit 84655d0)
|
||||
- Method 1: 1 violation — same `dec_tid_buf`
|
||||
- Method 2: FAIL — illegal memory access from stride-0 gsa expand view
|
||||
|
||||
| Line | Code | Severity | Fix |
|
||||
|------|------|----------|-----|
|
||||
| 933 | `attn_out.abs().max().item()` | LOW — guarded by VERBOSE | Already gated; remove entirely for graph capture |
|
||||
| 962 | `F_attn.abs().max().item()` | LOW — guarded | Same |
|
||||
| 974-975 | `topk_ids.max().item()`, `topk_ids.min().item()` | LOW — guarded | Same |
|
||||
| 981 | `gate_logits.min().item()`, `.max().item()`, `.mean().item()` | LOW — guarded | Same |
|
||||
| 983 | `torch.isnan(x).any().item()` | LOW — guarded | Same |
|
||||
| 987 | Various `.item()` in MoE DIAG | LOW — guarded | Same |
|
||||
| 995-999 | SE weight diagnostics | LOW — guarded | Same |
|
||||
| 1068-1086 | `X_next.abs().max().item()`, mHC diagnostics | LOW — guarded | Same |
|
||||
|
||||
### dsv4/layers/mhc.py — post_block (line 422)
|
||||
|
||||
| Line | Code | Severity | Fix |
|
||||
|------|------|----------|-----|
|
||||
| 422 | `X_next.abs().max().item()` — runs on EVERY layer | **CRITICAL** — syncs 122x per step (61 layers × 2 mHC) | Remove `.item()` entirely; the `pass` body makes this useless anyway |
|
||||
### Run 4 (commit 80bb27f) — CURRENT
|
||||
- Method 1: 0 violations in forward (only pinned memcpy flagged, which is graph-capturable)
|
||||
- Method 2: **PASS** ✅ — L0 graph capture succeeds
|
||||
|
||||
---
|
||||
|
||||
## CATEGORY 2: Per-step tensor allocations (graph capture killer)
|
||||
## CATEGORY 1: Explicit `.item()` syncs on hot path — ALL FIXED ✅
|
||||
|
||||
| File | Line | Code | Fix |
|
||||
|------|------|------|-----|
|
||||
| `dsv4/layers/linear.py` | 128 | `torch.zeros(padded_rows, padded_cols, ...)` in `_assemble_scales_single_group` | Pre-allocate scale buffer at max size; reuse with zero+scatter pattern |
|
||||
| `dsv4/layers/shared_expert.py` | 213 | Same pattern — `torch.zeros(...)` in `_assemble_scales_single_group` | Same fix |
|
||||
| `dsv4/ops/quantize.py` | 320 | `x_bf16.contiguous()` — may allocate if non-contiguous | Ensure inputs are always contiguous (pre-allocate) |
|
||||
| `dsv4/ops/quantize.py` | 327-329 | `gsa_gpu.reshape(1).expand(M).contiguous()` — allocates | Pre-allocate gsa buffer; use copy_ instead of expand+contiguous |
|
||||
| `single_shot_inference.py` | ~1600 | `mHCLayer.init_state(embed(dec_tid_buf))` — creates new tensor | Pre-allocate X buffer; use in-place copy |
|
||||
| File | Line | Fix | Commit |
|
||||
|------|------|-----|--------|
|
||||
| `dsv4/layers/mhc.py` | 422 | Removed `X_next.abs().max().item()` (122 syncs/step) | `a9ea303` |
|
||||
| `single_shot_inference.py` | ~1600 | Warmup-gsa `.item()` — one-time, outside graph | OK (by design) |
|
||||
| `single_shot_inference.py` | ~1642 | `argmax(logits).item()` — outside graph (sampling) | OK (by design) |
|
||||
|
||||
All VERBOSE-gated `.item()` calls (diagnostics) are safe at VERBOSE=0.
|
||||
|
||||
---
|
||||
|
||||
## CATEGORY 3: Data-dependent control flow (host branches on device-derived values)
|
||||
## CATEGORY 2: Per-step tensor allocations — ALL FIXED ✅
|
||||
|
||||
| File | Line | Code | Fix |
|
||||
|------|------|------|-----|
|
||||
| `single_shot_inference.py` | 335 | `if self.ratio == 0 or self._kv_bf16 is None: return None` — ratio is static per layer, but `_kv_bf16 is None` depends on load | This is static per layer — graph captures per-layer, so this is OK |
|
||||
| `single_shot_inference.py` | 352 | `if self._buf_len < r: return None` — compressor buffering reads host int | **Section C, Hazard #1**: Must compress every step; emit device-side |
|
||||
| `single_shot_inference.py` | 360 | `if n_complete == 0: return None` — depends on T (host-known for decode) | For decode T=1, HCA always returns None. This is host-known — OK per layer, but need fixed-shape output |
|
||||
| `single_shot_inference.py` | 376 | `if compressed.shape[0] == 0: return None` — data-dependent shape | Must always produce fixed-shape output (padded) |
|
||||
| `single_shot_inference.py` | 435 | `if ... kv_cache.n_comp == 0: return None` — host reads Python int | n_comp grows over time — **Section C, Hazard #2**: paged KV with fixed blocks |
|
||||
| `single_shot_inference.py` | ~935 | `if kv_cache.n_comp > 0:` — host branch on n_comp | Same fix: paged KV |
|
||||
| `single_shot_inference.py` | ~955 | `seq_len = kv_nope_scale.shape[0]` — dynamic shape | Fixed-shape gather with masking |
|
||||
| File | Line | Fix | Commit |
|
||||
|------|------|-----|--------|
|
||||
| `dsv4/layers/linear.py` | 128 | Pre-allocated `_scale_a_buf` | `a9ea303` |
|
||||
| `dsv4/layers/shared_expert.py` | 213 | Same fix — pre-allocated `padded_x_sf_buf` + view | `a9ea303`, `e07d798` |
|
||||
| `dsv4/layers/grouped_linear.py` | 240 | Pre-allocated `_scale_a_buf` | `f13a81d` |
|
||||
| `dsv4/layers/grouped_linear.py` | ~374 | Pre-allocated `_output_buf` | `0ca7bed` |
|
||||
| `dsv4/layers/moe.py` | ~508 | `torch.full` → `self._l1_gsa_buf.fill_()` | `84655d0` |
|
||||
| `dsv4/ops/quantize.py` | 84,88 | `torch.zeros_like` → scalar `0.0` | `f13a81d` |
|
||||
| `dsv4/ops/quantize.py` | 327-329 | gsa: reshape for M=1, contiguous for M>1 | `80bb27f` |
|
||||
| `dsv4/layers/mhc.py` | init_state | `out_buf` parameter for in-place write | `46a3a51` |
|
||||
| `single_shot_inference.py` | ~1600 | Pre-allocated `dec_X_buf` | `46a3a51` |
|
||||
|
||||
---
|
||||
|
||||
## CATEGORY 4: Cross-GPU transfers inside graph
|
||||
## CATEGORY 3: Data-dependent control flow — FIXED / DEFERRED
|
||||
|
||||
| File | Line | Code | Fix |
|
||||
|------|------|------|-----|
|
||||
| `single_shot_inference.py` | ~1600 | `X.to(f"cuda:{gpu}")` in layer loop | Cannot be in graph; break graph at attention (eager-break pattern) or pre-stage on target GPU |
|
||||
| File | Issue | Status | Fix |
|
||||
|------|-------|--------|-----|
|
||||
| `single_shot_inference.py` | `dec_tid_buf[0] = python_int` | ✅ FIXED | Pinned CPU buffer + `copy_` | `0ca7bed` |
|
||||
| `dsv4/layers/grouped_linear.py` | `expert_offsets[g] = python_int` | ✅ FIXED | Pre-allocated range tensor + element-wise multiply | `0ca7bed` |
|
||||
| `dsv4/layers/grouped_linear.py` | `if group_offsets[0] != 0` | ✅ FIXED | Unconditional GPU-only update | `df05289` |
|
||||
| `dsv4/layers/moe.py` | `torch.bincount` (data-dependent shapes) | ✅ FIXED | `scatter_add_` into pre-allocated buffer | `84655d0`, `518a1d3` |
|
||||
| `single_shot_inference.py` | Compressor returns `None` | ⏳ Phase 2 | Eager-break-at-attention: compressor runs outside graph |
|
||||
| `single_shot_inference.py` | KV `n_comp` Python int | ⏳ Phase 2 | Eager-break: attention runs outside graph |
|
||||
|
||||
---
|
||||
|
||||
## CATEGORY 5: torch.cuda.synchronize() on hot path
|
||||
## CATEGORY 4: Cross-GPU transfers inside graph — NOT YET ADDRESSED ⏳
|
||||
|
||||
| File | Line | Code | Fix |
|
||||
|------|------|------|-----|
|
||||
| `single_shot_inference.py` | 816 | `torch.cuda.synchronize()` in profile timing | Guarded by `_profile_detail` — must be False during graph capture |
|
||||
| `single_shot_inference.py` | 1041-1065 | `torch.cuda.synchronize()` in forward_layer profile | Same — must be disabled |
|
||||
| `single_shot_inference.py` | 1088 | `torch.cuda.synchronize()` in forward_layer diag | Guarded by profile flag |
|
||||
| `dsv4/layers/mhc.py` | 422 | Implicit sync via `.item()` | Remove |
|
||||
| File | Issue | Fix |
|
||||
|------|-------|-----|
|
||||
| `single_shot_inference.py` | `X.to(f"cuda:{gpu}")` in layer loop | Per-GPU X buffers + cross-GPU memcpy outside graph, or capture per-GPU subgraphs |
|
||||
|
||||
---
|
||||
|
||||
## Section C Hazards (from GETTING_CUDAGRAPH_READY.md)
|
||||
## CATEGORY 5: torch.cuda.synchronize() on hot path — ALL CONDITIONAL ✅
|
||||
|
||||
| # | Hazard | Current State | Fix Required |
|
||||
|---|--------|---------------|--------------|
|
||||
| 1 | Compressor returns None for most decode steps | `_buf_len` host check, returns None | Compress every step into persistent partial state; emit device-side on boundary |
|
||||
| 2 | KV grows each step | `n_comp` Python int, dynamic gather shapes | Paged KV (fixed blocks + block table) or make attention the eager break |
|
||||
| 3 | Indexer top-k → host reads count | `topk_indices` is fixed top_k shape — **already OK** | Already fixed-shape gather |
|
||||
| 4 | MoE per-expert token counts | `torch.bincount` in MoE run, but offsets are GPU tensors | Already uses device offsets and fixed total launch — **already OK** |
|
||||
| 5 | Next token/positions on host | Fresh `dec_tid_buf`, `dec_pos_buf` each step | Pre-allocated buffers with `copy_` — **already mostly OK** |
|
||||
| File | Line | Guard |
|
||||
|------|------|-------|
|
||||
| `single_shot_inference.py` | 816, 1041-1065 | `_profile_detail` flag — must be False during capture |
|
||||
| `single_shot_inference.py` | 1088 | Profile flag |
|
||||
|
||||
---
|
||||
|
||||
## Fix Priority
|
||||
## Section C Hazard Summary (from GETTING_CUDAGRAPH_READY.md)
|
||||
|
||||
1. **mhc.py line 422** — remove `.item()` (1 line fix, 122 syncs eliminated)
|
||||
2. **linear.py `_assemble_scales_single_group`** — pre-allocate scale buffer
|
||||
3. **shared_expert.py `_assemble_scales_single_group`** — same fix
|
||||
4. **quantize.py gsa expansion** — pre-allocate, use copy_ instead of expand+contiguous
|
||||
5. **Compressor Section C hazard** — compress every step, emit device-side
|
||||
6. **KV cache Section C hazard** — paged KV or eager-break at attention
|
||||
7. **Diagnostics `.item()` cleanup** — gate behind compile-time flag, not runtime VERBOSE
|
||||
8. **Warmup gsa** — batched sync, not per-projection `.item()`
|
||||
| # | Hazard | Status |
|
||||
|---|--------|--------|
|
||||
| 1 | Compressor returns None for most decode steps | ⏳ Phase 2 (eager-break) |
|
||||
| 2 | KV grows each step | ⏳ Phase 2 (eager-break) |
|
||||
| 3 | Indexer top-k → host reads count | ✅ Already fixed-shape |
|
||||
| 4 | MoE per-expert token counts | ✅ scatter_add_ with pre-allocated buffer |
|
||||
| 5 | Next token/positions on host | ✅ Pinned CPU buffers + copy_ |
|
||||
|
||||
The single-shot should be re-run with `VERBOSE=0` and `--no-fused-rmsnorm` disabled (use fused) to minimize conditional sync paths during detection.
|
||||
---
|
||||
|
||||
## Remaining Work for Full Graph Capture
|
||||
|
||||
1. **Extend capture to all 61 layers** — L0 passes, need L1-L60
|
||||
2. **Capture hc_head + norm + lm_head** on cuda:0
|
||||
3. **Cross-GPU transfers** — per-GPU X buffers, or per-GPU subgraphs
|
||||
4. **Replay verification** — bit-for-bit match with eager forward
|
||||
5. **Performance benchmark** — measure speedup from graph capture
|
||||
6. **Gate commits** on capture test
|
||||
|
||||
## Phase 2 (vLLM Integration)
|
||||
|
||||
- Paged KV cache (fixed blocks + block table)
|
||||
- Device-side compressor boundary detection + fixed-shape output
|
||||
- Full graph capture including FMHA
|
||||
- Bucket-by-shape for variable sequence lengths
|
||||
|
||||
@@ -10,85 +10,119 @@ You do **not** need one monolithic graph. The standard pattern (what vLLM's DSV4
|
||||
|
||||
---
|
||||
|
||||
## SECTION A — The detector (build this FIRST, before porting anything)
|
||||
## SECTION A — The detector (build this FIRST, before porting anything) ✅ DONE
|
||||
|
||||
Stop hunting syncs by hand. Make them fail at the exact line:
|
||||
**Status:** Built and verified on B200 (2026-06-03). See `tests/unit/test_cuda_graph_readiness.py`.
|
||||
|
||||
```python
|
||||
import torch
|
||||
torch.cuda.set_sync_debug_mode("error") # raises at any implicit device→host sync
|
||||
# ... run one decode step of the forward ...
|
||||
torch.cuda.set_sync_debug_mode("default")
|
||||
Results from detector runs on B200:
|
||||
- **Method 1** (sync debug mode): 0 violations in forward compute path
|
||||
- `dec_tid_buf.copy_(dec_tid_pinned)` is flagged but this is a valid graph-capturable pinned memcpy
|
||||
- All `.item()` syncs eliminated from hot path
|
||||
- **Method 2** (graph capture L0): **PASS** ✅
|
||||
- `torch.cuda.CUDAGraph()` capture of layer 0 decode step succeeds
|
||||
- All per-call allocations eliminated
|
||||
- All host reads of GPU values eliminated
|
||||
|
||||
The detector:
|
||||
1. Grep for Section B sync patterns in hot path files
|
||||
2. Run one decode step with `torch.cuda.set_sync_debug_mode("error")`
|
||||
3. Attempt `torch.cuda.graph` capture of L0 decode step
|
||||
4. Report results to `/tmp/cuda_graph_readiness_results.json`
|
||||
|
||||
Run via test harness:
|
||||
```bash
|
||||
fire_b200_test tests/unit/test_cuda_graph_readiness.py kernel-test /tmp/kernel-test.log 1800
|
||||
```
|
||||
|
||||
And a capture-under-test (most illegal host ops error *during* capture):
|
||||
```python
|
||||
g = torch.cuda.CUDAGraph()
|
||||
# static input buffers allocated ONCE, outside capture:
|
||||
with torch.cuda.graph(g):
|
||||
out = decode_step(static_inputs) # capture fails loudly on .item(), sync, alloc, etc.
|
||||
for _ in range(3):
|
||||
static_inputs.copy_(next_inputs); g.replay() # replay must reproduce eager output
|
||||
```
|
||||
|
||||
**Do this on the current `single_shot` forward first** — it inventories *every* existing sync in one pass, so you get the whole hunt-list upfront instead of discovering them one at a time during vLLM bring-up. Then gate every commit on both checks in CI; the day someone adds a `.item()`, the build fails at that line.
|
||||
|
||||
Also useful: `compute-sanitizer --tool synccheck`, and `nsys` to eyeball CPU↔GPU stall gaps.
|
||||
|
||||
---
|
||||
|
||||
## SECTION B — The hidden-CPU checklist (grep the hot path for these)
|
||||
## SECTION B — The hidden-CPU checklist (grep the hot path for these) ✅ ADDRESSED
|
||||
|
||||
**Explicit device→host transfers**
|
||||
`.item()` · `.cpu()` · `.tolist()` · `.numpy()` · `int(t)` / `float(t)` / `bool(t)` · `print(t)` · f-strings/logging that interpolate a tensor · `assert (device_condition)` (e.g. `assert (x>0).all()`) · `.to("cpu")`
|
||||
**Explicit device→host transfers** — All `.item()` calls on hot path eliminated:
|
||||
- mhc.py `post_block`: removed `X_next.abs().max().item()` (was 122 syncs/step across 61 layers × 2 mHC)
|
||||
- All other `.item()` calls are guarded by `VERBOSE >= 2` and don't execute at VERBOSE=0
|
||||
- Warmup-gsa `.item()` calls run once at step 0, outside graph region
|
||||
|
||||
**Host control flow on device values**
|
||||
`if t:` · `if mask.any():` · `if x.sum() > thr:` · `while t > 0:` · `for i in range(n.item())` · convergence early-exit reading a device residual · choosing a kernel based on the sampled token
|
||||
**Data-dependent shapes** — Eliminated `torch.bincount` from MoE:
|
||||
- Replaced with `scatter_add_` into pre-allocated `_tokens_per_expert_buf` (fixed shape, GPU-only)
|
||||
- Pre-allocated `_ones_buf` to avoid per-call `torch.ones()`
|
||||
|
||||
**Data-dependent shapes (these both change shape AND sync)**
|
||||
`torch.nonzero` · `torch.where(cond)` (one-arg form) · `torch.unique` · `torch.bincount` (when it drives a shape) · boolean/mask indexing `x[mask]`, `x[x>0]` · `masked_select` · `reshape(n.item(), ...)` · any gather sized by a device-computed count
|
||||
**Per-step host allocation** — All eliminated:
|
||||
- `torch.zeros()` in `_assemble_scales_single_group` → pre-allocated `_scale_a_buf` (linear.py, grouped_linear.py, shared_expert.py)
|
||||
- `torch.full()` for MoE l1_gsa → `self._l1_gsa_buf.fill_(l1_gs)`
|
||||
- `torch.empty()` for grouped_linear output → pre-allocated `_output_buf`
|
||||
- `mHCLayer.init_state` `.clone()` → `out_buf` parameter for in-place write
|
||||
- `torch.zeros_like` in quantize.py → scalar `0.0` in `torch.where`
|
||||
|
||||
**Per-step host allocation**
|
||||
`torch.empty/zeros/tensor([...])` created fresh inside the captured region · building a Python list then `torch.tensor(list, device=...)` · `np.*` anywhere on the path · any CPU tensor then `.to(device)` per step
|
||||
|
||||
**Host RNG**
|
||||
`random.*` / `np.random.*` / Python rng for sampling → use a device generator / captured philox state
|
||||
|
||||
**Sync primitives & checks**
|
||||
`torch.cuda.synchronize()` · `stream.synchronize()` · `torch.isnan(x).any()` / `isinf(...).any()` debug guards · pinned-copy syncs
|
||||
|
||||
**Sneaky ones (the "didn't realize" category)**
|
||||
`sum(t)` / `min(t)` / `max(t)` (Python builtins iterate → sync; use `t.sum()`) · a `.cpu()`/`.item()` hidden inside a logging, assert, or metrics helper · `einops` rearrange with a data-dependent dim · telemetry/progress hooks that read tensors · indexing a tensor with a Python int derived from `.item()`
|
||||
**Host control flow on device values** — Eliminated:
|
||||
- `dec_tid_buf[0] = python_int` → pinned CPU buffer + `copy_` (async, graph-capturable)
|
||||
- `expert_offsets[g] = python_int * padded_rows` → element-wise GPU multiply with pre-allocated range tensor
|
||||
- `if group_offsets[0] != 0` → unconditional GPU-only update (no host read of GPU tensor)
|
||||
|
||||
**What is FINE (no sync, don't waste time on these)**
|
||||
`.shape` / `.size()` / `.numel()` / `.dtype` (host metadata, no sync) · branching on host-known ints (step/batch/static shape) · dtype/shape kernel dispatch · the **stop-token check, detokenize, and your BF16 precision-floor dequant** (all load-time or *outside* the captured graph — leave them on host, that's correct).
|
||||
- `.shape` / `.size()` / `.numel()` / `.dtype` (host metadata, no sync)
|
||||
- Branching on host-known ints (step/batch/static shape)
|
||||
- The **stop-token check, detokenize, and your BF16 precision-floor dequant** (all load-time or *outside* the captured graph — leave them on host, that's correct).
|
||||
- `dec_tid_buf.copy_(dec_tid_pinned)` — pinned CPU→GPU async memcpy, graph-capturable
|
||||
|
||||
---
|
||||
|
||||
## SECTION C — DSV4-specific kernels that must be GPU-native
|
||||
|
||||
| # | Hazard (current host/dynamic behavior) | Requirement | vLLM reference |
|
||||
|---|---|---|---|
|
||||
| 1 | Compressor returns `None` for 3/4 (CSA) or 127/128 (HCA) decode steps — periodic host branch | Compress **every** step into a persistent partial-state/ring buffer; emit the compressed entry **device-side** on the boundary | `save_partial_states`, `fused_compress_quant_cache` |
|
||||
| 2 | KV grows each step → attention shape changes | Paged KV (fixed blocks + block table) captured at fixed max-len with masking, **or** make attention the eager break | `breakable_cudagraph` / `eager_break_during_capture`; `AttentionCGSupport.ALWAYS` |
|
||||
| 3 | Indexer top-k → host reads selected count to size gather | Always gather fixed `k` (padded), mask invalid; no host read of the count | `dequant_gather_k_cutedsl` (fixed-shape gather) |
|
||||
| 4 | MoE top-6 → per-expert token counts drive per-expert launches | Routing permutation/offsets computed **on device**; grouped GEMM with device offsets and a fixed total launch | `prepare_megamoe` |
|
||||
| 5 | Next token / positions managed on host, fresh tensors per step | Static I/O buffers allocated once; **in-place** `copy_` of next token; positions via device-side increment (or per-shape bucketed graphs) | vLLM persistent input buffers |
|
||||
| # | Hazard | Status | Fix Applied |
|
||||
|---|--------|--------|-------------|
|
||||
| 1 | Compressor returns `None` for 3/4 (CSA) or 127/128 (HCA) decode steps | ⏳ Phase 2 (eager-break) | Compressor runs in eager section. Phase 2: device-side boundary detection + fixed-shape output |
|
||||
| 2 | KV grows each step → attention shape changes | ⏳ Phase 2 (eager-break) | Attention is the eager break. Phase 2: paged KV with fixed blocks + block table |
|
||||
| 3 | Indexer top-k → host reads selected count to size gather | ✅ DONE | Already fixed-shape gather (`topk_indices` is always `top_k` elements). No host read of count. |
|
||||
| 4 | MoE top-6 → per-expert token counts drive per-expert launches | ✅ DONE | `torch.bincount` → `scatter_add_` into pre-allocated buffer. Expert offsets are GPU tensors. |
|
||||
| 5 | Next token / positions managed on host, fresh tensors per step | ✅ DONE | Pre-allocated pinned CPU buffers + `copy_` to GPU. No per-step allocation. |
|
||||
|
||||
Also confirm:
|
||||
- **Sinkhorn** runs a **fixed 20 iterations with no host convergence check** (a `while not converged` reading a device residual breaks capture). Fixed-iteration = safe.
|
||||
- **Sampler** is device-side; `repetition_penalty` reads from a **fixed-size device** recent-token buffer (not a growing Python list); the EOS/stop decision is a host step **outside** the graph (correct).
|
||||
Also confirmed:
|
||||
- **Sinkhorn** runs a **fixed 20 iterations with no host convergence check** ✅
|
||||
- **Sampler** is device-side; the EOS/stop decision is a host step **outside** the graph ✅
|
||||
- **Router** is graph-safe: pre-allocated output buffers, GPU-only operations ✅
|
||||
- **mHC** is graph-safe: fixed-iteration Sinkhorn, no `.item()` on hot path ✅
|
||||
|
||||
### Architectural Decision: Eager-Break-at-Attention (Phase 1)
|
||||
|
||||
The per-layer compute is split:
|
||||
- **Captured** (in CUDA graph): mHC pre_block → RMSNorm + quantize → attention projections → o_proj → mHC post_block → FFN mHC → Router → MoE → SE → mHC post_block
|
||||
- **Eager** (outside graph): Compressor → Indexer → KV gather → FMHA → inverse RoPE
|
||||
- **Rationale**: FMHA has dynamic sequence length; compressor/KV are data-dependent. Capturing the compute-heavy parts eliminates ~94ms of Python dispatch overhead per step.
|
||||
- **Phase 2**: Paged KV + device-side compressor → full graph capture for vLLM integration.
|
||||
|
||||
---
|
||||
|
||||
## SECTION D — Integration order
|
||||
|
||||
1. **Build Section A's detector and run it on the current forward** — get the full sync inventory in one pass.
|
||||
2. Fix Section C's five device-native kernels (these are the structural ones; the rest of Section B tends to be incidental `.item()`s once these are right).
|
||||
3. Re-run capture-under-test until it captures clean and replay matches eager bit-for-bit.
|
||||
4. Gate every commit on the capture test so violations can never silently return.
|
||||
1. ✅ **Build Section A's detector and run it on the current forward** — DONE. `tests/unit/test_cuda_graph_readiness.py` on B200.
|
||||
2. ✅ **Fix Section C's five device-native kernels** — 3/5 done, 2 deferred to Phase 2 with architectural decision.
|
||||
3. 🔄 **Re-run capture-under-test until it captures clean** — L0 capture PASSES. Need to extend to all 61 layers + lm_head + replay verification.
|
||||
4. ⬜ **Gate every commit on the capture test** — Not yet implemented.
|
||||
|
||||
### Next Steps
|
||||
1. Extend graph capture from L0 to all 61 layers
|
||||
2. Capture hc_head + norm + lm_head graph on cuda:0
|
||||
3. Implement replay loop and verify bit-for-bit match with eager
|
||||
4. Benchmark: measure speedup from graph capture vs eager decode
|
||||
5. Gate commits on capture test
|
||||
6. Phase 2: paged KV + device-side compressor for full vLLM graph capture
|
||||
|
||||
## Guardrails
|
||||
- Keep the stop-check, detokenize, and load-time BF16 dequant on the host — they're outside the captured region by design; don't contort them to be "graph-safe."
|
||||
- Decide the attention model up front (paged-capturable vs eager-break) — retrofitting it later forces a KV-cache rewrite.
|
||||
- Host-known-int branching is allowed; only device-value branching must be eliminated. Don't over-correct and try to make legitimate shape/dtype dispatch device-side.
|
||||
- **Phase 1 uses eager-break-at-attention.** Phase 2 adds paged KV. Don't retrofit paged KV into Phase 1 — it's a separate integration.
|
||||
- Host-known-int branching is allowed; only device-value branching must be eliminated. Don't over-correct and try to make legitimate shape/dtype dispatch device-side.
|
||||
|
||||
## Violation Fix Log
|
||||
|
||||
| Commit | Description |
|
||||
|--------|-------------|
|
||||
| `a9ea303` | mhc.py `.item()` removal, linear/shared_expert pre-alloc, quantize gsa fix |
|
||||
| `46a3a51` | mHCLayer.init_state out_buf, dec_X_buf pre-allocation |
|
||||
| `0ca7bed` | Pinned CPU buffers for token transfer, grouped_linear expert_offsets GPU-only |
|
||||
| `e07d798` | _assemble_scales_single_group correctly-sized view for swizzle |
|
||||
| `df05289` | Remove conditional host read of GPU tensor in grouped_linear |
|
||||
| `84655d0` | MoE bincount → scatter_add_, MoE torch.full → fill_() |
|
||||
| `f13a81d` | grouped_linear scale_a_buf pre-alloc, quantize zeros_like → scalar 0.0 |
|
||||
| `518a1d3` | MoE scatter_add_ int64 indices, fix second bincount call |
|
||||
| `80bb27f` | gsa broadcast: reshape for M=1 decode (no stride-0), contiguous for M>1 prefill |
|
||||
|
||||
Reference in New Issue
Block a user