1. grouped_linear.py: Remove conditional host read of GPU tensor - 'if group_offsets[0] != 0' reads GPU value on host → sync - Fix: unconditionally update offsets every call (GPU-only multiply) 2. test_cuda_graph_readiness.py: Use pinned CPU buffers for token transfer - dec_tid_buf[0] = python_int → CPU→GPU sync - Fix: write to pinned CPU buffer, then copy_ (async, graph-capturable) 3. Add dsv4/decode/cuda_graph_decoder.py (skeleton)
173 lines
6.6 KiB
Python
173 lines
6.6 KiB
Python
"""CUDA Graph Decode for DSV4 — zero Python dispatch overhead.
|
|
|
|
Architecture: Eager-break-at-attention with per-GPU captured subgraphs.
|
|
|
|
For each decode step:
|
|
1. Copy next token to pre-allocated input buffer (pinned CPU → GPU)
|
|
2. For each GPU subgraph: replay the captured compute
|
|
3. Between subgraphs: transfer X between GPUs (eager, small tensor)
|
|
4. FMHA runs eagerly (dynamic KV length) — this is the attention break
|
|
5. After all layers: hc_head + norm + lm_head (captured on cuda:0)
|
|
6. Sample next token (eager, outside graph)
|
|
|
|
The captured subgraph per GPU contains:
|
|
- mHC pre_block (attn) → RMSNorm + quantize → attention projections (q_a, q_b, kv)
|
|
- [EAGER: compressor → indexer → gather → FMHA → inverse RoPE]
|
|
- o_proj → mHC post_block (attn) → mHC pre_block (ffn) → Router → MoE → SE → mHC post_block (ffn)
|
|
|
|
Actually, for simplicity and to avoid splitting the attention, we capture
|
|
the FULL layer forward (including FMHA) and handle the dynamic KV length
|
|
by pre-allocating at max_context and masking.
|
|
|
|
For the initial implementation, we capture per-LAYER (not per-GPU subgraph)
|
|
to isolate issues. 61 individual graphs, each capturing one layer's forward.
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import time
|
|
import math
|
|
|
|
from dsv4.layers.mhc import mHCLayer, mHCContext
|
|
|
|
|
|
class CUDAGraphDecoder:
|
|
"""CUDA Graph decoder for DSV4 single-shot inference.
|
|
|
|
Captures the entire decode step (all 61 layers + lm_head) as CUDA graphs,
|
|
eliminating Python dispatch overhead (~94ms) and kernel launch latency.
|
|
|
|
Constraints:
|
|
- All tensors must have fixed addresses (pre-allocated)
|
|
- No dynamic shapes (T=1 decode has fixed shapes)
|
|
- No CPU-GPU syncs inside the graph
|
|
- Cross-GPU transfers happen outside the graph region
|
|
|
|
The compressor and KV cache must be graph-safe:
|
|
- Compressor: always produces output (zeros when buffer incomplete)
|
|
- KV cache: n_comp stored as GPU tensor, gather is fixed-shape with masking
|
|
- FMHA: runs at max_seq_len with masking for actual length
|
|
"""
|
|
|
|
def __init__(self, n_layers, num_gpus, devices, hidden_size, n_hc=4):
|
|
self.n_layers = n_layers
|
|
self.num_gpus = num_gpus
|
|
self.devices = devices
|
|
self.hidden_size = hidden_size
|
|
self.n_hc = n_hc
|
|
|
|
# Per-layer CUDA graphs
|
|
self.graphs = {} # li -> torch.cuda.CUDAGraph
|
|
|
|
# Final graph (hc_head + norm + lm_head) on cuda:0
|
|
self.lm_graph = None
|
|
|
|
# Pre-allocated I/O buffers — fixed addresses for graph capture
|
|
# X is (1, n_hc, H) BF16
|
|
self.x_in = {} # li -> tensor on device of layer li
|
|
self.x_out = {} # li -> tensor on device of layer li
|
|
|
|
# Final output buffers on cuda:0
|
|
self.logits_buf = None
|
|
self.x_cuda0_buf = None # X after all layers, on cuda:0
|
|
|
|
self.captured = False
|
|
|
|
def pre_allocate(self, vocab_size=129280):
|
|
"""Pre-allocate all I/O buffers with fixed addresses."""
|
|
for li in range(self.n_layers):
|
|
dev = self.devices[li % self.num_gpus]
|
|
self.x_in[li] = torch.zeros(1, self.n_hc, self.hidden_size,
|
|
dtype=torch.bfloat16, device=dev)
|
|
self.x_out[li] = torch.zeros(1, self.n_hc, self.hidden_size,
|
|
dtype=torch.bfloat16, device=dev)
|
|
|
|
self.logits_buf = torch.zeros(1, vocab_size, dtype=torch.bfloat16, device='cuda:0')
|
|
self.x_cuda0_buf = torch.zeros(1, self.n_hc, self.hidden_size,
|
|
dtype=torch.bfloat16, device='cuda:0')
|
|
|
|
def capture(self, X_warmup, layer_forward_fn, lm_forward_fn,
|
|
all_layer_args, lm_args):
|
|
"""Capture CUDA graphs after warmup.
|
|
|
|
Args:
|
|
X_warmup: X tensor from warmup step (to seed input buffers)
|
|
layer_forward_fn: function(X, li, **kwargs) -> X_next
|
|
lm_forward_fn: function(X, **kwargs) -> logits
|
|
all_layer_args: dict[li] -> kwargs for layer_forward_fn
|
|
lm_args: kwargs for lm_forward_fn
|
|
"""
|
|
print(" Capturing CUDA graphs for decode...", flush=True)
|
|
|
|
for li in range(self.n_layers):
|
|
gpu = li % self.num_gpus
|
|
dev = self.devices[gpu]
|
|
torch.cuda.set_device(gpu)
|
|
|
|
# Seed input buffer with warmup X
|
|
if li == 0:
|
|
self.x_in[li].copy_(X_warmup.to(dev))
|
|
else:
|
|
self.x_in[li].copy_(self.x_out[li - 1].to(dev))
|
|
|
|
graph = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(graph):
|
|
X_next = layer_forward_fn(self.x_in[li], li, **all_layer_args[li])
|
|
self.x_out[li].copy_(X_next)
|
|
|
|
self.graphs[li] = graph
|
|
if (li + 1) % 10 == 0:
|
|
print(f" Captured {li+1}/{self.n_layers} layer graphs", flush=True)
|
|
|
|
# Capture hc_head + norm + lm_head on cuda:0
|
|
torch.cuda.set_device(0)
|
|
if self.n_layers > 0:
|
|
self.x_cuda0_buf.copy_(self.x_out[self.n_layers - 1].to('cuda:0'))
|
|
|
|
self.lm_graph = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(self.lm_graph):
|
|
logits = lm_forward_fn(self.x_cuda0_buf, **lm_args)
|
|
self.logits_buf.copy_(logits)
|
|
|
|
self.captured = True
|
|
print(f" Captured {len(self.graphs)} layer graphs + lm_head graph", flush=True)
|
|
|
|
def replay(self, token_id_gpu, position_gpu):
|
|
"""Replay captured graphs for one decode step.
|
|
|
|
Args:
|
|
token_id_gpu: (1,) long tensor on cuda:0 — next token ID
|
|
position_gpu: (1,) long tensor on cuda:0 — current position
|
|
|
|
Returns:
|
|
logits: (1, vocab_size) bfloat16 tensor
|
|
"""
|
|
assert self.captured, "Must call capture() before replay()"
|
|
|
|
# TODO: Copy token_id/position to the static input buffers that the graph uses.
|
|
# This requires the graph to reference those buffers.
|
|
|
|
# Replay layer graphs
|
|
for li in range(self.n_layers):
|
|
gpu = li % self.num_gpus
|
|
torch.cuda.set_device(gpu)
|
|
|
|
# Copy input from previous layer's output
|
|
if li > 0:
|
|
prev_gpu = (li - 1) % self.num_gpus
|
|
if prev_gpu != gpu:
|
|
self.x_in[li].copy_(self.x_out[li - 1].to(self.devices[gpu]))
|
|
else:
|
|
self.x_in[li].copy_(self.x_out[li - 1])
|
|
|
|
self.graphs[li].replay()
|
|
|
|
# Transfer final X to cuda:0
|
|
if self.n_layers > 0:
|
|
self.x_cuda0_buf.copy_(self.x_out[self.n_layers - 1].to('cuda:0'))
|
|
|
|
# Replay lm_head graph
|
|
self.lm_graph.replay()
|
|
|
|
return self.logits_buf
|