13 KiB
DSV4 → vLLM: CUDA-Graph Safety / GPU-Native Requirements (PART 2 companion)
Goal: the per-step decode forward must be fully GPU-native so vLLM can capture and replay it. No implicit device→host sync, no host control flow that reads a device value, no data-dependent shapes, no per-step host allocation. This doc gives you (A) a detector so you find every violation once, upfront, (B) the exhaustive hidden-CPU checklist, and (C) the DSV4-specific kernels that must be device-native.
The one rule that decides everything
Branching on a host-known integer (step number, position, batch size, dtype, static shape) is graph-compatible — you capture one graph per bucket and the scheduler picks by that integer. Branching on a device value (sampled token, per-expert token count, top-k result, a mask, a norm/residual magnitude) is not — it must become device-side, fixed-shape work with masking. Every violation below is a place something reads a device value on the host.
You do not need one monolithic graph. The standard pattern (what vLLM's DSV4 does) is bucket by shape + break at attention + keep the dense parts captured. Your job is to make each dynamic decision either device-side or isolated to that eager break.
⚠️ CRITICAL MULTI-GPU REQUIREMENT (learned 2026-06-06)
PyTorch CUDA graphs on non-default GPUs REQUIRE explicit torch.cuda.Stream(device=device) for capture AND replay. Using torch.cuda.set_device() alone causes:
- GPU 0: Empty graph (warning: "The CUDA Graph is empty")
- GPU 1+: Graph replays with stale capture-time data, ignoring updated input buffers
The fix:
# CAPTURE:
s = torch.cuda.Stream(device=device)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g, stream=s):
output_buf.copy_(input_buf * 2.0)
# REPLAY:
with torch.cuda.stream(s):
g.replay()
Stream synchronization between graph and eager paths:
- Graph A/B run on per-device streams
- Eager attention (between Graph A and Graph B) runs on the default stream
- Use
torch.cuda.Event+record()+wait_event()for sync - Do NOT use
torch.cuda.synchronize()— it syncs ALL GPUs (too heavy)
This was the root cause of the "all-zeros replay" bug that took an entire session to diagnose. The minimal reproduction test is in tests/unit/test_cuda_graph_stream.py. Read this test if you ever see zero-output graph replay again.
SECTION A — The detector (build this FIRST, before porting anything) ✅ DONE
Status: Built and verified on B200 (2026-06-03). See tests/unit/test_cuda_graph_readiness.py.
Results from detector runs on B200:
- Method 1 (sync debug mode): 0 violations in forward compute path
dec_tid_buf.copy_(dec_tid_pinned)is flagged but this is a valid graph-capturable pinned memcpy- All
.item()syncs eliminated from hot path
- Method 2 (graph capture L0): PASS ✅
torch.cuda.CUDAGraph()capture of layer 0 decode step succeeds- All per-call allocations eliminated
- All host reads of GPU values eliminated
The detector:
- Grep for Section B sync patterns in hot path files
- Run one decode step with
torch.cuda.set_sync_debug_mode("error") - Attempt
torch.cuda.graphcapture of L0 decode step - Report results to
/tmp/cuda_graph_readiness_results.json
Run via test harness:
fire_b200_test tests/unit/test_cuda_graph_readiness.py kernel-test /tmp/kernel-test.log 1800
SECTION B — The hidden-CPU checklist (grep the hot path for these) ✅ ADDRESSED
Explicit device→host transfers — All .item() calls on hot path eliminated:
- mhc.py
post_block: removedX_next.abs().max().item()(122 syncs/step across 61 layers × 2 mHC) - All other
.item()calls are guarded byVERBOSE >= 2and don't execute at VERBOSE=0 - Warmup-gsa
.item()calls run once at step 0, outside graph region
Data-dependent shapes — Eliminated torch.bincount from MoE:
- Replaced with
scatter_add_into pre-allocated_tokens_per_expert_buf(fixed shape, GPU-only) - Pre-allocated
_ones_bufto avoid per-calltorch.ones()
Per-step host allocation — All eliminated:
torch.zeros()in_assemble_scales_single_group→ pre-allocated_scale_a_buf(linear.py, grouped_linear.py, shared_expert.py)torch.full()for MoE l1_gsa →self._l1_gsa_buf.fill_(l1_gs)torch.empty()for grouped_linear output → pre-allocated_output_bufmHCLayer.init_state.clone()→out_bufparameter for in-place writetorch.zeros_likein quantize.py → scalar0.0intorch.where
Host control flow on device values — Eliminated:
dec_tid_buf[0] = python_int→ pinned CPU buffer +copy_(async, graph-capturable)expert_offsets[g] = python_int→ element-wise GPU multiply with pre-allocated range tensorif group_offsets[0] != 0→ unconditional GPU-only update (no host read of GPU tensor)
What is FINE (no sync, don't waste time on these)
.shape/.size()/.numel()/.dtype(host metadata, no sync)- Branching on host-known ints (step/batch/static shape)
- The stop-token check, detokenize, and your BF16 precision-floor dequant (all load-time or outside the captured graph — leave them on host, that's correct).
dec_tid_buf.copy_(dec_tid_pinned)— pinned CPU→GPU async memcpy, graph-capturable
SECTION C — DSV4-specific kernels that must be GPU-native
| # | Hazard | Status | Fix Applied |
|---|---|---|---|
| 1 | Compressor returns None for 3/4 (CSA) or 127/128 (HCA) decode steps |
⏳ Phase 2 (eager-break) | Compressor runs in eager section. Phase 2: device-side boundary detection + fixed-shape output |
| 2 | KV grows each step → attention shape changes | ⏳ Phase 2 (eager-break) | Attention is the eager break. Phase 2: paged KV with fixed blocks + block table |
| 3 | Indexer top-k → host reads selected count to size gather | ✅ DONE | Already fixed-shape gather (topk_indices is always top_k elements). No host read of count. |
| 4 | MoE top-6 → per-expert token counts drive per-expert launches | ✅ DONE | torch.bincount → scatter_add_ into pre-allocated buffer. Expert offsets are GPU tensors. |
| 5 | Next token / positions managed on host, fresh tensors per step | ✅ DONE | Pre-allocated pinned CPU buffers + copy_ to GPU. No per-step allocation. |
Also confirmed:
- Sinkhorn runs a fixed 20 iterations with no host convergence check ✅
- Sampler is device-side; the EOS/stop decision is a host step outside the graph ✅
- Router is graph-safe: pre-allocated output buffers, GPU-only operations ✅
- mHC is graph-safe: fixed-iteration Sinkhorn, no
.item()on hot path ✅
Architectural Decision: Eager-Break-at-Attention (Phase 1) — UPDATED 2026-06-06
The per-layer compute is split into two graph-captured regions with eager attention in between:
- Graph A (captured): mHC pre_block(attn) + fused RMSNorm + quantize + q_a + q_a_norm + q_b + kv projections
- Outputs written to pre-allocated buffers: x_normed, q_heads, kv_3d, ctx_a_B, ctx_a_C, X_mid
- Eager (NOT captured): Compressor → Indexer → KV gather → FMHA → inverse RoPE → o_a + o_b → F_attn
- Dynamic shapes (FMHA seq_len, compressor returns None) → cannot be captured
forward_attention()accepts optionalq_heads/kv_3dto skip projections when called from graph replay
- Graph B (captured): mHC post_block(attn) + FFN mHC + RMSNorm + quantize + Router + MoE + SE + mHC post_block(ffn)
- Reads F_attn from pre-allocated buffer (written by eager attention)
- Writes X_next to pre-allocated output buffer
Rationale: FMHA has dynamic sequence length; compressor/KV are data-dependent. Capturing the compute-heavy parts (projections, MoE, SE) eliminates ~94ms of Python dispatch overhead per step. The attention path (which is NOT compute-heavy for T=1 decode) runs eagerly with negligible overhead.
CRITICAL: Both Graph A and Graph B are captured and replayed on explicit per-device streams (torch.cuda.Stream(device=device)). The eager attention path runs on the default stream. Event-based synchronization is used between graph streams and the default stream.
Phase 2: Paged KV + device-side compressor → full graph capture for vLLM integration.
SECTION D — Integration order
- ✅ Build Section A's detector and run it on the current forward — DONE.
tests/unit/test_cuda_graph_readiness.pyon B200. - ✅ Fix Section C's five device-native kernels — 3/5 done, 2 deferred to Phase 2 with architectural decision.
- ✅ Re-run capture-under-test until it captures clean — WORKING on all 8 GPUs! Root cause: multi-GPU requires explicit
torch.cuda.Stream(device=device). - ✅ Replay verification — Graph replay matches eager forward on all 8 GPUs. Logit range [-26.5, 15.0] matches.
- ✅ Benchmark — 0.28-0.30s/token with CUDA graphs (vs 0.55s/token eager = ~2x speedup).
- ⬜ Gate every commit on the capture test — Not yet implemented.
- ⬜ Optimize stream sync — Current implementation uses
torch.cuda.Event+wait_event()/synchronize(). Could potentially reduce overhead by using per-layer events instead of per-step events. - ⬜ Phase 2: Paged KV + device-side compressor for full vLLM graph capture.
NEXT STEPS (pick up here in next session)
Priority 1: Decode degeneration (still unresolved)
The model generates a repetition loop (psych ↔ istically) regardless of whether CUDA graphs are used. This is the SAME issue as the eager path — not caused by graph capture. Root cause UNKNOWN. Components exonerated: mHC, FMHA, compression. This is the highest-priority correctness issue.
Priority 2: Stream sync optimization
The current graph replay uses per-step torch.cuda.Event sync between graph streams and the default stream. This works but may add overhead. Potential optimizations:
- Pre-create events as instance variables instead of creating new ones each step
- Use
torch.cuda.Stream.wait_stream()instead of event-based sync where possible - Profile the sync overhead vs compute time
Priority 3: Long-run stability
Test with --max-tokens 512+ to verify stability over many decode steps. Check for:
- Memory leaks (growing GPU memory usage)
- Numerical drift (logit range changes over time)
- Graph replay failures after many steps
Priority 4: Phase 2 — Full vLLM integration
- Paged KV cache (fixed blocks + block table)
- Device-side compressor boundary detection + fixed-shape output
- Full graph capture including FMHA
- Bucket-by-shape for variable sequence lengths
Guardrails
- Keep the stop-check, detokenize, and load-time BF16 dequant on the host — they're outside the captured region by design; don't contort them to be "graph-safe."
- Phase 1 uses eager-break-at-attention. Phase 2 adds paged KV. Don't retrofit paged KV into Phase 1 — it's a separate integration.
- Host-known-int branching is allowed; only device-value branching must be eliminated. Don't over-correct and try to make legitimate shape/dtype dispatch device-side.
- ALWAYS use explicit
torch.cuda.Stream(device=device)for graph capture and replay on multi-GPU setups. This is non-negotiable on B200.
Violation Fix Log
| Commit | Description |
|---|---|
a9ea303 |
mhc.py .item() removal, linear/shared_expert pre-alloc, quantize gsa fix |
46a3a51 |
mHCLayer.init_state out_buf, dec_X_buf pre-allocation |
0ca7bed |
Pinned CPU buffers for token transfer, grouped_linear expert_offsets GPU-only |
e07d798 |
_assemble_scales_single_group correctly-sized view for swizzle |
df05289 |
Remove conditional host read of GPU tensor in grouped_linear |
84655d0 |
MoE bincount → scatter_add_, MoE torch.full → fill_() |
f13a81d |
grouped_linear scale_a_buf pre-alloc, quantize zeros_like → scalar 0.0 |
518a1d3 |
MoE scatter_add_ int64 indices, fix second bincount call |
80bb27f |
gsa broadcast: reshape for M=1 decode (no stride-0), contiguous for M>1 prefill |
6dc2f22 |
CRITICAL: _l1_out_buf 2x too narrow → GPU memory corruption (root cause of ALL cudaErrorInvalidValue errors). Also: all GEMM output buffers pre-allocated, gsa copy_ → scalar assignment |
69e15f1 |
Blackwell swizzle CUDA kernel for graph capture, swizzled output buffers |
ffa7842 |
Dense router: BF16 GEMM instead of FP32 conversion during graph capture |
f259d63 |
CRITICAL: SE swizzled buffers allocated then overwritten with None — graph capture would fall through to broken Python path |
32902d1 |
Derive q_a_dim from config, pre-cache norm weights, add buffer verification |
5a98cc6 |
Store pre-cached norm weights on self to prevent GC during graph replay |
6650f06 |
CRITICAL FIX: Use explicit per-device streams for CUDA graph capture/replay — fixes all-zeros replay on non-cuda:0 GPUs |