CUDA graph: Fix remaining sync violations from B200 detector run 2
1. grouped_linear.py: Remove conditional host read of GPU tensor - 'if group_offsets[0] != 0' reads GPU value on host → sync - Fix: unconditionally update offsets every call (GPU-only multiply) 2. test_cuda_graph_readiness.py: Use pinned CPU buffers for token transfer - dec_tid_buf[0] = python_int → CPU→GPU sync - Fix: write to pinned CPU buffer, then copy_ (async, graph-capturable) 3. Add dsv4/decode/cuda_graph_decoder.py (skeleton)
This commit is contained in:
172
dsv4/decode/cuda_graph_decoder.py
Normal file
172
dsv4/decode/cuda_graph_decoder.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""CUDA Graph Decode for DSV4 — zero Python dispatch overhead.
|
||||
|
||||
Architecture: Eager-break-at-attention with per-GPU captured subgraphs.
|
||||
|
||||
For each decode step:
|
||||
1. Copy next token to pre-allocated input buffer (pinned CPU → GPU)
|
||||
2. For each GPU subgraph: replay the captured compute
|
||||
3. Between subgraphs: transfer X between GPUs (eager, small tensor)
|
||||
4. FMHA runs eagerly (dynamic KV length) — this is the attention break
|
||||
5. After all layers: hc_head + norm + lm_head (captured on cuda:0)
|
||||
6. Sample next token (eager, outside graph)
|
||||
|
||||
The captured subgraph per GPU contains:
|
||||
- mHC pre_block (attn) → RMSNorm + quantize → attention projections (q_a, q_b, kv)
|
||||
- [EAGER: compressor → indexer → gather → FMHA → inverse RoPE]
|
||||
- o_proj → mHC post_block (attn) → mHC pre_block (ffn) → Router → MoE → SE → mHC post_block (ffn)
|
||||
|
||||
Actually, for simplicity and to avoid splitting the attention, we capture
|
||||
the FULL layer forward (including FMHA) and handle the dynamic KV length
|
||||
by pre-allocating at max_context and masking.
|
||||
|
||||
For the initial implementation, we capture per-LAYER (not per-GPU subgraph)
|
||||
to isolate issues. 61 individual graphs, each capturing one layer's forward.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import time
|
||||
import math
|
||||
|
||||
from dsv4.layers.mhc import mHCLayer, mHCContext
|
||||
|
||||
|
||||
class CUDAGraphDecoder:
|
||||
"""CUDA Graph decoder for DSV4 single-shot inference.
|
||||
|
||||
Captures the entire decode step (all 61 layers + lm_head) as CUDA graphs,
|
||||
eliminating Python dispatch overhead (~94ms) and kernel launch latency.
|
||||
|
||||
Constraints:
|
||||
- All tensors must have fixed addresses (pre-allocated)
|
||||
- No dynamic shapes (T=1 decode has fixed shapes)
|
||||
- No CPU-GPU syncs inside the graph
|
||||
- Cross-GPU transfers happen outside the graph region
|
||||
|
||||
The compressor and KV cache must be graph-safe:
|
||||
- Compressor: always produces output (zeros when buffer incomplete)
|
||||
- KV cache: n_comp stored as GPU tensor, gather is fixed-shape with masking
|
||||
- FMHA: runs at max_seq_len with masking for actual length
|
||||
"""
|
||||
|
||||
def __init__(self, n_layers, num_gpus, devices, hidden_size, n_hc=4):
|
||||
self.n_layers = n_layers
|
||||
self.num_gpus = num_gpus
|
||||
self.devices = devices
|
||||
self.hidden_size = hidden_size
|
||||
self.n_hc = n_hc
|
||||
|
||||
# Per-layer CUDA graphs
|
||||
self.graphs = {} # li -> torch.cuda.CUDAGraph
|
||||
|
||||
# Final graph (hc_head + norm + lm_head) on cuda:0
|
||||
self.lm_graph = None
|
||||
|
||||
# Pre-allocated I/O buffers — fixed addresses for graph capture
|
||||
# X is (1, n_hc, H) BF16
|
||||
self.x_in = {} # li -> tensor on device of layer li
|
||||
self.x_out = {} # li -> tensor on device of layer li
|
||||
|
||||
# Final output buffers on cuda:0
|
||||
self.logits_buf = None
|
||||
self.x_cuda0_buf = None # X after all layers, on cuda:0
|
||||
|
||||
self.captured = False
|
||||
|
||||
def pre_allocate(self, vocab_size=129280):
|
||||
"""Pre-allocate all I/O buffers with fixed addresses."""
|
||||
for li in range(self.n_layers):
|
||||
dev = self.devices[li % self.num_gpus]
|
||||
self.x_in[li] = torch.zeros(1, self.n_hc, self.hidden_size,
|
||||
dtype=torch.bfloat16, device=dev)
|
||||
self.x_out[li] = torch.zeros(1, self.n_hc, self.hidden_size,
|
||||
dtype=torch.bfloat16, device=dev)
|
||||
|
||||
self.logits_buf = torch.zeros(1, vocab_size, dtype=torch.bfloat16, device='cuda:0')
|
||||
self.x_cuda0_buf = torch.zeros(1, self.n_hc, self.hidden_size,
|
||||
dtype=torch.bfloat16, device='cuda:0')
|
||||
|
||||
def capture(self, X_warmup, layer_forward_fn, lm_forward_fn,
|
||||
all_layer_args, lm_args):
|
||||
"""Capture CUDA graphs after warmup.
|
||||
|
||||
Args:
|
||||
X_warmup: X tensor from warmup step (to seed input buffers)
|
||||
layer_forward_fn: function(X, li, **kwargs) -> X_next
|
||||
lm_forward_fn: function(X, **kwargs) -> logits
|
||||
all_layer_args: dict[li] -> kwargs for layer_forward_fn
|
||||
lm_args: kwargs for lm_forward_fn
|
||||
"""
|
||||
print(" Capturing CUDA graphs for decode...", flush=True)
|
||||
|
||||
for li in range(self.n_layers):
|
||||
gpu = li % self.num_gpus
|
||||
dev = self.devices[gpu]
|
||||
torch.cuda.set_device(gpu)
|
||||
|
||||
# Seed input buffer with warmup X
|
||||
if li == 0:
|
||||
self.x_in[li].copy_(X_warmup.to(dev))
|
||||
else:
|
||||
self.x_in[li].copy_(self.x_out[li - 1].to(dev))
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
X_next = layer_forward_fn(self.x_in[li], li, **all_layer_args[li])
|
||||
self.x_out[li].copy_(X_next)
|
||||
|
||||
self.graphs[li] = graph
|
||||
if (li + 1) % 10 == 0:
|
||||
print(f" Captured {li+1}/{self.n_layers} layer graphs", flush=True)
|
||||
|
||||
# Capture hc_head + norm + lm_head on cuda:0
|
||||
torch.cuda.set_device(0)
|
||||
if self.n_layers > 0:
|
||||
self.x_cuda0_buf.copy_(self.x_out[self.n_layers - 1].to('cuda:0'))
|
||||
|
||||
self.lm_graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(self.lm_graph):
|
||||
logits = lm_forward_fn(self.x_cuda0_buf, **lm_args)
|
||||
self.logits_buf.copy_(logits)
|
||||
|
||||
self.captured = True
|
||||
print(f" Captured {len(self.graphs)} layer graphs + lm_head graph", flush=True)
|
||||
|
||||
def replay(self, token_id_gpu, position_gpu):
|
||||
"""Replay captured graphs for one decode step.
|
||||
|
||||
Args:
|
||||
token_id_gpu: (1,) long tensor on cuda:0 — next token ID
|
||||
position_gpu: (1,) long tensor on cuda:0 — current position
|
||||
|
||||
Returns:
|
||||
logits: (1, vocab_size) bfloat16 tensor
|
||||
"""
|
||||
assert self.captured, "Must call capture() before replay()"
|
||||
|
||||
# TODO: Copy token_id/position to the static input buffers that the graph uses.
|
||||
# This requires the graph to reference those buffers.
|
||||
|
||||
# Replay layer graphs
|
||||
for li in range(self.n_layers):
|
||||
gpu = li % self.num_gpus
|
||||
torch.cuda.set_device(gpu)
|
||||
|
||||
# Copy input from previous layer's output
|
||||
if li > 0:
|
||||
prev_gpu = (li - 1) % self.num_gpus
|
||||
if prev_gpu != gpu:
|
||||
self.x_in[li].copy_(self.x_out[li - 1].to(self.devices[gpu]))
|
||||
else:
|
||||
self.x_in[li].copy_(self.x_out[li - 1])
|
||||
|
||||
self.graphs[li].replay()
|
||||
|
||||
# Transfer final X to cuda:0
|
||||
if self.n_layers > 0:
|
||||
self.x_cuda0_buf.copy_(self.x_out[self.n_layers - 1].to('cuda:0'))
|
||||
|
||||
# Replay lm_head graph
|
||||
self.lm_graph.replay()
|
||||
|
||||
return self.logits_buf
|
||||
@@ -333,12 +333,12 @@ class Nvfp4GroupedLinear:
|
||||
x_fp4_grouped = x_fp4_flat.reshape(self.n_local_groups, num_tokens, self.group_in_features // 2)
|
||||
|
||||
# Vectorized scatter — no Python loop, no CPU→GPU sync
|
||||
# Build destination offsets and use index_put_ with pre-allocated indices
|
||||
# Unconditionally update group offsets — GPU-only, no conditional host read.
|
||||
# padded_rows_per_group is a Python int multiplied with a GPU tensor = GPU op.
|
||||
group_offsets = self._group_offset_buf[:self.n_local_groups]
|
||||
if group_offsets[0] != 0 or (self.n_local_groups > 1 and group_offsets[1] != padded_rows_per_group):
|
||||
# Update offsets (only when padded_rows changes, e.g., prefill vs decode)
|
||||
group_offsets.copy_(torch.arange(self.n_local_groups, dtype=torch.int32, device=o.device) * padded_rows_per_group)
|
||||
# Scatter each group's x_fp4 into padded buffer (vectorized)
|
||||
expert_offsets = self._expert_offsets_buf
|
||||
expert_offsets[:self.n_local_groups] = self._expert_offsets_range_buf * padded_rows_per_group
|
||||
# Scatter each group's x_fp4 into padded buffer
|
||||
for g in range(self.n_local_groups):
|
||||
offset = g * padded_rows_per_group
|
||||
padded_x_fp4.view(torch.uint8)[offset:offset + num_tokens] = x_fp4_grouped[g].view(torch.uint8)
|
||||
|
||||
@@ -374,12 +374,23 @@ def run_sync_debug_mode():
|
||||
dec_tid_buf = torch.zeros(1, dtype=torch.long, device='cuda:0')
|
||||
dec_pos_buf = torch.zeros(1, dtype=torch.long, device='cuda:0')
|
||||
dec_tid32_buf = torch.zeros(1, dtype=torch.int32, device='cuda:0')
|
||||
# Pinned CPU buffers for graph-capturable token/position transfer
|
||||
dec_tid_pinned = torch.zeros(1, dtype=torch.long, device='cpu').pin_memory()
|
||||
dec_tid32_pinned = torch.zeros(1, dtype=torch.int32, device='cpu').pin_memory()
|
||||
dec_pos_pinned = torch.zeros(1, dtype=torch.long, device='cpu').pin_memory()
|
||||
|
||||
def write_token_to_gpu(token_id, position):
|
||||
"""Write token/position to GPU buffers via pinned CPU (no CPU→GPU sync)."""
|
||||
dec_tid_pinned[0] = token_id
|
||||
dec_tid_buf.copy_(dec_tid_pinned)
|
||||
dec_tid32_pinned[0] = token_id
|
||||
dec_tid32_buf.copy_(dec_tid32_pinned)
|
||||
dec_pos_pinned[0] = position
|
||||
dec_pos_buf.copy_(dec_pos_pinned)
|
||||
|
||||
# Warmup step first (so CuTeDSL kernels are compiled)
|
||||
print(" Warmup decode step (compiling CuTeDSL kernels)...", flush=True)
|
||||
dec_tid_buf[0] = all_tokens[-1]
|
||||
dec_tid32_buf[0] = all_tokens[-1]
|
||||
dec_pos_buf[0] = len(all_tokens) - 1
|
||||
write_token_to_gpu(all_tokens[-1], len(all_tokens) - 1)
|
||||
X = mHCLayer.init_state(embed(dec_tid_buf))
|
||||
for li in range(n_layers):
|
||||
gpu = li % NUM_GPUS
|
||||
@@ -408,9 +419,7 @@ def run_sync_debug_mode():
|
||||
sync_errors = []
|
||||
try:
|
||||
detector.phase = "decode_forward"
|
||||
dec_tid_buf[0] = all_tokens[-1]
|
||||
dec_tid32_buf[0] = all_tokens[-1]
|
||||
dec_pos_buf[0] = len(all_tokens) - 1
|
||||
write_token_to_gpu(all_tokens[-1], len(all_tokens) - 1)
|
||||
|
||||
X = mHCLayer.init_state(embed(dec_tid_buf))
|
||||
for li in range(n_layers):
|
||||
@@ -473,10 +482,13 @@ def run_sync_debug_mode():
|
||||
g = torch.cuda.CUDAGraph()
|
||||
torch.cuda.set_device(0)
|
||||
|
||||
# Fill static buffers with current decode state
|
||||
static_token[0] = all_tokens[-1]
|
||||
static_token32[0] = all_tokens[-1]
|
||||
static_pos[0] = len(all_tokens) - 1
|
||||
# Fill static buffers with current decode state (via pinned CPU — no sync)
|
||||
dec_tid_pinned[0] = all_tokens[-1]
|
||||
static_token.copy_(dec_tid_pinned)
|
||||
dec_tid32_pinned[0] = all_tokens[-1]
|
||||
static_token32.copy_(dec_tid32_pinned)
|
||||
dec_pos_pinned[0] = len(all_tokens) - 1
|
||||
static_pos.copy_(dec_pos_pinned)
|
||||
|
||||
with torch.cuda.graph(g):
|
||||
X = mHCLayer.init_state(embed(static_token))
|
||||
|
||||
Reference in New Issue
Block a user