|
|
|
|
@@ -134,39 +134,69 @@ def unweighted_rmsnorm(x, eps=1e-6):
|
|
|
|
|
class CUDAGraphDecoder:
|
|
|
|
|
"""Captures and replays CUDA graphs for the decode loop.
|
|
|
|
|
|
|
|
|
|
Architecture: One graph per layer, capturing the entire forward_layer.
|
|
|
|
|
After one warmup step (which also fixes gsa values), each layer's
|
|
|
|
|
forward is captured as a single CUDA graph. Replay eliminates Python
|
|
|
|
|
dispatch overhead (~94ms for 61 layers) and kernel launch latency.
|
|
|
|
|
Architecture (Phase 1: eager-break-at-attention):
|
|
|
|
|
Each layer is split into two graph-captured sub-regions with eager attention
|
|
|
|
|
in between:
|
|
|
|
|
|
|
|
|
|
Graph A (pre-attention): mHC pre_block(attn) + fused RMSNorm + quantize
|
|
|
|
|
+ q_a + q_a_norm + q_b + kv projections
|
|
|
|
|
→ writes x_normed, q_heads, kv_3d, ctx_a to
|
|
|
|
|
pre-allocated buffers for eager attention
|
|
|
|
|
Eager (attention): Compressor → Indexer → KV gather → FMHA
|
|
|
|
|
→ inverse RoPE → o_a + o_b → F_attn
|
|
|
|
|
→ writes F_attn to pre-allocated buffer
|
|
|
|
|
Graph B (post-attention): mHC post_block(attn) + mHC pre_block(ffn)
|
|
|
|
|
+ fused RMSNorm + quantize + Router + MoE + SE
|
|
|
|
|
+ mHC post_block(ffn)
|
|
|
|
|
→ writes X_next to pre-allocated output buffer
|
|
|
|
|
|
|
|
|
|
The attention path (compressor, FMHA, inverse RoPE) has dynamic shapes
|
|
|
|
|
and data-dependent control flow — it MUST run eagerly.
|
|
|
|
|
The compute path has fixed shapes for T=1 decode — it CAN be captured.
|
|
|
|
|
|
|
|
|
|
The hc_head + norm + lm_head are captured as a separate graph on cuda:0.
|
|
|
|
|
|
|
|
|
|
Cross-GPU transfers (X.to(cuda:N)) happen OUTSIDE graphs between layers.
|
|
|
|
|
|
|
|
|
|
Constraints:
|
|
|
|
|
- All tensors must have fixed addresses (pre-allocated)
|
|
|
|
|
- No CPU-GPU syncs inside the graph
|
|
|
|
|
- All tensors in captured regions must have fixed addresses (pre-allocated)
|
|
|
|
|
- No CPU-GPU syncs inside captured regions
|
|
|
|
|
- The only per-step sync is argmax for sampling (outside graph)
|
|
|
|
|
- FMHA pads KV to 128 → fixed shape for graph capture
|
|
|
|
|
- Compressor returns None on non-boundary steps → graph captures the
|
|
|
|
|
path taken during warmup (typically the None path for HCA r=128)
|
|
|
|
|
- Attention runs eagerly — dynamic shapes are OK there
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, n_layers, num_gpus, hidden_size, devices):
|
|
|
|
|
def __init__(self, n_layers, num_gpus, hidden_size, devices, cfg):
|
|
|
|
|
self.n_layers = n_layers
|
|
|
|
|
self.num_gpus = num_gpus
|
|
|
|
|
self.hidden_size = hidden_size
|
|
|
|
|
self.devices = devices
|
|
|
|
|
self.captured = False
|
|
|
|
|
|
|
|
|
|
# One graph per layer + lm_head
|
|
|
|
|
self.graphs = {} # li -> torch.cuda.CUDAGraph
|
|
|
|
|
self.lm_graph = None # single graph for hc_head + norm + lm_head on cuda:0
|
|
|
|
|
# Model dimensions for buffer pre-allocation
|
|
|
|
|
self.n_h = cfg.get("num_attention_heads", 128)
|
|
|
|
|
self.hd = cfg.get("head_dim", 512)
|
|
|
|
|
self.rd = cfg.get("qk_rope_head_dim", 64)
|
|
|
|
|
|
|
|
|
|
# Two graphs per layer (A: pre-attn, B: post-attn+FFN) + lm_head
|
|
|
|
|
self.graphs_a = {} # li -> torch.cuda.CUDAGraph
|
|
|
|
|
self.graphs_b = {} # li -> torch.cuda.CUDAGraph
|
|
|
|
|
self.lm_graph = None # single graph for hc_head + norm + lm_head on cuda:0
|
|
|
|
|
|
|
|
|
|
# Pre-allocated I/O buffers — fixed addresses for graph capture
|
|
|
|
|
self.x_in_bufs = {} # li -> (1, 4, H) BF16 on layer's device
|
|
|
|
|
self.x_out_bufs = {} # li -> (1, 4, H) BF16 on layer's device
|
|
|
|
|
|
|
|
|
|
# Graph A output buffers (read by eager attention, written by graph A)
|
|
|
|
|
# These survive across the graph A → eager → graph B boundary.
|
|
|
|
|
self.x_normed_bufs = {} # li -> (1, H) BF16 — for compressor/indexer
|
|
|
|
|
self.q_heads_bufs = {} # li -> (1, n_h, hd) BF16 — for FMHA
|
|
|
|
|
self.kv_3d_bufs = {} # li -> (1, 1, hd) BF16 — for FMHA (pre-RoPE)
|
|
|
|
|
self.ctx_a_B_bufs = {} # li -> (1, 4, 4) FP32 — B_l for post_block
|
|
|
|
|
self.ctx_a_C_bufs = {} # li -> (1, 4) BF16 — C_l for post_block
|
|
|
|
|
self.X_mid_bufs = {} # li -> (1, 4, H) BF16 — X_l for post_block
|
|
|
|
|
|
|
|
|
|
# Graph B input buffer (written by eager attention, read by graph B)
|
|
|
|
|
self.F_attn_bufs = {} # li -> (1, H) BF16 — attention output for post_block
|
|
|
|
|
|
|
|
|
|
# lm_head graph buffers (on cuda:0)
|
|
|
|
|
self.x_lm_in = None # (1, 4, H) BF16 on cuda:0
|
|
|
|
|
self.logits_buf = None # (1, vocab_size) BF16 on cuda:0
|
|
|
|
|
@@ -175,11 +205,22 @@ class CUDAGraphDecoder:
|
|
|
|
|
"""Pre-allocate all I/O buffers with fixed addresses."""
|
|
|
|
|
H = self.hidden_size
|
|
|
|
|
V = cfg.get("vocab_size", 129280)
|
|
|
|
|
n_h = self.n_h
|
|
|
|
|
hd = self.hd
|
|
|
|
|
|
|
|
|
|
for li in range(self.n_layers):
|
|
|
|
|
dev = self.devices[li % self.num_gpus]
|
|
|
|
|
self.x_in_bufs[li] = torch.zeros(1, 4, H, dtype=torch.bfloat16, device=dev)
|
|
|
|
|
self.x_out_bufs[li] = torch.zeros(1, 4, H, dtype=torch.bfloat16, device=dev)
|
|
|
|
|
# Graph A intermediates
|
|
|
|
|
self.x_normed_bufs[li] = torch.zeros(1, H, dtype=torch.bfloat16, device=dev)
|
|
|
|
|
self.q_heads_bufs[li] = torch.zeros(1, n_h, hd, dtype=torch.bfloat16, device=dev)
|
|
|
|
|
self.kv_3d_bufs[li] = torch.zeros(1, 1, hd, dtype=torch.bfloat16, device=dev)
|
|
|
|
|
self.ctx_a_B_bufs[li] = torch.zeros(1, 4, 4, dtype=torch.float32, device=dev)
|
|
|
|
|
self.ctx_a_C_bufs[li] = torch.zeros(1, 4, dtype=torch.bfloat16, device=dev)
|
|
|
|
|
self.X_mid_bufs[li] = torch.zeros(1, 4, H, dtype=torch.bfloat16, device=dev)
|
|
|
|
|
# Graph B input
|
|
|
|
|
self.F_attn_bufs[li] = torch.zeros(1, H, dtype=torch.bfloat16, device=dev)
|
|
|
|
|
|
|
|
|
|
# lm_head graph I/O (cuda:0 only)
|
|
|
|
|
self.x_lm_in = torch.zeros(1, 4, H, dtype=torch.bfloat16, device='cuda:0')
|
|
|
|
|
@@ -189,41 +230,123 @@ class CUDAGraphDecoder:
|
|
|
|
|
kv_caches, compressors, indexers, moe_runners, se_runners,
|
|
|
|
|
routers, prod_lins, layer_w, rope_caches, hc_head,
|
|
|
|
|
final_norm_w, lm_w, dec_pos_per_gpu, dec_tid32_per_gpu, comp_rope_caches=None):
|
|
|
|
|
"""Capture CUDA graphs for all layers + lm_head.
|
|
|
|
|
"""Capture CUDA graphs for all layers (A/B split) + lm_head.
|
|
|
|
|
|
|
|
|
|
Phase 1: eager-break-at-attention. Graphs A/B capture the compute-heavy
|
|
|
|
|
path; the attention path runs eagerly between A and B replays.
|
|
|
|
|
|
|
|
|
|
Must be called after one warmup step so that:
|
|
|
|
|
1. All CuTeDSL kernels are compiled and cached
|
|
|
|
|
2. gsa values are fixed (from warmup_gsa)
|
|
|
|
|
3. CUDA kernels are warmed up (first launch is often slower)
|
|
|
|
|
"""
|
|
|
|
|
H = self.hidden_size
|
|
|
|
|
from dsv4.ops.quantize import (
|
|
|
|
|
mhc_rmsnorm_quantize_nvfp4, dequantize_nvfp4,
|
|
|
|
|
rmsnorm_quantize_nvfp4 as _rmsnorm_quantize,
|
|
|
|
|
)
|
|
|
|
|
from dsv4.layers.mhc import mHCContext
|
|
|
|
|
|
|
|
|
|
print(" Capturing CUDA graphs for decode (1 graph per layer)...", flush=True)
|
|
|
|
|
H = self.hidden_size
|
|
|
|
|
n_h = self.n_h
|
|
|
|
|
hd = self.hd
|
|
|
|
|
rd = self.rd
|
|
|
|
|
|
|
|
|
|
print(" Capturing CUDA graphs (A/B split: compute captured, attention eager)...", flush=True)
|
|
|
|
|
|
|
|
|
|
for li in range(self.n_layers):
|
|
|
|
|
gpu = li % self.num_gpus
|
|
|
|
|
dev = self.devices[gpu]
|
|
|
|
|
torch.cuda.set_device(gpu)
|
|
|
|
|
|
|
|
|
|
graph = torch.cuda.CUDAGraph()
|
|
|
|
|
with torch.cuda.graph(graph):
|
|
|
|
|
X_out = forward_layer(
|
|
|
|
|
self.x_in_bufs[li], layer_w[li], li, cfg, *rope_caches[gpu],
|
|
|
|
|
attn_mhcs.get(li), ffn_mhcs.get(li),
|
|
|
|
|
attn_norms.get(li), ffn_norms.get(li),
|
|
|
|
|
kv_caches[li], dec_pos_per_gpu[gpu], dec_tid32_per_gpu[gpu],
|
|
|
|
|
compressors.get(li), indexers.get(li),
|
|
|
|
|
moe_runners.get(li), se_runners.get(li), routers.get(li),
|
|
|
|
|
prod_lin=prod_lins.get(li),
|
|
|
|
|
_use_fused_rmsnorm_quantize=True,
|
|
|
|
|
comp_rope_cos=comp_rope_caches[gpu][0] if comp_rope_caches else None,
|
|
|
|
|
comp_rope_sin=comp_rope_caches[gpu][1] if comp_rope_caches else None,
|
|
|
|
|
)
|
|
|
|
|
self.x_out_bufs[li].copy_(X_out)
|
|
|
|
|
attn_mhc = attn_mhcs.get(li)
|
|
|
|
|
ffn_mhc = ffn_mhcs.get(li)
|
|
|
|
|
attn_norm_w = attn_norms.get(li)
|
|
|
|
|
ffn_norm_w = ffn_norms.get(li)
|
|
|
|
|
pl = prod_lins.get(li, {})
|
|
|
|
|
pfx = f"model.layers.{li}.self_attn"
|
|
|
|
|
|
|
|
|
|
# ======== Graph A: pre-attention compute ========
|
|
|
|
|
# Input: X_l = self.x_in_bufs[li] (1, 4, H)
|
|
|
|
|
# Output: x_normed, q_heads, kv_3d, ctx_a, X_l → pre-allocated buffers
|
|
|
|
|
graph_a = torch.cuda.CUDAGraph()
|
|
|
|
|
with torch.cuda.graph(graph_a):
|
|
|
|
|
X_l = self.x_in_bufs[li]
|
|
|
|
|
|
|
|
|
|
# 1. mHC pre_block (attn) — fused P5
|
|
|
|
|
A_l_a, B_l_a, C_l_a = attn_mhc._dynamic_params(X_l)
|
|
|
|
|
x_quant_attn = mhc_rmsnorm_quantize_nvfp4(
|
|
|
|
|
X_l, A_l_a, attn_norm_w.to(dev, torch.float32))
|
|
|
|
|
x_normed = dequantize_nvfp4(x_quant_attn.x_fp4, x_quant_attn.x_sf, x_quant_attn.gsa)
|
|
|
|
|
|
|
|
|
|
# 2. Attention projections
|
|
|
|
|
q_a = pl['q_a'].run_from_quantized(x_quant_attn)
|
|
|
|
|
q_norm_w = layer_w[li].get(f"{pfx}.q_a_norm.weight")
|
|
|
|
|
if q_norm_w is not None:
|
|
|
|
|
q_a_quant = _rmsnorm_quantize(q_a, q_norm_w.to(dev, torch.float32))
|
|
|
|
|
q_a = dequantize_nvfp4(q_a_quant.x_fp4, q_a_quant.x_sf, q_a_quant.gsa)
|
|
|
|
|
q = pl['q_b'].run_from_quantized(q_a_quant)
|
|
|
|
|
else:
|
|
|
|
|
q = pl['q_b'](q_a)
|
|
|
|
|
q = unweighted_rmsnorm(q).bfloat16()
|
|
|
|
|
# NOTE: RoPE is applied in the eager attention path (dynamic positions)
|
|
|
|
|
q_heads = q.reshape(1, n_h, hd)
|
|
|
|
|
|
|
|
|
|
kv = pl['kv'].run_from_quantized(x_quant_attn)
|
|
|
|
|
kv_norm_w_k = layer_w[li].get(f"{pfx}.kv_norm.weight")
|
|
|
|
|
if kv_norm_w_k is not None:
|
|
|
|
|
kv = rmsnorm(kv, kv_norm_w_k.to(dev, torch.float32))
|
|
|
|
|
kv_3d = kv.reshape(1, 1, hd)
|
|
|
|
|
# NOTE: RoPE is applied in the eager attention path
|
|
|
|
|
|
|
|
|
|
# Write to pre-allocated buffers for eager attention path
|
|
|
|
|
self.x_normed_bufs[li].copy_(x_normed)
|
|
|
|
|
self.q_heads_bufs[li].copy_(q_heads)
|
|
|
|
|
self.kv_3d_bufs[li].copy_(kv_3d)
|
|
|
|
|
self.ctx_a_B_bufs[li].copy_(B_l_a)
|
|
|
|
|
self.ctx_a_C_bufs[li].copy_(C_l_a)
|
|
|
|
|
self.X_mid_bufs[li].copy_(X_l)
|
|
|
|
|
|
|
|
|
|
self.graphs_a[li] = graph_a
|
|
|
|
|
|
|
|
|
|
# ======== Graph B: post-attention + FFN compute ========
|
|
|
|
|
# Input: X_mid = self.X_mid_bufs[li], F_attn = self.F_attn_bufs[li]
|
|
|
|
|
# Output: X_next → self.x_out_bufs[li]
|
|
|
|
|
graph_b = torch.cuda.CUDAGraph()
|
|
|
|
|
with torch.cuda.graph(graph_b):
|
|
|
|
|
X_mid = self.X_mid_bufs[li]
|
|
|
|
|
F_attn = self.F_attn_bufs[li]
|
|
|
|
|
|
|
|
|
|
# 1. mHC post_block (attn)
|
|
|
|
|
B_l_a = self.ctx_a_B_bufs[li]
|
|
|
|
|
C_l_a = self.ctx_a_C_bufs[li]
|
|
|
|
|
BX_a = torch.bmm(B_l_a.transpose(-1, -2), X_mid.float())
|
|
|
|
|
CF_a = C_l_a.unsqueeze(-1) * F_attn.unsqueeze(1)
|
|
|
|
|
X_mid_out = (CF_a.float() + BX_a).to(X_mid.dtype)
|
|
|
|
|
|
|
|
|
|
# 2. FFN mHC pre_block — fused P5
|
|
|
|
|
A_l_f, B_l_f, C_l_f = ffn_mhc._dynamic_params(X_mid_out)
|
|
|
|
|
x_quant_ffn = mhc_rmsnorm_quantize_nvfp4(
|
|
|
|
|
X_mid_out, A_l_f, ffn_norm_w.to(dev, torch.float32))
|
|
|
|
|
x_ffn = dequantize_nvfp4(x_quant_ffn.x_fp4, x_quant_ffn.x_sf, x_quant_ffn.gsa)
|
|
|
|
|
|
|
|
|
|
# 3. Router + MoE + SE
|
|
|
|
|
token_id_dev = dec_tid32_per_gpu[gpu]
|
|
|
|
|
topk_w, topk_ids = routers.get(li)(x_ffn, token_ids=token_id_dev)
|
|
|
|
|
routed_out = moe_runners.get(li).run(x_ffn, topk_w, topk_ids)
|
|
|
|
|
shared_out = se_runners.get(li).run(x_ffn)
|
|
|
|
|
F_ffn = routed_out + shared_out
|
|
|
|
|
|
|
|
|
|
# 4. mHC post_block (ffn)
|
|
|
|
|
BX_f = torch.bmm(B_l_f.transpose(-1, -2), X_mid_out.float())
|
|
|
|
|
CF_f = C_l_f.unsqueeze(-1) * F_ffn.unsqueeze(1)
|
|
|
|
|
X_next = (CF_f.float() + BX_f).to(X_mid.dtype)
|
|
|
|
|
|
|
|
|
|
self.x_out_bufs[li].copy_(X_next)
|
|
|
|
|
|
|
|
|
|
self.graphs_b[li] = graph_b
|
|
|
|
|
|
|
|
|
|
self.graphs[li] = graph
|
|
|
|
|
if (li + 1) % 10 == 0:
|
|
|
|
|
print(f" Captured {li+1}/{self.n_layers} layer graphs", flush=True)
|
|
|
|
|
print(f" Captured {li+1}/{self.n_layers} layer A/B graphs", flush=True)
|
|
|
|
|
|
|
|
|
|
# ---- Capture hc_head + norm + lm_head on cuda:0 ----
|
|
|
|
|
torch.cuda.set_device(0)
|
|
|
|
|
@@ -236,7 +359,8 @@ class CUDAGraphDecoder:
|
|
|
|
|
self.logits_buf.copy_(logits)
|
|
|
|
|
|
|
|
|
|
self.captured = True
|
|
|
|
|
print(f" Captured {len(self.graphs)} layer graphs + lm_head", flush=True)
|
|
|
|
|
print(f" Captured {len(self.graphs_a)} layer A/B graph pairs + lm_head", flush=True)
|
|
|
|
|
|
|
|
|
|
# =====================================================================
|
|
|
|
|
def dequant_nvfp4(weight, weight_scale, weight_scale_2=None, input_scale=None):
|
|
|
|
|
O, I2 = weight.shape; I = I2 * 2
|
|
|
|
|
@@ -878,7 +1002,8 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
|
|
|
|
kv_cache, positions, compressor, indexer, prod_lin,
|
|
|
|
|
x_quant=None,
|
|
|
|
|
_profile_detail=False, _profile_times=None,
|
|
|
|
|
comp_rope_cos=None, comp_rope_sin=None):
|
|
|
|
|
comp_rope_cos=None, comp_rope_sin=None,
|
|
|
|
|
q_heads=None, kv_3d=None):
|
|
|
|
|
dev = x_normed.device; T = x_normed.shape[0]
|
|
|
|
|
n_h = cfg["num_attention_heads"]; hd = cfg["head_dim"]; rd = cfg.get("qk_rope_head_dim", 64)
|
|
|
|
|
o_groups = cfg.get("o_groups", 16); o_rank = cfg.get("o_lora_rank", 1024)
|
|
|
|
|
@@ -895,40 +1020,43 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
|
|
|
|
|
|
|
|
|
|
_pt('q_a_start')
|
|
|
|
|
# 1. Q: q_a (NVFP4 GEMM) → q_a_norm → q_b (NVFP4 GEMM) → q_b_norm
|
|
|
|
|
q_a = prod_lin['q_a'].run_from_quantized(x_quant) if x_quant is not None else prod_lin['q_a'](x_normed)
|
|
|
|
|
_pt('q_a_end')
|
|
|
|
|
if VERBOSE >= 2 and li < 3:
|
|
|
|
|
# Compare q_a with PyTorch reference
|
|
|
|
|
q_a_ref = do_nvfp4_linear_ref(x_normed, w, pfx, 'q_a_proj')
|
|
|
|
|
if q_a_ref is not None:
|
|
|
|
|
cos_qa = torch.nn.functional.cosine_similarity(q_a.flatten().float(), q_a_ref.flatten().float(), dim=0).item()
|
|
|
|
|
print(f" L{li} q_a: |prod|={q_a.abs().max().item():.6f} |ref|={q_a_ref.abs().max().item():.6f} cos={cos_qa:.6f}", flush=True)
|
|
|
|
|
q_norm_w = w.get(f"{pfx}.q_a_norm.weight")
|
|
|
|
|
# B3: Fused rmsnorm+quant for q_a_norm → q_b path
|
|
|
|
|
# Replaces: rmsnorm(q_a, w) → BF16 → q_b quantizes internally
|
|
|
|
|
# With: fused rmsnorm+NVFP4 quantize → QuantizedActivation → q_b.run_from_quantized
|
|
|
|
|
# Saves: ~6 kernel launches per layer (rmsnorm 4+ + quantize 2 vs fused 2)
|
|
|
|
|
if q_norm_w is not None:
|
|
|
|
|
from dsv4.ops.quantize import rmsnorm_quantize_nvfp4 as _rmsnorm_quantize, dequantize_nvfp4 as _dequantize_nvfp4
|
|
|
|
|
q_a_quant = _rmsnorm_quantize(q_a, q_norm_w.to(dev, torch.float32))
|
|
|
|
|
q_a = _dequantize_nvfp4(q_a_quant.x_fp4, q_a_quant.x_sf, q_a_quant.gsa)
|
|
|
|
|
_pt('q_b_start')
|
|
|
|
|
if q_norm_w is not None:
|
|
|
|
|
q = prod_lin['q_b'].run_from_quantized(q_a_quant)
|
|
|
|
|
else:
|
|
|
|
|
q = prod_lin['q_b'](q_a)
|
|
|
|
|
q = unweighted_rmsnorm(q).bfloat16()
|
|
|
|
|
_pt('q_b_end')
|
|
|
|
|
q_heads = q.reshape(T, n_h, hd); q_heads = _apply_rope(q_heads, positions, rope_cos, rope_sin, rd)
|
|
|
|
|
# When q_heads is provided (from CUDA graph A), skip projections — only apply RoPE
|
|
|
|
|
if q_heads is None:
|
|
|
|
|
q_a = prod_lin['q_a'].run_from_quantized(x_quant) if x_quant is not None else prod_lin['q_a'](x_normed)
|
|
|
|
|
_pt('q_a_end')
|
|
|
|
|
if VERBOSE >= 2 and li < 3:
|
|
|
|
|
# Compare q_a with PyTorch reference
|
|
|
|
|
q_a_ref = do_nvfp4_linear_ref(x_normed, w, pfx, 'q_a_proj')
|
|
|
|
|
if q_a_ref is not None:
|
|
|
|
|
cos_qa = torch.nn.functional.cosine_similarity(q_a.flatten().float(), q_a_ref.flatten().float(), dim=0).item()
|
|
|
|
|
print(f" L{li} q_a: |prod|={q_a.abs().max().item():.6f} |ref|={q_a_ref.abs().max().item():.6f} cos={cos_qa:.6f}", flush=True)
|
|
|
|
|
q_norm_w = w.get(f"{pfx}.q_a_norm.weight")
|
|
|
|
|
# B3: Fused rmsnorm+quant for q_a_norm → q_b path
|
|
|
|
|
if q_norm_w is not None:
|
|
|
|
|
from dsv4.ops.quantize import rmsnorm_quantize_nvfp4 as _rmsnorm_quantize, dequantize_nvfp4 as _dequantize_nvfp4
|
|
|
|
|
q_a_quant = _rmsnorm_quantize(q_a, q_norm_w.to(dev, torch.float32))
|
|
|
|
|
q_a = _dequantize_nvfp4(q_a_quant.x_fp4, q_a_quant.x_sf, q_a_quant.gsa)
|
|
|
|
|
_pt('q_b_start')
|
|
|
|
|
if q_norm_w is not None:
|
|
|
|
|
q = prod_lin['q_b'].run_from_quantized(q_a_quant)
|
|
|
|
|
else:
|
|
|
|
|
q = prod_lin['q_b'](q_a)
|
|
|
|
|
q = unweighted_rmsnorm(q).bfloat16()
|
|
|
|
|
_pt('q_b_end')
|
|
|
|
|
q_heads = q.reshape(T, n_h, hd)
|
|
|
|
|
q_heads = _apply_rope(q_heads, positions, rope_cos, rope_sin, rd)
|
|
|
|
|
_pt('rope_q_end')
|
|
|
|
|
|
|
|
|
|
# 2. KV (NVFP4 GEMM, MQA, single KV head)
|
|
|
|
|
# When kv_3d is provided (from CUDA graph A), skip projections — only apply RoPE
|
|
|
|
|
_pt('kv_start')
|
|
|
|
|
kv = prod_lin['kv'].run_from_quantized(x_quant) if x_quant is not None else prod_lin['kv'](x_normed)
|
|
|
|
|
_pt('kv_end')
|
|
|
|
|
kv_norm_w = w.get(f"{pfx}.kv_norm.weight")
|
|
|
|
|
if kv_norm_w is not None: kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32))
|
|
|
|
|
kv_3d = kv.reshape(T, 1, hd); kv_3d = _apply_rope(kv_3d, positions, rope_cos, rope_sin, rd)
|
|
|
|
|
if kv_3d is None:
|
|
|
|
|
kv = prod_lin['kv'].run_from_quantized(x_quant) if x_quant is not None else prod_lin['kv'](x_normed)
|
|
|
|
|
_pt('kv_end')
|
|
|
|
|
kv_norm_w = w.get(f"{pfx}.kv_norm.weight")
|
|
|
|
|
if kv_norm_w is not None: kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32))
|
|
|
|
|
kv_3d = kv.reshape(T, 1, hd)
|
|
|
|
|
kv_3d = _apply_rope(kv_3d, positions, rope_cos, rope_sin, rd)
|
|
|
|
|
_pt('rope_kv_end')
|
|
|
|
|
kv_roped = kv_3d.reshape(T, hd); kv_cache.append_swa(kv_roped, positions)
|
|
|
|
|
|
|
|
|
|
@@ -1670,7 +1798,7 @@ def main():
|
|
|
|
|
graph_decoder = None
|
|
|
|
|
if _args.cuda_graph:
|
|
|
|
|
print(" CUDA graph capture requested — will capture after warmup step")
|
|
|
|
|
graph_decoder = CUDAGraphDecoder(n_layers, NUM_GPUS, H, [f'cuda:{g}' for g in range(NUM_GPUS)])
|
|
|
|
|
graph_decoder = CUDAGraphDecoder(n_layers, NUM_GPUS, H, [f'cuda:{g}' for g in range(NUM_GPUS)], cfg)
|
|
|
|
|
graph_decoder.pre_allocate(cfg)
|
|
|
|
|
|
|
|
|
|
for step in range(MAX_NEW_TOKENS):
|
|
|
|
|
@@ -1692,18 +1820,41 @@ def main():
|
|
|
|
|
|
|
|
|
|
# ---- Forward: graph replay or eager ----
|
|
|
|
|
if graph_decoder is not None and graph_decoder.captured:
|
|
|
|
|
# CUDA graph replay path — one graph per layer
|
|
|
|
|
# CUDA graph replay path — A/B split with eager attention
|
|
|
|
|
for li in range(n_layers):
|
|
|
|
|
gpu = li % NUM_GPUS
|
|
|
|
|
torch.cuda.set_device(gpu)
|
|
|
|
|
dev = f'cuda:{gpu}'
|
|
|
|
|
|
|
|
|
|
# Copy X into graph input buffer (copy_ handles cross-GPU transfer)
|
|
|
|
|
# Copy X into graph A input buffer (copy_ handles cross-GPU transfer)
|
|
|
|
|
graph_decoder.x_in_bufs[li].copy_(X)
|
|
|
|
|
|
|
|
|
|
# Replay layer graph
|
|
|
|
|
graph_decoder.graphs[li].replay()
|
|
|
|
|
# Replay graph A: mHC pre_block + RMSNorm + q_a/q_b/kv projections
|
|
|
|
|
graph_decoder.graphs_a[li].replay()
|
|
|
|
|
|
|
|
|
|
# Read output from graph
|
|
|
|
|
# ---- Eager attention (NOT captured) ----
|
|
|
|
|
# Read graph A outputs from pre-allocated buffers
|
|
|
|
|
x_normed = graph_decoder.x_normed_bufs[li]
|
|
|
|
|
q_heads = graph_decoder.q_heads_bufs[li]
|
|
|
|
|
kv_3d = graph_decoder.kv_3d_bufs[li]
|
|
|
|
|
|
|
|
|
|
# Run full attention eagerly (compressor + indexer + FMHA + o_proj)
|
|
|
|
|
F_attn, _ = forward_attention(
|
|
|
|
|
x_normed, layer_w[li], li, cfg, *rope_caches[gpu],
|
|
|
|
|
kv_caches[li], dec_pos_per_gpu[gpu], dec_tid32_per_gpu[gpu],
|
|
|
|
|
compressors.get(li), indexers.get(li), prod_lins.get(li),
|
|
|
|
|
q_heads=q_heads, kv_3d=kv_3d, # pass pre-computed q/kv from graph A
|
|
|
|
|
comp_rope_cos=comp_rope_caches[gpu][0] if comp_rope_caches else None,
|
|
|
|
|
comp_rope_sin=comp_rope_caches[gpu][1] if comp_rope_caches else None,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Write F_attn to graph B input buffer
|
|
|
|
|
graph_decoder.F_attn_bufs[li].copy_(F_attn)
|
|
|
|
|
|
|
|
|
|
# Replay graph B: mHC post_block + FFN + MoE + SE
|
|
|
|
|
graph_decoder.graphs_b[li].replay()
|
|
|
|
|
|
|
|
|
|
# Read output from graph B
|
|
|
|
|
X = graph_decoder.x_out_bufs[li]
|
|
|
|
|
|
|
|
|
|
# Transfer last layer output to cuda:0 for lm_head graph
|
|
|
|
|
|