From df05289d6fb89553efe1e6c1d4bc2ed3f00e476d Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 3 Jun 2026 17:20:34 +0000 Subject: [PATCH] CUDA graph: Fix remaining sync violations from B200 detector run 2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. grouped_linear.py: Remove conditional host read of GPU tensor - 'if group_offsets[0] != 0' reads GPU value on host → sync - Fix: unconditionally update offsets every call (GPU-only multiply) 2. test_cuda_graph_readiness.py: Use pinned CPU buffers for token transfer - dec_tid_buf[0] = python_int → CPU→GPU sync - Fix: write to pinned CPU buffer, then copy_ (async, graph-capturable) 3. Add dsv4/decode/cuda_graph_decoder.py (skeleton) --- dsv4/decode/cuda_graph_decoder.py | 172 ++++++++++++++++++++++++ dsv4/layers/grouped_linear.py | 10 +- tests/unit/test_cuda_graph_readiness.py | 32 +++-- 3 files changed, 199 insertions(+), 15 deletions(-) create mode 100644 dsv4/decode/cuda_graph_decoder.py diff --git a/dsv4/decode/cuda_graph_decoder.py b/dsv4/decode/cuda_graph_decoder.py new file mode 100644 index 00000000..54082a65 --- /dev/null +++ b/dsv4/decode/cuda_graph_decoder.py @@ -0,0 +1,172 @@ +"""CUDA Graph Decode for DSV4 — zero Python dispatch overhead. + +Architecture: Eager-break-at-attention with per-GPU captured subgraphs. + +For each decode step: + 1. Copy next token to pre-allocated input buffer (pinned CPU → GPU) + 2. For each GPU subgraph: replay the captured compute + 3. Between subgraphs: transfer X between GPUs (eager, small tensor) + 4. FMHA runs eagerly (dynamic KV length) — this is the attention break + 5. After all layers: hc_head + norm + lm_head (captured on cuda:0) + 6. Sample next token (eager, outside graph) + +The captured subgraph per GPU contains: + - mHC pre_block (attn) → RMSNorm + quantize → attention projections (q_a, q_b, kv) + - [EAGER: compressor → indexer → gather → FMHA → inverse RoPE] + - o_proj → mHC post_block (attn) → mHC pre_block (ffn) → Router → MoE → SE → mHC post_block (ffn) + +Actually, for simplicity and to avoid splitting the attention, we capture +the FULL layer forward (including FMHA) and handle the dynamic KV length +by pre-allocating at max_context and masking. + +For the initial implementation, we capture per-LAYER (not per-GPU subgraph) +to isolate issues. 61 individual graphs, each capturing one layer's forward. +""" + +import torch +import torch.nn.functional as F +import time +import math + +from dsv4.layers.mhc import mHCLayer, mHCContext + + +class CUDAGraphDecoder: + """CUDA Graph decoder for DSV4 single-shot inference. + + Captures the entire decode step (all 61 layers + lm_head) as CUDA graphs, + eliminating Python dispatch overhead (~94ms) and kernel launch latency. + + Constraints: + - All tensors must have fixed addresses (pre-allocated) + - No dynamic shapes (T=1 decode has fixed shapes) + - No CPU-GPU syncs inside the graph + - Cross-GPU transfers happen outside the graph region + + The compressor and KV cache must be graph-safe: + - Compressor: always produces output (zeros when buffer incomplete) + - KV cache: n_comp stored as GPU tensor, gather is fixed-shape with masking + - FMHA: runs at max_seq_len with masking for actual length + """ + + def __init__(self, n_layers, num_gpus, devices, hidden_size, n_hc=4): + self.n_layers = n_layers + self.num_gpus = num_gpus + self.devices = devices + self.hidden_size = hidden_size + self.n_hc = n_hc + + # Per-layer CUDA graphs + self.graphs = {} # li -> torch.cuda.CUDAGraph + + # Final graph (hc_head + norm + lm_head) on cuda:0 + self.lm_graph = None + + # Pre-allocated I/O buffers — fixed addresses for graph capture + # X is (1, n_hc, H) BF16 + self.x_in = {} # li -> tensor on device of layer li + self.x_out = {} # li -> tensor on device of layer li + + # Final output buffers on cuda:0 + self.logits_buf = None + self.x_cuda0_buf = None # X after all layers, on cuda:0 + + self.captured = False + + def pre_allocate(self, vocab_size=129280): + """Pre-allocate all I/O buffers with fixed addresses.""" + for li in range(self.n_layers): + dev = self.devices[li % self.num_gpus] + self.x_in[li] = torch.zeros(1, self.n_hc, self.hidden_size, + dtype=torch.bfloat16, device=dev) + self.x_out[li] = torch.zeros(1, self.n_hc, self.hidden_size, + dtype=torch.bfloat16, device=dev) + + self.logits_buf = torch.zeros(1, vocab_size, dtype=torch.bfloat16, device='cuda:0') + self.x_cuda0_buf = torch.zeros(1, self.n_hc, self.hidden_size, + dtype=torch.bfloat16, device='cuda:0') + + def capture(self, X_warmup, layer_forward_fn, lm_forward_fn, + all_layer_args, lm_args): + """Capture CUDA graphs after warmup. + + Args: + X_warmup: X tensor from warmup step (to seed input buffers) + layer_forward_fn: function(X, li, **kwargs) -> X_next + lm_forward_fn: function(X, **kwargs) -> logits + all_layer_args: dict[li] -> kwargs for layer_forward_fn + lm_args: kwargs for lm_forward_fn + """ + print(" Capturing CUDA graphs for decode...", flush=True) + + for li in range(self.n_layers): + gpu = li % self.num_gpus + dev = self.devices[gpu] + torch.cuda.set_device(gpu) + + # Seed input buffer with warmup X + if li == 0: + self.x_in[li].copy_(X_warmup.to(dev)) + else: + self.x_in[li].copy_(self.x_out[li - 1].to(dev)) + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + X_next = layer_forward_fn(self.x_in[li], li, **all_layer_args[li]) + self.x_out[li].copy_(X_next) + + self.graphs[li] = graph + if (li + 1) % 10 == 0: + print(f" Captured {li+1}/{self.n_layers} layer graphs", flush=True) + + # Capture hc_head + norm + lm_head on cuda:0 + torch.cuda.set_device(0) + if self.n_layers > 0: + self.x_cuda0_buf.copy_(self.x_out[self.n_layers - 1].to('cuda:0')) + + self.lm_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self.lm_graph): + logits = lm_forward_fn(self.x_cuda0_buf, **lm_args) + self.logits_buf.copy_(logits) + + self.captured = True + print(f" Captured {len(self.graphs)} layer graphs + lm_head graph", flush=True) + + def replay(self, token_id_gpu, position_gpu): + """Replay captured graphs for one decode step. + + Args: + token_id_gpu: (1,) long tensor on cuda:0 — next token ID + position_gpu: (1,) long tensor on cuda:0 — current position + + Returns: + logits: (1, vocab_size) bfloat16 tensor + """ + assert self.captured, "Must call capture() before replay()" + + # TODO: Copy token_id/position to the static input buffers that the graph uses. + # This requires the graph to reference those buffers. + + # Replay layer graphs + for li in range(self.n_layers): + gpu = li % self.num_gpus + torch.cuda.set_device(gpu) + + # Copy input from previous layer's output + if li > 0: + prev_gpu = (li - 1) % self.num_gpus + if prev_gpu != gpu: + self.x_in[li].copy_(self.x_out[li - 1].to(self.devices[gpu])) + else: + self.x_in[li].copy_(self.x_out[li - 1]) + + self.graphs[li].replay() + + # Transfer final X to cuda:0 + if self.n_layers > 0: + self.x_cuda0_buf.copy_(self.x_out[self.n_layers - 1].to('cuda:0')) + + # Replay lm_head graph + self.lm_graph.replay() + + return self.logits_buf diff --git a/dsv4/layers/grouped_linear.py b/dsv4/layers/grouped_linear.py index 8e405e07..61057e22 100644 --- a/dsv4/layers/grouped_linear.py +++ b/dsv4/layers/grouped_linear.py @@ -333,12 +333,12 @@ class Nvfp4GroupedLinear: x_fp4_grouped = x_fp4_flat.reshape(self.n_local_groups, num_tokens, self.group_in_features // 2) # Vectorized scatter — no Python loop, no CPU→GPU sync - # Build destination offsets and use index_put_ with pre-allocated indices + # Unconditionally update group offsets — GPU-only, no conditional host read. + # padded_rows_per_group is a Python int multiplied with a GPU tensor = GPU op. group_offsets = self._group_offset_buf[:self.n_local_groups] - if group_offsets[0] != 0 or (self.n_local_groups > 1 and group_offsets[1] != padded_rows_per_group): - # Update offsets (only when padded_rows changes, e.g., prefill vs decode) - group_offsets.copy_(torch.arange(self.n_local_groups, dtype=torch.int32, device=o.device) * padded_rows_per_group) - # Scatter each group's x_fp4 into padded buffer (vectorized) + expert_offsets = self._expert_offsets_buf + expert_offsets[:self.n_local_groups] = self._expert_offsets_range_buf * padded_rows_per_group + # Scatter each group's x_fp4 into padded buffer for g in range(self.n_local_groups): offset = g * padded_rows_per_group padded_x_fp4.view(torch.uint8)[offset:offset + num_tokens] = x_fp4_grouped[g].view(torch.uint8) diff --git a/tests/unit/test_cuda_graph_readiness.py b/tests/unit/test_cuda_graph_readiness.py index c8ab5726..f26bb7fd 100644 --- a/tests/unit/test_cuda_graph_readiness.py +++ b/tests/unit/test_cuda_graph_readiness.py @@ -374,12 +374,23 @@ def run_sync_debug_mode(): dec_tid_buf = torch.zeros(1, dtype=torch.long, device='cuda:0') dec_pos_buf = torch.zeros(1, dtype=torch.long, device='cuda:0') dec_tid32_buf = torch.zeros(1, dtype=torch.int32, device='cuda:0') + # Pinned CPU buffers for graph-capturable token/position transfer + dec_tid_pinned = torch.zeros(1, dtype=torch.long, device='cpu').pin_memory() + dec_tid32_pinned = torch.zeros(1, dtype=torch.int32, device='cpu').pin_memory() + dec_pos_pinned = torch.zeros(1, dtype=torch.long, device='cpu').pin_memory() + + def write_token_to_gpu(token_id, position): + """Write token/position to GPU buffers via pinned CPU (no CPU→GPU sync).""" + dec_tid_pinned[0] = token_id + dec_tid_buf.copy_(dec_tid_pinned) + dec_tid32_pinned[0] = token_id + dec_tid32_buf.copy_(dec_tid32_pinned) + dec_pos_pinned[0] = position + dec_pos_buf.copy_(dec_pos_pinned) # Warmup step first (so CuTeDSL kernels are compiled) print(" Warmup decode step (compiling CuTeDSL kernels)...", flush=True) - dec_tid_buf[0] = all_tokens[-1] - dec_tid32_buf[0] = all_tokens[-1] - dec_pos_buf[0] = len(all_tokens) - 1 + write_token_to_gpu(all_tokens[-1], len(all_tokens) - 1) X = mHCLayer.init_state(embed(dec_tid_buf)) for li in range(n_layers): gpu = li % NUM_GPUS @@ -408,9 +419,7 @@ def run_sync_debug_mode(): sync_errors = [] try: detector.phase = "decode_forward" - dec_tid_buf[0] = all_tokens[-1] - dec_tid32_buf[0] = all_tokens[-1] - dec_pos_buf[0] = len(all_tokens) - 1 + write_token_to_gpu(all_tokens[-1], len(all_tokens) - 1) X = mHCLayer.init_state(embed(dec_tid_buf)) for li in range(n_layers): @@ -473,10 +482,13 @@ def run_sync_debug_mode(): g = torch.cuda.CUDAGraph() torch.cuda.set_device(0) - # Fill static buffers with current decode state - static_token[0] = all_tokens[-1] - static_token32[0] = all_tokens[-1] - static_pos[0] = len(all_tokens) - 1 + # Fill static buffers with current decode state (via pinned CPU — no sync) + dec_tid_pinned[0] = all_tokens[-1] + static_token.copy_(dec_tid_pinned) + dec_tid32_pinned[0] = all_tokens[-1] + static_token32.copy_(dec_tid32_pinned) + dec_pos_pinned[0] = len(all_tokens) - 1 + static_pos.copy_(dec_pos_pinned) with torch.cuda.graph(g): X = mHCLayer.init_state(embed(static_token))