CUDA graph: Fix remaining sync violations from B200 detector run 2

1. grouped_linear.py: Remove conditional host read of GPU tensor
   - 'if group_offsets[0] != 0' reads GPU value on host → sync
   - Fix: unconditionally update offsets every call (GPU-only multiply)

2. test_cuda_graph_readiness.py: Use pinned CPU buffers for token transfer
   - dec_tid_buf[0] = python_int → CPU→GPU sync
   - Fix: write to pinned CPU buffer, then copy_ (async, graph-capturable)

3. Add dsv4/decode/cuda_graph_decoder.py (skeleton)
This commit is contained in:
2026-06-03 17:20:34 +00:00
parent e07d79868f
commit df05289d6f
3 changed files with 199 additions and 15 deletions

View File

@@ -0,0 +1,172 @@
"""CUDA Graph Decode for DSV4 — zero Python dispatch overhead.
Architecture: Eager-break-at-attention with per-GPU captured subgraphs.
For each decode step:
1. Copy next token to pre-allocated input buffer (pinned CPU → GPU)
2. For each GPU subgraph: replay the captured compute
3. Between subgraphs: transfer X between GPUs (eager, small tensor)
4. FMHA runs eagerly (dynamic KV length) — this is the attention break
5. After all layers: hc_head + norm + lm_head (captured on cuda:0)
6. Sample next token (eager, outside graph)
The captured subgraph per GPU contains:
- mHC pre_block (attn) → RMSNorm + quantize → attention projections (q_a, q_b, kv)
- [EAGER: compressor → indexer → gather → FMHA → inverse RoPE]
- o_proj → mHC post_block (attn) → mHC pre_block (ffn) → Router → MoE → SE → mHC post_block (ffn)
Actually, for simplicity and to avoid splitting the attention, we capture
the FULL layer forward (including FMHA) and handle the dynamic KV length
by pre-allocating at max_context and masking.
For the initial implementation, we capture per-LAYER (not per-GPU subgraph)
to isolate issues. 61 individual graphs, each capturing one layer's forward.
"""
import torch
import torch.nn.functional as F
import time
import math
from dsv4.layers.mhc import mHCLayer, mHCContext
class CUDAGraphDecoder:
"""CUDA Graph decoder for DSV4 single-shot inference.
Captures the entire decode step (all 61 layers + lm_head) as CUDA graphs,
eliminating Python dispatch overhead (~94ms) and kernel launch latency.
Constraints:
- All tensors must have fixed addresses (pre-allocated)
- No dynamic shapes (T=1 decode has fixed shapes)
- No CPU-GPU syncs inside the graph
- Cross-GPU transfers happen outside the graph region
The compressor and KV cache must be graph-safe:
- Compressor: always produces output (zeros when buffer incomplete)
- KV cache: n_comp stored as GPU tensor, gather is fixed-shape with masking
- FMHA: runs at max_seq_len with masking for actual length
"""
def __init__(self, n_layers, num_gpus, devices, hidden_size, n_hc=4):
self.n_layers = n_layers
self.num_gpus = num_gpus
self.devices = devices
self.hidden_size = hidden_size
self.n_hc = n_hc
# Per-layer CUDA graphs
self.graphs = {} # li -> torch.cuda.CUDAGraph
# Final graph (hc_head + norm + lm_head) on cuda:0
self.lm_graph = None
# Pre-allocated I/O buffers — fixed addresses for graph capture
# X is (1, n_hc, H) BF16
self.x_in = {} # li -> tensor on device of layer li
self.x_out = {} # li -> tensor on device of layer li
# Final output buffers on cuda:0
self.logits_buf = None
self.x_cuda0_buf = None # X after all layers, on cuda:0
self.captured = False
def pre_allocate(self, vocab_size=129280):
"""Pre-allocate all I/O buffers with fixed addresses."""
for li in range(self.n_layers):
dev = self.devices[li % self.num_gpus]
self.x_in[li] = torch.zeros(1, self.n_hc, self.hidden_size,
dtype=torch.bfloat16, device=dev)
self.x_out[li] = torch.zeros(1, self.n_hc, self.hidden_size,
dtype=torch.bfloat16, device=dev)
self.logits_buf = torch.zeros(1, vocab_size, dtype=torch.bfloat16, device='cuda:0')
self.x_cuda0_buf = torch.zeros(1, self.n_hc, self.hidden_size,
dtype=torch.bfloat16, device='cuda:0')
def capture(self, X_warmup, layer_forward_fn, lm_forward_fn,
all_layer_args, lm_args):
"""Capture CUDA graphs after warmup.
Args:
X_warmup: X tensor from warmup step (to seed input buffers)
layer_forward_fn: function(X, li, **kwargs) -> X_next
lm_forward_fn: function(X, **kwargs) -> logits
all_layer_args: dict[li] -> kwargs for layer_forward_fn
lm_args: kwargs for lm_forward_fn
"""
print(" Capturing CUDA graphs for decode...", flush=True)
for li in range(self.n_layers):
gpu = li % self.num_gpus
dev = self.devices[gpu]
torch.cuda.set_device(gpu)
# Seed input buffer with warmup X
if li == 0:
self.x_in[li].copy_(X_warmup.to(dev))
else:
self.x_in[li].copy_(self.x_out[li - 1].to(dev))
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
X_next = layer_forward_fn(self.x_in[li], li, **all_layer_args[li])
self.x_out[li].copy_(X_next)
self.graphs[li] = graph
if (li + 1) % 10 == 0:
print(f" Captured {li+1}/{self.n_layers} layer graphs", flush=True)
# Capture hc_head + norm + lm_head on cuda:0
torch.cuda.set_device(0)
if self.n_layers > 0:
self.x_cuda0_buf.copy_(self.x_out[self.n_layers - 1].to('cuda:0'))
self.lm_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.lm_graph):
logits = lm_forward_fn(self.x_cuda0_buf, **lm_args)
self.logits_buf.copy_(logits)
self.captured = True
print(f" Captured {len(self.graphs)} layer graphs + lm_head graph", flush=True)
def replay(self, token_id_gpu, position_gpu):
"""Replay captured graphs for one decode step.
Args:
token_id_gpu: (1,) long tensor on cuda:0 — next token ID
position_gpu: (1,) long tensor on cuda:0 — current position
Returns:
logits: (1, vocab_size) bfloat16 tensor
"""
assert self.captured, "Must call capture() before replay()"
# TODO: Copy token_id/position to the static input buffers that the graph uses.
# This requires the graph to reference those buffers.
# Replay layer graphs
for li in range(self.n_layers):
gpu = li % self.num_gpus
torch.cuda.set_device(gpu)
# Copy input from previous layer's output
if li > 0:
prev_gpu = (li - 1) % self.num_gpus
if prev_gpu != gpu:
self.x_in[li].copy_(self.x_out[li - 1].to(self.devices[gpu]))
else:
self.x_in[li].copy_(self.x_out[li - 1])
self.graphs[li].replay()
# Transfer final X to cuda:0
if self.n_layers > 0:
self.x_cuda0_buf.copy_(self.x_out[self.n_layers - 1].to('cuda:0'))
# Replay lm_head graph
self.lm_graph.replay()
return self.logits_buf

View File

@@ -333,12 +333,12 @@ class Nvfp4GroupedLinear:
x_fp4_grouped = x_fp4_flat.reshape(self.n_local_groups, num_tokens, self.group_in_features // 2)
# Vectorized scatter — no Python loop, no CPU→GPU sync
# Build destination offsets and use index_put_ with pre-allocated indices
# Unconditionally update group offsets — GPU-only, no conditional host read.
# padded_rows_per_group is a Python int multiplied with a GPU tensor = GPU op.
group_offsets = self._group_offset_buf[:self.n_local_groups]
if group_offsets[0] != 0 or (self.n_local_groups > 1 and group_offsets[1] != padded_rows_per_group):
# Update offsets (only when padded_rows changes, e.g., prefill vs decode)
group_offsets.copy_(torch.arange(self.n_local_groups, dtype=torch.int32, device=o.device) * padded_rows_per_group)
# Scatter each group's x_fp4 into padded buffer (vectorized)
expert_offsets = self._expert_offsets_buf
expert_offsets[:self.n_local_groups] = self._expert_offsets_range_buf * padded_rows_per_group
# Scatter each group's x_fp4 into padded buffer
for g in range(self.n_local_groups):
offset = g * padded_rows_per_group
padded_x_fp4.view(torch.uint8)[offset:offset + num_tokens] = x_fp4_grouped[g].view(torch.uint8)

View File

@@ -374,12 +374,23 @@ def run_sync_debug_mode():
dec_tid_buf = torch.zeros(1, dtype=torch.long, device='cuda:0')
dec_pos_buf = torch.zeros(1, dtype=torch.long, device='cuda:0')
dec_tid32_buf = torch.zeros(1, dtype=torch.int32, device='cuda:0')
# Pinned CPU buffers for graph-capturable token/position transfer
dec_tid_pinned = torch.zeros(1, dtype=torch.long, device='cpu').pin_memory()
dec_tid32_pinned = torch.zeros(1, dtype=torch.int32, device='cpu').pin_memory()
dec_pos_pinned = torch.zeros(1, dtype=torch.long, device='cpu').pin_memory()
def write_token_to_gpu(token_id, position):
"""Write token/position to GPU buffers via pinned CPU (no CPU→GPU sync)."""
dec_tid_pinned[0] = token_id
dec_tid_buf.copy_(dec_tid_pinned)
dec_tid32_pinned[0] = token_id
dec_tid32_buf.copy_(dec_tid32_pinned)
dec_pos_pinned[0] = position
dec_pos_buf.copy_(dec_pos_pinned)
# Warmup step first (so CuTeDSL kernels are compiled)
print(" Warmup decode step (compiling CuTeDSL kernels)...", flush=True)
dec_tid_buf[0] = all_tokens[-1]
dec_tid32_buf[0] = all_tokens[-1]
dec_pos_buf[0] = len(all_tokens) - 1
write_token_to_gpu(all_tokens[-1], len(all_tokens) - 1)
X = mHCLayer.init_state(embed(dec_tid_buf))
for li in range(n_layers):
gpu = li % NUM_GPUS
@@ -408,9 +419,7 @@ def run_sync_debug_mode():
sync_errors = []
try:
detector.phase = "decode_forward"
dec_tid_buf[0] = all_tokens[-1]
dec_tid32_buf[0] = all_tokens[-1]
dec_pos_buf[0] = len(all_tokens) - 1
write_token_to_gpu(all_tokens[-1], len(all_tokens) - 1)
X = mHCLayer.init_state(embed(dec_tid_buf))
for li in range(n_layers):
@@ -473,10 +482,13 @@ def run_sync_debug_mode():
g = torch.cuda.CUDAGraph()
torch.cuda.set_device(0)
# Fill static buffers with current decode state
static_token[0] = all_tokens[-1]
static_token32[0] = all_tokens[-1]
static_pos[0] = len(all_tokens) - 1
# Fill static buffers with current decode state (via pinned CPU — no sync)
dec_tid_pinned[0] = all_tokens[-1]
static_token.copy_(dec_tid_pinned)
dec_tid32_pinned[0] = all_tokens[-1]
static_token32.copy_(dec_tid32_pinned)
dec_pos_pinned[0] = len(all_tokens) - 1
static_pos.copy_(dec_pos_pinned)
with torch.cuda.graph(g):
X = mHCLayer.init_state(embed(static_token))