Files
nvfp4-megamoe-kernel/GETTING_CUDAGRAPH_READY.md

13 KiB
Raw Permalink Blame History

DSV4 → vLLM: CUDA-Graph Safety / GPU-Native Requirements (PART 2 companion)

Goal: the per-step decode forward must be fully GPU-native so vLLM can capture and replay it. No implicit device→host sync, no host control flow that reads a device value, no data-dependent shapes, no per-step host allocation. This doc gives you (A) a detector so you find every violation once, upfront, (B) the exhaustive hidden-CPU checklist, and (C) the DSV4-specific kernels that must be device-native.

The one rule that decides everything

Branching on a host-known integer (step number, position, batch size, dtype, static shape) is graph-compatible — you capture one graph per bucket and the scheduler picks by that integer. Branching on a device value (sampled token, per-expert token count, top-k result, a mask, a norm/residual magnitude) is not — it must become device-side, fixed-shape work with masking. Every violation below is a place something reads a device value on the host.

You do not need one monolithic graph. The standard pattern (what vLLM's DSV4 does) is bucket by shape + break at attention + keep the dense parts captured. Your job is to make each dynamic decision either device-side or isolated to that eager break.


⚠️ CRITICAL MULTI-GPU REQUIREMENT (learned 2026-06-06)

PyTorch CUDA graphs on non-default GPUs REQUIRE explicit torch.cuda.Stream(device=device) for capture AND replay. Using torch.cuda.set_device() alone causes:

  • GPU 0: Empty graph (warning: "The CUDA Graph is empty")
  • GPU 1+: Graph replays with stale capture-time data, ignoring updated input buffers

The fix:

# CAPTURE:
s = torch.cuda.Stream(device=device)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g, stream=s):
    output_buf.copy_(input_buf * 2.0)

# REPLAY:
with torch.cuda.stream(s):
    g.replay()

Stream synchronization between graph and eager paths:

  • Graph A/B run on per-device streams
  • Eager attention (between Graph A and Graph B) runs on the default stream
  • Use torch.cuda.Event + record() + wait_event() for sync
  • Do NOT use torch.cuda.synchronize() — it syncs ALL GPUs (too heavy)

This was the root cause of the "all-zeros replay" bug that took an entire session to diagnose. The minimal reproduction test is in tests/unit/test_cuda_graph_stream.py. Read this test if you ever see zero-output graph replay again.


SECTION A — The detector (build this FIRST, before porting anything) DONE

Status: Built and verified on B200 (2026-06-03). See tests/unit/test_cuda_graph_readiness.py.

Results from detector runs on B200:

  • Method 1 (sync debug mode): 0 violations in forward compute path
    • dec_tid_buf.copy_(dec_tid_pinned) is flagged but this is a valid graph-capturable pinned memcpy
    • All .item() syncs eliminated from hot path
  • Method 2 (graph capture L0): PASS
    • torch.cuda.CUDAGraph() capture of layer 0 decode step succeeds
    • All per-call allocations eliminated
    • All host reads of GPU values eliminated

The detector:

  1. Grep for Section B sync patterns in hot path files
  2. Run one decode step with torch.cuda.set_sync_debug_mode("error")
  3. Attempt torch.cuda.graph capture of L0 decode step
  4. Report results to /tmp/cuda_graph_readiness_results.json

Run via test harness:

fire_b200_test tests/unit/test_cuda_graph_readiness.py kernel-test /tmp/kernel-test.log 1800

SECTION B — The hidden-CPU checklist (grep the hot path for these) ADDRESSED

Explicit device→host transfers — All .item() calls on hot path eliminated:

  • mhc.py post_block: removed X_next.abs().max().item() (122 syncs/step across 61 layers × 2 mHC)
  • All other .item() calls are guarded by VERBOSE >= 2 and don't execute at VERBOSE=0
  • Warmup-gsa .item() calls run once at step 0, outside graph region

Data-dependent shapes — Eliminated torch.bincount from MoE:

  • Replaced with scatter_add_ into pre-allocated _tokens_per_expert_buf (fixed shape, GPU-only)
  • Pre-allocated _ones_buf to avoid per-call torch.ones()

Per-step host allocation — All eliminated:

  • torch.zeros() in _assemble_scales_single_group → pre-allocated _scale_a_buf (linear.py, grouped_linear.py, shared_expert.py)
  • torch.full() for MoE l1_gsa → self._l1_gsa_buf.fill_(l1_gs)
  • torch.empty() for grouped_linear output → pre-allocated _output_buf
  • mHCLayer.init_state .clone()out_buf parameter for in-place write
  • torch.zeros_like in quantize.py → scalar 0.0 in torch.where

Host control flow on device values — Eliminated:

  • dec_tid_buf[0] = python_int → pinned CPU buffer + copy_ (async, graph-capturable)
  • expert_offsets[g] = python_int → element-wise GPU multiply with pre-allocated range tensor
  • if group_offsets[0] != 0 → unconditional GPU-only update (no host read of GPU tensor)

What is FINE (no sync, don't waste time on these)

  • .shape / .size() / .numel() / .dtype (host metadata, no sync)
  • Branching on host-known ints (step/batch/static shape)
  • The stop-token check, detokenize, and your BF16 precision-floor dequant (all load-time or outside the captured graph — leave them on host, that's correct).
  • dec_tid_buf.copy_(dec_tid_pinned) — pinned CPU→GPU async memcpy, graph-capturable

SECTION C — DSV4-specific kernels that must be GPU-native

# Hazard Status Fix Applied
1 Compressor returns None for 3/4 (CSA) or 127/128 (HCA) decode steps Phase 2 (eager-break) Compressor runs in eager section. Phase 2: device-side boundary detection + fixed-shape output
2 KV grows each step → attention shape changes Phase 2 (eager-break) Attention is the eager break. Phase 2: paged KV with fixed blocks + block table
3 Indexer top-k → host reads selected count to size gather DONE Already fixed-shape gather (topk_indices is always top_k elements). No host read of count.
4 MoE top-6 → per-expert token counts drive per-expert launches DONE torch.bincountscatter_add_ into pre-allocated buffer. Expert offsets are GPU tensors.
5 Next token / positions managed on host, fresh tensors per step DONE Pre-allocated pinned CPU buffers + copy_ to GPU. No per-step allocation.

Also confirmed:

  • Sinkhorn runs a fixed 20 iterations with no host convergence check
  • Sampler is device-side; the EOS/stop decision is a host step outside the graph
  • Router is graph-safe: pre-allocated output buffers, GPU-only operations
  • mHC is graph-safe: fixed-iteration Sinkhorn, no .item() on hot path

Architectural Decision: Eager-Break-at-Attention (Phase 1) — UPDATED 2026-06-06

The per-layer compute is split into two graph-captured regions with eager attention in between:

  • Graph A (captured): mHC pre_block(attn) + fused RMSNorm + quantize + q_a + q_a_norm + q_b + kv projections
    • Outputs written to pre-allocated buffers: x_normed, q_heads, kv_3d, ctx_a_B, ctx_a_C, X_mid
  • Eager (NOT captured): Compressor → Indexer → KV gather → FMHA → inverse RoPE → o_a + o_b → F_attn
    • Dynamic shapes (FMHA seq_len, compressor returns None) → cannot be captured
    • forward_attention() accepts optional q_heads/kv_3d to skip projections when called from graph replay
  • Graph B (captured): mHC post_block(attn) + FFN mHC + RMSNorm + quantize + Router + MoE + SE + mHC post_block(ffn)
    • Reads F_attn from pre-allocated buffer (written by eager attention)
    • Writes X_next to pre-allocated output buffer

Rationale: FMHA has dynamic sequence length; compressor/KV are data-dependent. Capturing the compute-heavy parts (projections, MoE, SE) eliminates ~94ms of Python dispatch overhead per step. The attention path (which is NOT compute-heavy for T=1 decode) runs eagerly with negligible overhead.

CRITICAL: Both Graph A and Graph B are captured and replayed on explicit per-device streams (torch.cuda.Stream(device=device)). The eager attention path runs on the default stream. Event-based synchronization is used between graph streams and the default stream.

Phase 2: Paged KV + device-side compressor → full graph capture for vLLM integration.


SECTION D — Integration order

  1. Build Section A's detector and run it on the current forward — DONE. tests/unit/test_cuda_graph_readiness.py on B200.
  2. Fix Section C's five device-native kernels — 3/5 done, 2 deferred to Phase 2 with architectural decision.
  3. Re-run capture-under-test until it captures clean — WORKING on all 8 GPUs! Root cause: multi-GPU requires explicit torch.cuda.Stream(device=device).
  4. Replay verification — Graph replay matches eager forward on all 8 GPUs. Logit range [-26.5, 15.0] matches.
  5. Benchmark — 0.28-0.30s/token with CUDA graphs (vs 0.55s/token eager = ~2x speedup).
  6. Gate every commit on the capture test — Not yet implemented.
  7. Optimize stream sync — Current implementation uses torch.cuda.Event + wait_event()/synchronize(). Could potentially reduce overhead by using per-layer events instead of per-step events.
  8. Phase 2: Paged KV + device-side compressor for full vLLM graph capture.

NEXT STEPS (pick up here in next session)

Priority 1: Decode degeneration (still unresolved)

The model generates a repetition loop (psychistically) regardless of whether CUDA graphs are used. This is the SAME issue as the eager path — not caused by graph capture. Root cause UNKNOWN. Components exonerated: mHC, FMHA, compression. This is the highest-priority correctness issue.

Priority 2: Stream sync optimization

The current graph replay uses per-step torch.cuda.Event sync between graph streams and the default stream. This works but may add overhead. Potential optimizations:

  • Pre-create events as instance variables instead of creating new ones each step
  • Use torch.cuda.Stream.wait_stream() instead of event-based sync where possible
  • Profile the sync overhead vs compute time

Priority 3: Long-run stability

Test with --max-tokens 512+ to verify stability over many decode steps. Check for:

  • Memory leaks (growing GPU memory usage)
  • Numerical drift (logit range changes over time)
  • Graph replay failures after many steps

Priority 4: Phase 2 — Full vLLM integration

  • Paged KV cache (fixed blocks + block table)
  • Device-side compressor boundary detection + fixed-shape output
  • Full graph capture including FMHA
  • Bucket-by-shape for variable sequence lengths

Guardrails

  • Keep the stop-check, detokenize, and load-time BF16 dequant on the host — they're outside the captured region by design; don't contort them to be "graph-safe."
  • Phase 1 uses eager-break-at-attention. Phase 2 adds paged KV. Don't retrofit paged KV into Phase 1 — it's a separate integration.
  • Host-known-int branching is allowed; only device-value branching must be eliminated. Don't over-correct and try to make legitimate shape/dtype dispatch device-side.
  • ALWAYS use explicit torch.cuda.Stream(device=device) for graph capture and replay on multi-GPU setups. This is non-negotiable on B200.

Violation Fix Log

Commit Description
a9ea303 mhc.py .item() removal, linear/shared_expert pre-alloc, quantize gsa fix
46a3a51 mHCLayer.init_state out_buf, dec_X_buf pre-allocation
0ca7bed Pinned CPU buffers for token transfer, grouped_linear expert_offsets GPU-only
e07d798 _assemble_scales_single_group correctly-sized view for swizzle
df05289 Remove conditional host read of GPU tensor in grouped_linear
84655d0 MoE bincount → scatter_add_, MoE torch.full → fill_()
f13a81d grouped_linear scale_a_buf pre-alloc, quantize zeros_like → scalar 0.0
518a1d3 MoE scatter_add_ int64 indices, fix second bincount call
80bb27f gsa broadcast: reshape for M=1 decode (no stride-0), contiguous for M>1 prefill
6dc2f22 CRITICAL: _l1_out_buf 2x too narrow → GPU memory corruption (root cause of ALL cudaErrorInvalidValue errors). Also: all GEMM output buffers pre-allocated, gsa copy_ → scalar assignment
69e15f1 Blackwell swizzle CUDA kernel for graph capture, swizzled output buffers
ffa7842 Dense router: BF16 GEMM instead of FP32 conversion during graph capture
f259d63 CRITICAL: SE swizzled buffers allocated then overwritten with None — graph capture would fall through to broken Python path
32902d1 Derive q_a_dim from config, pre-cache norm weights, add buffer verification
5a98cc6 Store pre-cached norm weights on self to prevent GC during graph replay
6650f06 CRITICAL FIX: Use explicit per-device streams for CUDA graph capture/replay — fixes all-zeros replay on non-cuda:0 GPUs