Restore A/B split + gsa scalar fix (error is pre-existing, not regression)

This commit is contained in:
2026-06-04 01:03:36 +00:00
parent 59eccd04ab
commit 55def5eef9
5 changed files with 250 additions and 87 deletions

View File

@@ -338,10 +338,12 @@ class Nvfp4GroupedLinear:
# gsa_gpu is (G*T,) — all rows share same amax (from max over full tensor)
# For the GEMM's global_scale_a, fill all group slots with the same gsa value
# Use GPU-only copy: no .item(), no CPU sync
self._gsa_buf[:1].copy_(gsa_gpu[:1]) # GPU→GPU scalar copy, no sync
self._gsa_buf[0] = gsa_gpu[0] # scalar GPU→GPU, no sync, graph-capturable
# Broadcast to all groups (all get same gsa)
# Use scalar broadcast assignment instead of copy_ from expanded view
# (expanded views can cause cudaErrorInvalidValue in copy_)
if self.n_local_groups > 1:
self._gsa_buf[1:].copy_(self._gsa_buf[:1].expand(self.n_local_groups - 1))
self._gsa_buf[1:] = self._gsa_buf[0] # scalar broadcast, graph-capturable
else:
self._gsa_buf.fill_(self._activation_global_scale)
x_fp4_flat, x_sf_flat = quantize_activation_nvfp4(

View File

@@ -206,7 +206,7 @@ class Nvfp4Linear:
if getattr(self, '_use_runtime_gsa', False):
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
x_fp4, x_sf, gsa_gpu = quantize_nvfp4_gpu_fused(hidden_states)
self._gsa_buf.copy_(gsa_gpu[:1].reshape(1)) # GPU → GPU, no sync
self._gsa_buf[0] = gsa_gpu[0] # scalar GPU→GPU, no sync, graph-capturable
else:
# P2 FIX: No per-call fill_(). The _gsa_buf already has the correct
# value — set either during initialization (via _ensure_buffer_size)
@@ -284,13 +284,10 @@ class Nvfp4Linear:
# For M=1 decode: per-row gsa is already scalar, no reduction needed.
# For M>1 prefill: reduce per-row gsa to a single scalar (max).
if quant.gsa.shape[0] == 1:
gsa = quant.gsa[:1].reshape(1) # Already scalar
self._gsa_buf[0] = quant.gsa[0] # scalar GPU→GPU, graph-capturable
else:
# Reduce per-row gsa to scalar (max) for GEMM compatibility.
# Per-row gsa is mathematically more precise, but the GEMM only
# supports a single global scale per expert.
gsa = quant.gsa.max().reshape(1)
self._gsa_buf.copy_(gsa)
self._gsa_buf[0] = quant.gsa.max() # GPU max, scalar assign, graph-capturable
# Run GEMM
out = run_nvfp4_grouped_gemm(

View File

@@ -630,7 +630,7 @@ class Nvfp4MoE:
if getattr(self, '_use_runtime_gsa', False):
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
slot_x_fp4, slot_x_sf, gsa_l1_gpu = quantize_nvfp4_gpu_fused(slot_hidden)
self._l1_gsa_buf.copy_(gsa_l1_gpu[:1].reshape(1)) # GPU → GPU, no sync
self._l1_gsa_buf[0] = gsa_l1_gpu[0] # scalar GPU→GPU, no sync, graph-capturable
else:
slot_x_fp4, slot_x_sf = quantize_nvfp4_gpu(
slot_hidden, self._l1_activation_global_scale
@@ -666,7 +666,7 @@ class Nvfp4MoE:
from dsv4.ops.quantize import deinterleave_amax_quantize_nvfp4_fused
slot_l2_x_fp4, slot_l2_x_sf, gsa_l2_gpu = deinterleave_amax_quantize_nvfp4_fused(
l1_out_real, self.intermediate_size)
self._l2_gsa_buf.copy_(gsa_l2_gpu[:1].reshape(1)) # GPU → GPU, no sync
self._l2_gsa_buf[0] = gsa_l2_gpu[0] # scalar GPU→GPU, no sync, graph-capturable
else:
slot_l2_x_fp4, slot_l2_x_sf = deinterleave_quantize_nvfp4_cuda(
l1_out_real, self.intermediate_size, self._l2_activation_global_scale
@@ -694,7 +694,7 @@ class Nvfp4MoE:
if not self._fused_swiglu and getattr(self, '_use_runtime_gsa', False):
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
slot_l2_x_fp4, slot_l2_x_sf, gsa_l2_gpu = quantize_nvfp4_gpu_fused(activated)
self._l2_gsa_buf.copy_(gsa_l2_gpu[:1].reshape(1)) # GPU → GPU, no sync
self._l2_gsa_buf[0] = gsa_l2_gpu[0] # scalar GPU→GPU, no sync, graph-capturable
elif not self._fused_swiglu:
slot_l2_x_fp4, slot_l2_x_sf = quantize_nvfp4_gpu(
activated, self._l2_activation_global_scale

View File

@@ -268,7 +268,7 @@ class Nvfp4SharedExpert:
if getattr(self, '_use_runtime_gsa', False):
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
x_fp4, x_sf, gsa_l1_gpu = quantize_nvfp4_gpu_fused(x_bf16)
self._l1_gsa_buf.copy_(gsa_l1_gpu[:1].reshape(1)) # GPU → GPU
self._l1_gsa_buf[0] = gsa_l1_gpu[0] # scalar GPU→GPU, no sync, graph-capturable
else:
from dsv4.ops.quantize import quantize_activation_nvfp4
x_fp4, x_sf = quantize_activation_nvfp4(x_bf16, self._l1_activation_global_scale)
@@ -316,7 +316,7 @@ class Nvfp4SharedExpert:
if getattr(self, '_use_runtime_gsa', False):
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
x_fp4, x_sf, gsa_l1_gpu = quantize_nvfp4_gpu_fused(hidden_states)
self._l1_gsa_buf.copy_(gsa_l1_gpu[:1].reshape(1)) # GPU → GPU, no sync
self._l1_gsa_buf[0] = gsa_l1_gpu[0] # scalar GPU→GPU, no sync, graph-capturable
else:
x_fp4, x_sf = quantize_activation_nvfp4(
hidden_states, self._l1_activation_global_scale
@@ -363,8 +363,21 @@ class Nvfp4SharedExpert:
# Fused amax + quantize: zero CPU syncs.
if getattr(self, '_use_runtime_gsa', False):
from dsv4.ops.quantize import quantize_nvfp4_gpu_fused
x_fp4, x_sf, gsa_l2_gpu = quantize_nvfp4_gpu_fused(intermediate)
self._l2_gsa_buf.copy_(gsa_l2_gpu[:1].reshape(1)) # GPU → GPU, no sync
if not intermediate.is_contiguous():
intermediate = intermediate.contiguous()
# DEBUG: sync before quantize to isolate which kernel fails
torch.cuda.synchronize()
try:
x_fp4, x_sf, gsa_l2_gpu = quantize_nvfp4_gpu_fused(intermediate)
except RuntimeError as e:
print(f" SE L2 quantize FAILED: {e}", flush=True)
print(f" intermediate: shape={tuple(intermediate.shape)} dtype={intermediate.dtype} dev={intermediate.device} contiguous={intermediate.is_contiguous()}", flush=True)
raise
torch.cuda.synchronize() # DEBUG: catch async errors from quantize
# Copy first element of gsa (scalar for single-expert) to pre-allocated buffer.
# Using scalar assignment avoids copy_() from view which caused cudaErrorInvalidValue
# on non-contiguous gsa_gpu slices (gsa_gpu[:1].reshape(1) — view of expanded tensor).
self._l2_gsa_buf[0] = gsa_l2_gpu[0] # scalar GPU → GPU, no sync, graph-capturable
else:
x_fp4, x_sf = quantize_activation_nvfp4(
intermediate, self._l2_activation_global_scale

View File

@@ -134,39 +134,69 @@ def unweighted_rmsnorm(x, eps=1e-6):
class CUDAGraphDecoder:
"""Captures and replays CUDA graphs for the decode loop.
Architecture: One graph per layer, capturing the entire forward_layer.
After one warmup step (which also fixes gsa values), each layer's
forward is captured as a single CUDA graph. Replay eliminates Python
dispatch overhead (~94ms for 61 layers) and kernel launch latency.
Architecture (Phase 1: eager-break-at-attention):
Each layer is split into two graph-captured sub-regions with eager attention
in between:
Graph A (pre-attention): mHC pre_block(attn) + fused RMSNorm + quantize
+ q_a + q_a_norm + q_b + kv projections
→ writes x_normed, q_heads, kv_3d, ctx_a to
pre-allocated buffers for eager attention
Eager (attention): Compressor → Indexer → KV gather → FMHA
→ inverse RoPE → o_a + o_b → F_attn
→ writes F_attn to pre-allocated buffer
Graph B (post-attention): mHC post_block(attn) + mHC pre_block(ffn)
+ fused RMSNorm + quantize + Router + MoE + SE
+ mHC post_block(ffn)
→ writes X_next to pre-allocated output buffer
The attention path (compressor, FMHA, inverse RoPE) has dynamic shapes
and data-dependent control flow — it MUST run eagerly.
The compute path has fixed shapes for T=1 decode — it CAN be captured.
The hc_head + norm + lm_head are captured as a separate graph on cuda:0.
Cross-GPU transfers (X.to(cuda:N)) happen OUTSIDE graphs between layers.
Constraints:
- All tensors must have fixed addresses (pre-allocated)
- No CPU-GPU syncs inside the graph
- All tensors in captured regions must have fixed addresses (pre-allocated)
- No CPU-GPU syncs inside captured regions
- The only per-step sync is argmax for sampling (outside graph)
- FMHA pads KV to 128 → fixed shape for graph capture
- Compressor returns None on non-boundary steps → graph captures the
path taken during warmup (typically the None path for HCA r=128)
- Attention runs eagerly — dynamic shapes are OK there
"""
def __init__(self, n_layers, num_gpus, hidden_size, devices):
def __init__(self, n_layers, num_gpus, hidden_size, devices, cfg):
self.n_layers = n_layers
self.num_gpus = num_gpus
self.hidden_size = hidden_size
self.devices = devices
self.captured = False
# One graph per layer + lm_head
self.graphs = {} # li -> torch.cuda.CUDAGraph
self.lm_graph = None # single graph for hc_head + norm + lm_head on cuda:0
# Model dimensions for buffer pre-allocation
self.n_h = cfg.get("num_attention_heads", 128)
self.hd = cfg.get("head_dim", 512)
self.rd = cfg.get("qk_rope_head_dim", 64)
# Two graphs per layer (A: pre-attn, B: post-attn+FFN) + lm_head
self.graphs_a = {} # li -> torch.cuda.CUDAGraph
self.graphs_b = {} # li -> torch.cuda.CUDAGraph
self.lm_graph = None # single graph for hc_head + norm + lm_head on cuda:0
# Pre-allocated I/O buffers — fixed addresses for graph capture
self.x_in_bufs = {} # li -> (1, 4, H) BF16 on layer's device
self.x_out_bufs = {} # li -> (1, 4, H) BF16 on layer's device
# Graph A output buffers (read by eager attention, written by graph A)
# These survive across the graph A → eager → graph B boundary.
self.x_normed_bufs = {} # li -> (1, H) BF16 — for compressor/indexer
self.q_heads_bufs = {} # li -> (1, n_h, hd) BF16 — for FMHA
self.kv_3d_bufs = {} # li -> (1, 1, hd) BF16 — for FMHA (pre-RoPE)
self.ctx_a_B_bufs = {} # li -> (1, 4, 4) FP32 — B_l for post_block
self.ctx_a_C_bufs = {} # li -> (1, 4) BF16 — C_l for post_block
self.X_mid_bufs = {} # li -> (1, 4, H) BF16 — X_l for post_block
# Graph B input buffer (written by eager attention, read by graph B)
self.F_attn_bufs = {} # li -> (1, H) BF16 — attention output for post_block
# lm_head graph buffers (on cuda:0)
self.x_lm_in = None # (1, 4, H) BF16 on cuda:0
self.logits_buf = None # (1, vocab_size) BF16 on cuda:0
@@ -175,11 +205,22 @@ class CUDAGraphDecoder:
"""Pre-allocate all I/O buffers with fixed addresses."""
H = self.hidden_size
V = cfg.get("vocab_size", 129280)
n_h = self.n_h
hd = self.hd
for li in range(self.n_layers):
dev = self.devices[li % self.num_gpus]
self.x_in_bufs[li] = torch.zeros(1, 4, H, dtype=torch.bfloat16, device=dev)
self.x_out_bufs[li] = torch.zeros(1, 4, H, dtype=torch.bfloat16, device=dev)
# Graph A intermediates
self.x_normed_bufs[li] = torch.zeros(1, H, dtype=torch.bfloat16, device=dev)
self.q_heads_bufs[li] = torch.zeros(1, n_h, hd, dtype=torch.bfloat16, device=dev)
self.kv_3d_bufs[li] = torch.zeros(1, 1, hd, dtype=torch.bfloat16, device=dev)
self.ctx_a_B_bufs[li] = torch.zeros(1, 4, 4, dtype=torch.float32, device=dev)
self.ctx_a_C_bufs[li] = torch.zeros(1, 4, dtype=torch.bfloat16, device=dev)
self.X_mid_bufs[li] = torch.zeros(1, 4, H, dtype=torch.bfloat16, device=dev)
# Graph B input
self.F_attn_bufs[li] = torch.zeros(1, H, dtype=torch.bfloat16, device=dev)
# lm_head graph I/O (cuda:0 only)
self.x_lm_in = torch.zeros(1, 4, H, dtype=torch.bfloat16, device='cuda:0')
@@ -189,41 +230,123 @@ class CUDAGraphDecoder:
kv_caches, compressors, indexers, moe_runners, se_runners,
routers, prod_lins, layer_w, rope_caches, hc_head,
final_norm_w, lm_w, dec_pos_per_gpu, dec_tid32_per_gpu, comp_rope_caches=None):
"""Capture CUDA graphs for all layers + lm_head.
"""Capture CUDA graphs for all layers (A/B split) + lm_head.
Phase 1: eager-break-at-attention. Graphs A/B capture the compute-heavy
path; the attention path runs eagerly between A and B replays.
Must be called after one warmup step so that:
1. All CuTeDSL kernels are compiled and cached
2. gsa values are fixed (from warmup_gsa)
3. CUDA kernels are warmed up (first launch is often slower)
"""
H = self.hidden_size
from dsv4.ops.quantize import (
mhc_rmsnorm_quantize_nvfp4, dequantize_nvfp4,
rmsnorm_quantize_nvfp4 as _rmsnorm_quantize,
)
from dsv4.layers.mhc import mHCContext
print(" Capturing CUDA graphs for decode (1 graph per layer)...", flush=True)
H = self.hidden_size
n_h = self.n_h
hd = self.hd
rd = self.rd
print(" Capturing CUDA graphs (A/B split: compute captured, attention eager)...", flush=True)
for li in range(self.n_layers):
gpu = li % self.num_gpus
dev = self.devices[gpu]
torch.cuda.set_device(gpu)
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
X_out = forward_layer(
self.x_in_bufs[li], layer_w[li], li, cfg, *rope_caches[gpu],
attn_mhcs.get(li), ffn_mhcs.get(li),
attn_norms.get(li), ffn_norms.get(li),
kv_caches[li], dec_pos_per_gpu[gpu], dec_tid32_per_gpu[gpu],
compressors.get(li), indexers.get(li),
moe_runners.get(li), se_runners.get(li), routers.get(li),
prod_lin=prod_lins.get(li),
_use_fused_rmsnorm_quantize=True,
comp_rope_cos=comp_rope_caches[gpu][0] if comp_rope_caches else None,
comp_rope_sin=comp_rope_caches[gpu][1] if comp_rope_caches else None,
)
self.x_out_bufs[li].copy_(X_out)
attn_mhc = attn_mhcs.get(li)
ffn_mhc = ffn_mhcs.get(li)
attn_norm_w = attn_norms.get(li)
ffn_norm_w = ffn_norms.get(li)
pl = prod_lins.get(li, {})
pfx = f"model.layers.{li}.self_attn"
# ======== Graph A: pre-attention compute ========
# Input: X_l = self.x_in_bufs[li] (1, 4, H)
# Output: x_normed, q_heads, kv_3d, ctx_a, X_l → pre-allocated buffers
graph_a = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph_a):
X_l = self.x_in_bufs[li]
# 1. mHC pre_block (attn) — fused P5
A_l_a, B_l_a, C_l_a = attn_mhc._dynamic_params(X_l)
x_quant_attn = mhc_rmsnorm_quantize_nvfp4(
X_l, A_l_a, attn_norm_w.to(dev, torch.float32))
x_normed = dequantize_nvfp4(x_quant_attn.x_fp4, x_quant_attn.x_sf, x_quant_attn.gsa)
# 2. Attention projections
q_a = pl['q_a'].run_from_quantized(x_quant_attn)
q_norm_w = layer_w[li].get(f"{pfx}.q_a_norm.weight")
if q_norm_w is not None:
q_a_quant = _rmsnorm_quantize(q_a, q_norm_w.to(dev, torch.float32))
q_a = dequantize_nvfp4(q_a_quant.x_fp4, q_a_quant.x_sf, q_a_quant.gsa)
q = pl['q_b'].run_from_quantized(q_a_quant)
else:
q = pl['q_b'](q_a)
q = unweighted_rmsnorm(q).bfloat16()
# NOTE: RoPE is applied in the eager attention path (dynamic positions)
q_heads = q.reshape(1, n_h, hd)
kv = pl['kv'].run_from_quantized(x_quant_attn)
kv_norm_w_k = layer_w[li].get(f"{pfx}.kv_norm.weight")
if kv_norm_w_k is not None:
kv = rmsnorm(kv, kv_norm_w_k.to(dev, torch.float32))
kv_3d = kv.reshape(1, 1, hd)
# NOTE: RoPE is applied in the eager attention path
# Write to pre-allocated buffers for eager attention path
self.x_normed_bufs[li].copy_(x_normed)
self.q_heads_bufs[li].copy_(q_heads)
self.kv_3d_bufs[li].copy_(kv_3d)
self.ctx_a_B_bufs[li].copy_(B_l_a)
self.ctx_a_C_bufs[li].copy_(C_l_a)
self.X_mid_bufs[li].copy_(X_l)
self.graphs_a[li] = graph_a
# ======== Graph B: post-attention + FFN compute ========
# Input: X_mid = self.X_mid_bufs[li], F_attn = self.F_attn_bufs[li]
# Output: X_next → self.x_out_bufs[li]
graph_b = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph_b):
X_mid = self.X_mid_bufs[li]
F_attn = self.F_attn_bufs[li]
# 1. mHC post_block (attn)
B_l_a = self.ctx_a_B_bufs[li]
C_l_a = self.ctx_a_C_bufs[li]
BX_a = torch.bmm(B_l_a.transpose(-1, -2), X_mid.float())
CF_a = C_l_a.unsqueeze(-1) * F_attn.unsqueeze(1)
X_mid_out = (CF_a.float() + BX_a).to(X_mid.dtype)
# 2. FFN mHC pre_block — fused P5
A_l_f, B_l_f, C_l_f = ffn_mhc._dynamic_params(X_mid_out)
x_quant_ffn = mhc_rmsnorm_quantize_nvfp4(
X_mid_out, A_l_f, ffn_norm_w.to(dev, torch.float32))
x_ffn = dequantize_nvfp4(x_quant_ffn.x_fp4, x_quant_ffn.x_sf, x_quant_ffn.gsa)
# 3. Router + MoE + SE
token_id_dev = dec_tid32_per_gpu[gpu]
topk_w, topk_ids = routers.get(li)(x_ffn, token_ids=token_id_dev)
routed_out = moe_runners.get(li).run(x_ffn, topk_w, topk_ids)
shared_out = se_runners.get(li).run(x_ffn)
F_ffn = routed_out + shared_out
# 4. mHC post_block (ffn)
BX_f = torch.bmm(B_l_f.transpose(-1, -2), X_mid_out.float())
CF_f = C_l_f.unsqueeze(-1) * F_ffn.unsqueeze(1)
X_next = (CF_f.float() + BX_f).to(X_mid.dtype)
self.x_out_bufs[li].copy_(X_next)
self.graphs_b[li] = graph_b
self.graphs[li] = graph
if (li + 1) % 10 == 0:
print(f" Captured {li+1}/{self.n_layers} layer graphs", flush=True)
print(f" Captured {li+1}/{self.n_layers} layer A/B graphs", flush=True)
# ---- Capture hc_head + norm + lm_head on cuda:0 ----
torch.cuda.set_device(0)
@@ -236,7 +359,8 @@ class CUDAGraphDecoder:
self.logits_buf.copy_(logits)
self.captured = True
print(f" Captured {len(self.graphs)} layer graphs + lm_head", flush=True)
print(f" Captured {len(self.graphs_a)} layer A/B graph pairs + lm_head", flush=True)
# =====================================================================
def dequant_nvfp4(weight, weight_scale, weight_scale_2=None, input_scale=None):
O, I2 = weight.shape; I = I2 * 2
@@ -878,7 +1002,8 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
kv_cache, positions, compressor, indexer, prod_lin,
x_quant=None,
_profile_detail=False, _profile_times=None,
comp_rope_cos=None, comp_rope_sin=None):
comp_rope_cos=None, comp_rope_sin=None,
q_heads=None, kv_3d=None):
dev = x_normed.device; T = x_normed.shape[0]
n_h = cfg["num_attention_heads"]; hd = cfg["head_dim"]; rd = cfg.get("qk_rope_head_dim", 64)
o_groups = cfg.get("o_groups", 16); o_rank = cfg.get("o_lora_rank", 1024)
@@ -895,40 +1020,43 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
_pt('q_a_start')
# 1. Q: q_a (NVFP4 GEMM) → q_a_norm → q_b (NVFP4 GEMM) → q_b_norm
q_a = prod_lin['q_a'].run_from_quantized(x_quant) if x_quant is not None else prod_lin['q_a'](x_normed)
_pt('q_a_end')
if VERBOSE >= 2 and li < 3:
# Compare q_a with PyTorch reference
q_a_ref = do_nvfp4_linear_ref(x_normed, w, pfx, 'q_a_proj')
if q_a_ref is not None:
cos_qa = torch.nn.functional.cosine_similarity(q_a.flatten().float(), q_a_ref.flatten().float(), dim=0).item()
print(f" L{li} q_a: |prod|={q_a.abs().max().item():.6f} |ref|={q_a_ref.abs().max().item():.6f} cos={cos_qa:.6f}", flush=True)
q_norm_w = w.get(f"{pfx}.q_a_norm.weight")
# B3: Fused rmsnorm+quant for q_a_norm → q_b path
# Replaces: rmsnorm(q_a, w) → BF16 → q_b quantizes internally
# With: fused rmsnorm+NVFP4 quantize → QuantizedActivation → q_b.run_from_quantized
# Saves: ~6 kernel launches per layer (rmsnorm 4+ + quantize 2 vs fused 2)
if q_norm_w is not None:
from dsv4.ops.quantize import rmsnorm_quantize_nvfp4 as _rmsnorm_quantize, dequantize_nvfp4 as _dequantize_nvfp4
q_a_quant = _rmsnorm_quantize(q_a, q_norm_w.to(dev, torch.float32))
q_a = _dequantize_nvfp4(q_a_quant.x_fp4, q_a_quant.x_sf, q_a_quant.gsa)
_pt('q_b_start')
if q_norm_w is not None:
q = prod_lin['q_b'].run_from_quantized(q_a_quant)
else:
q = prod_lin['q_b'](q_a)
q = unweighted_rmsnorm(q).bfloat16()
_pt('q_b_end')
q_heads = q.reshape(T, n_h, hd); q_heads = _apply_rope(q_heads, positions, rope_cos, rope_sin, rd)
# When q_heads is provided (from CUDA graph A), skip projections — only apply RoPE
if q_heads is None:
q_a = prod_lin['q_a'].run_from_quantized(x_quant) if x_quant is not None else prod_lin['q_a'](x_normed)
_pt('q_a_end')
if VERBOSE >= 2 and li < 3:
# Compare q_a with PyTorch reference
q_a_ref = do_nvfp4_linear_ref(x_normed, w, pfx, 'q_a_proj')
if q_a_ref is not None:
cos_qa = torch.nn.functional.cosine_similarity(q_a.flatten().float(), q_a_ref.flatten().float(), dim=0).item()
print(f" L{li} q_a: |prod|={q_a.abs().max().item():.6f} |ref|={q_a_ref.abs().max().item():.6f} cos={cos_qa:.6f}", flush=True)
q_norm_w = w.get(f"{pfx}.q_a_norm.weight")
# B3: Fused rmsnorm+quant for q_a_norm → q_b path
if q_norm_w is not None:
from dsv4.ops.quantize import rmsnorm_quantize_nvfp4 as _rmsnorm_quantize, dequantize_nvfp4 as _dequantize_nvfp4
q_a_quant = _rmsnorm_quantize(q_a, q_norm_w.to(dev, torch.float32))
q_a = _dequantize_nvfp4(q_a_quant.x_fp4, q_a_quant.x_sf, q_a_quant.gsa)
_pt('q_b_start')
if q_norm_w is not None:
q = prod_lin['q_b'].run_from_quantized(q_a_quant)
else:
q = prod_lin['q_b'](q_a)
q = unweighted_rmsnorm(q).bfloat16()
_pt('q_b_end')
q_heads = q.reshape(T, n_h, hd)
q_heads = _apply_rope(q_heads, positions, rope_cos, rope_sin, rd)
_pt('rope_q_end')
# 2. KV (NVFP4 GEMM, MQA, single KV head)
# When kv_3d is provided (from CUDA graph A), skip projections — only apply RoPE
_pt('kv_start')
kv = prod_lin['kv'].run_from_quantized(x_quant) if x_quant is not None else prod_lin['kv'](x_normed)
_pt('kv_end')
kv_norm_w = w.get(f"{pfx}.kv_norm.weight")
if kv_norm_w is not None: kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32))
kv_3d = kv.reshape(T, 1, hd); kv_3d = _apply_rope(kv_3d, positions, rope_cos, rope_sin, rd)
if kv_3d is None:
kv = prod_lin['kv'].run_from_quantized(x_quant) if x_quant is not None else prod_lin['kv'](x_normed)
_pt('kv_end')
kv_norm_w = w.get(f"{pfx}.kv_norm.weight")
if kv_norm_w is not None: kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32))
kv_3d = kv.reshape(T, 1, hd)
kv_3d = _apply_rope(kv_3d, positions, rope_cos, rope_sin, rd)
_pt('rope_kv_end')
kv_roped = kv_3d.reshape(T, hd); kv_cache.append_swa(kv_roped, positions)
@@ -1670,7 +1798,7 @@ def main():
graph_decoder = None
if _args.cuda_graph:
print(" CUDA graph capture requested — will capture after warmup step")
graph_decoder = CUDAGraphDecoder(n_layers, NUM_GPUS, H, [f'cuda:{g}' for g in range(NUM_GPUS)])
graph_decoder = CUDAGraphDecoder(n_layers, NUM_GPUS, H, [f'cuda:{g}' for g in range(NUM_GPUS)], cfg)
graph_decoder.pre_allocate(cfg)
for step in range(MAX_NEW_TOKENS):
@@ -1692,18 +1820,41 @@ def main():
# ---- Forward: graph replay or eager ----
if graph_decoder is not None and graph_decoder.captured:
# CUDA graph replay path — one graph per layer
# CUDA graph replay path — A/B split with eager attention
for li in range(n_layers):
gpu = li % NUM_GPUS
torch.cuda.set_device(gpu)
dev = f'cuda:{gpu}'
# Copy X into graph input buffer (copy_ handles cross-GPU transfer)
# Copy X into graph A input buffer (copy_ handles cross-GPU transfer)
graph_decoder.x_in_bufs[li].copy_(X)
# Replay layer graph
graph_decoder.graphs[li].replay()
# Replay graph A: mHC pre_block + RMSNorm + q_a/q_b/kv projections
graph_decoder.graphs_a[li].replay()
# Read output from graph
# ---- Eager attention (NOT captured) ----
# Read graph A outputs from pre-allocated buffers
x_normed = graph_decoder.x_normed_bufs[li]
q_heads = graph_decoder.q_heads_bufs[li]
kv_3d = graph_decoder.kv_3d_bufs[li]
# Run full attention eagerly (compressor + indexer + FMHA + o_proj)
F_attn, _ = forward_attention(
x_normed, layer_w[li], li, cfg, *rope_caches[gpu],
kv_caches[li], dec_pos_per_gpu[gpu], dec_tid32_per_gpu[gpu],
compressors.get(li), indexers.get(li), prod_lins.get(li),
q_heads=q_heads, kv_3d=kv_3d, # pass pre-computed q/kv from graph A
comp_rope_cos=comp_rope_caches[gpu][0] if comp_rope_caches else None,
comp_rope_sin=comp_rope_caches[gpu][1] if comp_rope_caches else None,
)
# Write F_attn to graph B input buffer
graph_decoder.F_attn_bufs[li].copy_(F_attn)
# Replay graph B: mHC post_block + FFN + MoE + SE
graph_decoder.graphs_b[li].replay()
# Read output from graph B
X = graph_decoder.x_out_bufs[li]
# Transfer last layer output to cuda:0 for lm_head graph