"""CUDA Graph Decode for DSV4 — zero Python dispatch overhead. Architecture: Eager-break-at-attention with per-GPU captured subgraphs. For each decode step: 1. Copy next token to pre-allocated input buffer (pinned CPU → GPU) 2. For each GPU subgraph: replay the captured compute 3. Between subgraphs: transfer X between GPUs (eager, small tensor) 4. FMHA runs eagerly (dynamic KV length) — this is the attention break 5. After all layers: hc_head + norm + lm_head (captured on cuda:0) 6. Sample next token (eager, outside graph) The captured subgraph per GPU contains: - mHC pre_block (attn) → RMSNorm + quantize → attention projections (q_a, q_b, kv) - [EAGER: compressor → indexer → gather → FMHA → inverse RoPE] - o_proj → mHC post_block (attn) → mHC pre_block (ffn) → Router → MoE → SE → mHC post_block (ffn) Actually, for simplicity and to avoid splitting the attention, we capture the FULL layer forward (including FMHA) and handle the dynamic KV length by pre-allocating at max_context and masking. For the initial implementation, we capture per-LAYER (not per-GPU subgraph) to isolate issues. 61 individual graphs, each capturing one layer's forward. """ import torch import torch.nn.functional as F import time import math from dsv4.layers.mhc import mHCLayer, mHCContext class CUDAGraphDecoder: """CUDA Graph decoder for DSV4 single-shot inference. Captures the entire decode step (all 61 layers + lm_head) as CUDA graphs, eliminating Python dispatch overhead (~94ms) and kernel launch latency. Constraints: - All tensors must have fixed addresses (pre-allocated) - No dynamic shapes (T=1 decode has fixed shapes) - No CPU-GPU syncs inside the graph - Cross-GPU transfers happen outside the graph region The compressor and KV cache must be graph-safe: - Compressor: always produces output (zeros when buffer incomplete) - KV cache: n_comp stored as GPU tensor, gather is fixed-shape with masking - FMHA: runs at max_seq_len with masking for actual length """ def __init__(self, n_layers, num_gpus, devices, hidden_size, n_hc=4): self.n_layers = n_layers self.num_gpus = num_gpus self.devices = devices self.hidden_size = hidden_size self.n_hc = n_hc # Per-layer CUDA graphs self.graphs = {} # li -> torch.cuda.CUDAGraph # Final graph (hc_head + norm + lm_head) on cuda:0 self.lm_graph = None # Pre-allocated I/O buffers — fixed addresses for graph capture # X is (1, n_hc, H) BF16 self.x_in = {} # li -> tensor on device of layer li self.x_out = {} # li -> tensor on device of layer li # Final output buffers on cuda:0 self.logits_buf = None self.x_cuda0_buf = None # X after all layers, on cuda:0 self.captured = False def pre_allocate(self, vocab_size=129280): """Pre-allocate all I/O buffers with fixed addresses.""" for li in range(self.n_layers): dev = self.devices[li % self.num_gpus] self.x_in[li] = torch.zeros(1, self.n_hc, self.hidden_size, dtype=torch.bfloat16, device=dev) self.x_out[li] = torch.zeros(1, self.n_hc, self.hidden_size, dtype=torch.bfloat16, device=dev) self.logits_buf = torch.zeros(1, vocab_size, dtype=torch.bfloat16, device='cuda:0') self.x_cuda0_buf = torch.zeros(1, self.n_hc, self.hidden_size, dtype=torch.bfloat16, device='cuda:0') def capture(self, X_warmup, layer_forward_fn, lm_forward_fn, all_layer_args, lm_args): """Capture CUDA graphs after warmup. Args: X_warmup: X tensor from warmup step (to seed input buffers) layer_forward_fn: function(X, li, **kwargs) -> X_next lm_forward_fn: function(X, **kwargs) -> logits all_layer_args: dict[li] -> kwargs for layer_forward_fn lm_args: kwargs for lm_forward_fn """ print(" Capturing CUDA graphs for decode...", flush=True) for li in range(self.n_layers): gpu = li % self.num_gpus dev = self.devices[gpu] torch.cuda.set_device(gpu) # Seed input buffer with warmup X if li == 0: self.x_in[li].copy_(X_warmup.to(dev)) else: self.x_in[li].copy_(self.x_out[li - 1].to(dev)) graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): X_next = layer_forward_fn(self.x_in[li], li, **all_layer_args[li]) self.x_out[li].copy_(X_next) self.graphs[li] = graph if (li + 1) % 10 == 0: print(f" Captured {li+1}/{self.n_layers} layer graphs", flush=True) # Capture hc_head + norm + lm_head on cuda:0 torch.cuda.set_device(0) if self.n_layers > 0: self.x_cuda0_buf.copy_(self.x_out[self.n_layers - 1].to('cuda:0')) self.lm_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.lm_graph): logits = lm_forward_fn(self.x_cuda0_buf, **lm_args) self.logits_buf.copy_(logits) self.captured = True print(f" Captured {len(self.graphs)} layer graphs + lm_head graph", flush=True) def replay(self, token_id_gpu, position_gpu): """Replay captured graphs for one decode step. Args: token_id_gpu: (1,) long tensor on cuda:0 — next token ID position_gpu: (1,) long tensor on cuda:0 — current position Returns: logits: (1, vocab_size) bfloat16 tensor """ assert self.captured, "Must call capture() before replay()" # TODO: Copy token_id/position to the static input buffers that the graph uses. # This requires the graph to reference those buffers. # Replay layer graphs for li in range(self.n_layers): gpu = li % self.num_gpus torch.cuda.set_device(gpu) # Copy input from previous layer's output if li > 0: prev_gpu = (li - 1) % self.num_gpus if prev_gpu != gpu: self.x_in[li].copy_(self.x_out[li - 1].to(self.devices[gpu])) else: self.x_in[li].copy_(self.x_out[li - 1]) self.graphs[li].replay() # Transfer final X to cuda:0 if self.n_layers > 0: self.x_cuda0_buf.copy_(self.x_out[self.n_layers - 1].to('cuda:0')) # Replay lm_head graph self.lm_graph.replay() return self.logits_buf