From 55def5eef9ebf016fc7dbe274fa5eb504a411148 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 4 Jun 2026 01:03:36 +0000 Subject: [PATCH] Restore A/B split + gsa scalar fix (error is pre-existing, not regression) --- dsv4/layers/grouped_linear.py | 6 +- dsv4/layers/linear.py | 9 +- dsv4/layers/moe.py | 6 +- dsv4/layers/shared_expert.py | 21 ++- single_shot_inference.py | 295 +++++++++++++++++++++++++--------- 5 files changed, 250 insertions(+), 87 deletions(-) diff --git a/dsv4/layers/grouped_linear.py b/dsv4/layers/grouped_linear.py index 30c8a555..182308a5 100644 --- a/dsv4/layers/grouped_linear.py +++ b/dsv4/layers/grouped_linear.py @@ -338,10 +338,12 @@ class Nvfp4GroupedLinear: # gsa_gpu is (G*T,) — all rows share same amax (from max over full tensor) # For the GEMM's global_scale_a, fill all group slots with the same gsa value # Use GPU-only copy: no .item(), no CPU sync - self._gsa_buf[:1].copy_(gsa_gpu[:1]) # GPU→GPU scalar copy, no sync + self._gsa_buf[0] = gsa_gpu[0] # scalar GPU→GPU, no sync, graph-capturable # Broadcast to all groups (all get same gsa) + # Use scalar broadcast assignment instead of copy_ from expanded view + # (expanded views can cause cudaErrorInvalidValue in copy_) if self.n_local_groups > 1: - self._gsa_buf[1:].copy_(self._gsa_buf[:1].expand(self.n_local_groups - 1)) + self._gsa_buf[1:] = self._gsa_buf[0] # scalar broadcast, graph-capturable else: self._gsa_buf.fill_(self._activation_global_scale) x_fp4_flat, x_sf_flat = quantize_activation_nvfp4( diff --git a/dsv4/layers/linear.py b/dsv4/layers/linear.py index 32b53473..82f12aaa 100644 --- a/dsv4/layers/linear.py +++ b/dsv4/layers/linear.py @@ -206,7 +206,7 @@ class Nvfp4Linear: if getattr(self, '_use_runtime_gsa', False): from dsv4.ops.quantize import quantize_nvfp4_gpu_fused x_fp4, x_sf, gsa_gpu = quantize_nvfp4_gpu_fused(hidden_states) - self._gsa_buf.copy_(gsa_gpu[:1].reshape(1)) # GPU → GPU, no sync + self._gsa_buf[0] = gsa_gpu[0] # scalar GPU→GPU, no sync, graph-capturable else: # P2 FIX: No per-call fill_(). The _gsa_buf already has the correct # value — set either during initialization (via _ensure_buffer_size) @@ -284,13 +284,10 @@ class Nvfp4Linear: # For M=1 decode: per-row gsa is already scalar, no reduction needed. # For M>1 prefill: reduce per-row gsa to a single scalar (max). if quant.gsa.shape[0] == 1: - gsa = quant.gsa[:1].reshape(1) # Already scalar + self._gsa_buf[0] = quant.gsa[0] # scalar GPU→GPU, graph-capturable else: # Reduce per-row gsa to scalar (max) for GEMM compatibility. - # Per-row gsa is mathematically more precise, but the GEMM only - # supports a single global scale per expert. - gsa = quant.gsa.max().reshape(1) - self._gsa_buf.copy_(gsa) + self._gsa_buf[0] = quant.gsa.max() # GPU max, scalar assign, graph-capturable # Run GEMM out = run_nvfp4_grouped_gemm( diff --git a/dsv4/layers/moe.py b/dsv4/layers/moe.py index 1744e767..0dc0e89e 100644 --- a/dsv4/layers/moe.py +++ b/dsv4/layers/moe.py @@ -630,7 +630,7 @@ class Nvfp4MoE: if getattr(self, '_use_runtime_gsa', False): from dsv4.ops.quantize import quantize_nvfp4_gpu_fused slot_x_fp4, slot_x_sf, gsa_l1_gpu = quantize_nvfp4_gpu_fused(slot_hidden) - self._l1_gsa_buf.copy_(gsa_l1_gpu[:1].reshape(1)) # GPU → GPU, no sync + self._l1_gsa_buf[0] = gsa_l1_gpu[0] # scalar GPU→GPU, no sync, graph-capturable else: slot_x_fp4, slot_x_sf = quantize_nvfp4_gpu( slot_hidden, self._l1_activation_global_scale @@ -666,7 +666,7 @@ class Nvfp4MoE: from dsv4.ops.quantize import deinterleave_amax_quantize_nvfp4_fused slot_l2_x_fp4, slot_l2_x_sf, gsa_l2_gpu = deinterleave_amax_quantize_nvfp4_fused( l1_out_real, self.intermediate_size) - self._l2_gsa_buf.copy_(gsa_l2_gpu[:1].reshape(1)) # GPU → GPU, no sync + self._l2_gsa_buf[0] = gsa_l2_gpu[0] # scalar GPU→GPU, no sync, graph-capturable else: slot_l2_x_fp4, slot_l2_x_sf = deinterleave_quantize_nvfp4_cuda( l1_out_real, self.intermediate_size, self._l2_activation_global_scale @@ -694,7 +694,7 @@ class Nvfp4MoE: if not self._fused_swiglu and getattr(self, '_use_runtime_gsa', False): from dsv4.ops.quantize import quantize_nvfp4_gpu_fused slot_l2_x_fp4, slot_l2_x_sf, gsa_l2_gpu = quantize_nvfp4_gpu_fused(activated) - self._l2_gsa_buf.copy_(gsa_l2_gpu[:1].reshape(1)) # GPU → GPU, no sync + self._l2_gsa_buf[0] = gsa_l2_gpu[0] # scalar GPU→GPU, no sync, graph-capturable elif not self._fused_swiglu: slot_l2_x_fp4, slot_l2_x_sf = quantize_nvfp4_gpu( activated, self._l2_activation_global_scale diff --git a/dsv4/layers/shared_expert.py b/dsv4/layers/shared_expert.py index 4d32a548..9a9e2ae2 100644 --- a/dsv4/layers/shared_expert.py +++ b/dsv4/layers/shared_expert.py @@ -268,7 +268,7 @@ class Nvfp4SharedExpert: if getattr(self, '_use_runtime_gsa', False): from dsv4.ops.quantize import quantize_nvfp4_gpu_fused x_fp4, x_sf, gsa_l1_gpu = quantize_nvfp4_gpu_fused(x_bf16) - self._l1_gsa_buf.copy_(gsa_l1_gpu[:1].reshape(1)) # GPU → GPU + self._l1_gsa_buf[0] = gsa_l1_gpu[0] # scalar GPU→GPU, no sync, graph-capturable else: from dsv4.ops.quantize import quantize_activation_nvfp4 x_fp4, x_sf = quantize_activation_nvfp4(x_bf16, self._l1_activation_global_scale) @@ -316,7 +316,7 @@ class Nvfp4SharedExpert: if getattr(self, '_use_runtime_gsa', False): from dsv4.ops.quantize import quantize_nvfp4_gpu_fused x_fp4, x_sf, gsa_l1_gpu = quantize_nvfp4_gpu_fused(hidden_states) - self._l1_gsa_buf.copy_(gsa_l1_gpu[:1].reshape(1)) # GPU → GPU, no sync + self._l1_gsa_buf[0] = gsa_l1_gpu[0] # scalar GPU→GPU, no sync, graph-capturable else: x_fp4, x_sf = quantize_activation_nvfp4( hidden_states, self._l1_activation_global_scale @@ -363,8 +363,21 @@ class Nvfp4SharedExpert: # Fused amax + quantize: zero CPU syncs. if getattr(self, '_use_runtime_gsa', False): from dsv4.ops.quantize import quantize_nvfp4_gpu_fused - x_fp4, x_sf, gsa_l2_gpu = quantize_nvfp4_gpu_fused(intermediate) - self._l2_gsa_buf.copy_(gsa_l2_gpu[:1].reshape(1)) # GPU → GPU, no sync + if not intermediate.is_contiguous(): + intermediate = intermediate.contiguous() + # DEBUG: sync before quantize to isolate which kernel fails + torch.cuda.synchronize() + try: + x_fp4, x_sf, gsa_l2_gpu = quantize_nvfp4_gpu_fused(intermediate) + except RuntimeError as e: + print(f" SE L2 quantize FAILED: {e}", flush=True) + print(f" intermediate: shape={tuple(intermediate.shape)} dtype={intermediate.dtype} dev={intermediate.device} contiguous={intermediate.is_contiguous()}", flush=True) + raise + torch.cuda.synchronize() # DEBUG: catch async errors from quantize + # Copy first element of gsa (scalar for single-expert) to pre-allocated buffer. + # Using scalar assignment avoids copy_() from view which caused cudaErrorInvalidValue + # on non-contiguous gsa_gpu slices (gsa_gpu[:1].reshape(1) — view of expanded tensor). + self._l2_gsa_buf[0] = gsa_l2_gpu[0] # scalar GPU → GPU, no sync, graph-capturable else: x_fp4, x_sf = quantize_activation_nvfp4( intermediate, self._l2_activation_global_scale diff --git a/single_shot_inference.py b/single_shot_inference.py index 7d7dee61..4a37cae7 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -134,39 +134,69 @@ def unweighted_rmsnorm(x, eps=1e-6): class CUDAGraphDecoder: """Captures and replays CUDA graphs for the decode loop. - Architecture: One graph per layer, capturing the entire forward_layer. - After one warmup step (which also fixes gsa values), each layer's - forward is captured as a single CUDA graph. Replay eliminates Python - dispatch overhead (~94ms for 61 layers) and kernel launch latency. + Architecture (Phase 1: eager-break-at-attention): + Each layer is split into two graph-captured sub-regions with eager attention + in between: + + Graph A (pre-attention): mHC pre_block(attn) + fused RMSNorm + quantize + + q_a + q_a_norm + q_b + kv projections + → writes x_normed, q_heads, kv_3d, ctx_a to + pre-allocated buffers for eager attention + Eager (attention): Compressor → Indexer → KV gather → FMHA + → inverse RoPE → o_a + o_b → F_attn + → writes F_attn to pre-allocated buffer + Graph B (post-attention): mHC post_block(attn) + mHC pre_block(ffn) + + fused RMSNorm + quantize + Router + MoE + SE + + mHC post_block(ffn) + → writes X_next to pre-allocated output buffer + + The attention path (compressor, FMHA, inverse RoPE) has dynamic shapes + and data-dependent control flow — it MUST run eagerly. + The compute path has fixed shapes for T=1 decode — it CAN be captured. The hc_head + norm + lm_head are captured as a separate graph on cuda:0. - Cross-GPU transfers (X.to(cuda:N)) happen OUTSIDE graphs between layers. Constraints: - - All tensors must have fixed addresses (pre-allocated) - - No CPU-GPU syncs inside the graph + - All tensors in captured regions must have fixed addresses (pre-allocated) + - No CPU-GPU syncs inside captured regions - The only per-step sync is argmax for sampling (outside graph) - - FMHA pads KV to 128 → fixed shape for graph capture - - Compressor returns None on non-boundary steps → graph captures the - path taken during warmup (typically the None path for HCA r=128) + - Attention runs eagerly — dynamic shapes are OK there """ - def __init__(self, n_layers, num_gpus, hidden_size, devices): + def __init__(self, n_layers, num_gpus, hidden_size, devices, cfg): self.n_layers = n_layers self.num_gpus = num_gpus self.hidden_size = hidden_size self.devices = devices self.captured = False - # One graph per layer + lm_head - self.graphs = {} # li -> torch.cuda.CUDAGraph - self.lm_graph = None # single graph for hc_head + norm + lm_head on cuda:0 + # Model dimensions for buffer pre-allocation + self.n_h = cfg.get("num_attention_heads", 128) + self.hd = cfg.get("head_dim", 512) + self.rd = cfg.get("qk_rope_head_dim", 64) + + # Two graphs per layer (A: pre-attn, B: post-attn+FFN) + lm_head + self.graphs_a = {} # li -> torch.cuda.CUDAGraph + self.graphs_b = {} # li -> torch.cuda.CUDAGraph + self.lm_graph = None # single graph for hc_head + norm + lm_head on cuda:0 # Pre-allocated I/O buffers — fixed addresses for graph capture self.x_in_bufs = {} # li -> (1, 4, H) BF16 on layer's device self.x_out_bufs = {} # li -> (1, 4, H) BF16 on layer's device + # Graph A output buffers (read by eager attention, written by graph A) + # These survive across the graph A → eager → graph B boundary. + self.x_normed_bufs = {} # li -> (1, H) BF16 — for compressor/indexer + self.q_heads_bufs = {} # li -> (1, n_h, hd) BF16 — for FMHA + self.kv_3d_bufs = {} # li -> (1, 1, hd) BF16 — for FMHA (pre-RoPE) + self.ctx_a_B_bufs = {} # li -> (1, 4, 4) FP32 — B_l for post_block + self.ctx_a_C_bufs = {} # li -> (1, 4) BF16 — C_l for post_block + self.X_mid_bufs = {} # li -> (1, 4, H) BF16 — X_l for post_block + + # Graph B input buffer (written by eager attention, read by graph B) + self.F_attn_bufs = {} # li -> (1, H) BF16 — attention output for post_block + # lm_head graph buffers (on cuda:0) self.x_lm_in = None # (1, 4, H) BF16 on cuda:0 self.logits_buf = None # (1, vocab_size) BF16 on cuda:0 @@ -175,11 +205,22 @@ class CUDAGraphDecoder: """Pre-allocate all I/O buffers with fixed addresses.""" H = self.hidden_size V = cfg.get("vocab_size", 129280) + n_h = self.n_h + hd = self.hd for li in range(self.n_layers): dev = self.devices[li % self.num_gpus] self.x_in_bufs[li] = torch.zeros(1, 4, H, dtype=torch.bfloat16, device=dev) self.x_out_bufs[li] = torch.zeros(1, 4, H, dtype=torch.bfloat16, device=dev) + # Graph A intermediates + self.x_normed_bufs[li] = torch.zeros(1, H, dtype=torch.bfloat16, device=dev) + self.q_heads_bufs[li] = torch.zeros(1, n_h, hd, dtype=torch.bfloat16, device=dev) + self.kv_3d_bufs[li] = torch.zeros(1, 1, hd, dtype=torch.bfloat16, device=dev) + self.ctx_a_B_bufs[li] = torch.zeros(1, 4, 4, dtype=torch.float32, device=dev) + self.ctx_a_C_bufs[li] = torch.zeros(1, 4, dtype=torch.bfloat16, device=dev) + self.X_mid_bufs[li] = torch.zeros(1, 4, H, dtype=torch.bfloat16, device=dev) + # Graph B input + self.F_attn_bufs[li] = torch.zeros(1, H, dtype=torch.bfloat16, device=dev) # lm_head graph I/O (cuda:0 only) self.x_lm_in = torch.zeros(1, 4, H, dtype=torch.bfloat16, device='cuda:0') @@ -189,41 +230,123 @@ class CUDAGraphDecoder: kv_caches, compressors, indexers, moe_runners, se_runners, routers, prod_lins, layer_w, rope_caches, hc_head, final_norm_w, lm_w, dec_pos_per_gpu, dec_tid32_per_gpu, comp_rope_caches=None): - """Capture CUDA graphs for all layers + lm_head. + """Capture CUDA graphs for all layers (A/B split) + lm_head. + + Phase 1: eager-break-at-attention. Graphs A/B capture the compute-heavy + path; the attention path runs eagerly between A and B replays. Must be called after one warmup step so that: 1. All CuTeDSL kernels are compiled and cached 2. gsa values are fixed (from warmup_gsa) 3. CUDA kernels are warmed up (first launch is often slower) """ - H = self.hidden_size + from dsv4.ops.quantize import ( + mhc_rmsnorm_quantize_nvfp4, dequantize_nvfp4, + rmsnorm_quantize_nvfp4 as _rmsnorm_quantize, + ) + from dsv4.layers.mhc import mHCContext - print(" Capturing CUDA graphs for decode (1 graph per layer)...", flush=True) + H = self.hidden_size + n_h = self.n_h + hd = self.hd + rd = self.rd + + print(" Capturing CUDA graphs (A/B split: compute captured, attention eager)...", flush=True) for li in range(self.n_layers): gpu = li % self.num_gpus dev = self.devices[gpu] torch.cuda.set_device(gpu) - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph): - X_out = forward_layer( - self.x_in_bufs[li], layer_w[li], li, cfg, *rope_caches[gpu], - attn_mhcs.get(li), ffn_mhcs.get(li), - attn_norms.get(li), ffn_norms.get(li), - kv_caches[li], dec_pos_per_gpu[gpu], dec_tid32_per_gpu[gpu], - compressors.get(li), indexers.get(li), - moe_runners.get(li), se_runners.get(li), routers.get(li), - prod_lin=prod_lins.get(li), - _use_fused_rmsnorm_quantize=True, - comp_rope_cos=comp_rope_caches[gpu][0] if comp_rope_caches else None, - comp_rope_sin=comp_rope_caches[gpu][1] if comp_rope_caches else None, - ) - self.x_out_bufs[li].copy_(X_out) + attn_mhc = attn_mhcs.get(li) + ffn_mhc = ffn_mhcs.get(li) + attn_norm_w = attn_norms.get(li) + ffn_norm_w = ffn_norms.get(li) + pl = prod_lins.get(li, {}) + pfx = f"model.layers.{li}.self_attn" + + # ======== Graph A: pre-attention compute ======== + # Input: X_l = self.x_in_bufs[li] (1, 4, H) + # Output: x_normed, q_heads, kv_3d, ctx_a, X_l → pre-allocated buffers + graph_a = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph_a): + X_l = self.x_in_bufs[li] + + # 1. mHC pre_block (attn) — fused P5 + A_l_a, B_l_a, C_l_a = attn_mhc._dynamic_params(X_l) + x_quant_attn = mhc_rmsnorm_quantize_nvfp4( + X_l, A_l_a, attn_norm_w.to(dev, torch.float32)) + x_normed = dequantize_nvfp4(x_quant_attn.x_fp4, x_quant_attn.x_sf, x_quant_attn.gsa) + + # 2. Attention projections + q_a = pl['q_a'].run_from_quantized(x_quant_attn) + q_norm_w = layer_w[li].get(f"{pfx}.q_a_norm.weight") + if q_norm_w is not None: + q_a_quant = _rmsnorm_quantize(q_a, q_norm_w.to(dev, torch.float32)) + q_a = dequantize_nvfp4(q_a_quant.x_fp4, q_a_quant.x_sf, q_a_quant.gsa) + q = pl['q_b'].run_from_quantized(q_a_quant) + else: + q = pl['q_b'](q_a) + q = unweighted_rmsnorm(q).bfloat16() + # NOTE: RoPE is applied in the eager attention path (dynamic positions) + q_heads = q.reshape(1, n_h, hd) + + kv = pl['kv'].run_from_quantized(x_quant_attn) + kv_norm_w_k = layer_w[li].get(f"{pfx}.kv_norm.weight") + if kv_norm_w_k is not None: + kv = rmsnorm(kv, kv_norm_w_k.to(dev, torch.float32)) + kv_3d = kv.reshape(1, 1, hd) + # NOTE: RoPE is applied in the eager attention path + + # Write to pre-allocated buffers for eager attention path + self.x_normed_bufs[li].copy_(x_normed) + self.q_heads_bufs[li].copy_(q_heads) + self.kv_3d_bufs[li].copy_(kv_3d) + self.ctx_a_B_bufs[li].copy_(B_l_a) + self.ctx_a_C_bufs[li].copy_(C_l_a) + self.X_mid_bufs[li].copy_(X_l) + + self.graphs_a[li] = graph_a + + # ======== Graph B: post-attention + FFN compute ======== + # Input: X_mid = self.X_mid_bufs[li], F_attn = self.F_attn_bufs[li] + # Output: X_next → self.x_out_bufs[li] + graph_b = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph_b): + X_mid = self.X_mid_bufs[li] + F_attn = self.F_attn_bufs[li] + + # 1. mHC post_block (attn) + B_l_a = self.ctx_a_B_bufs[li] + C_l_a = self.ctx_a_C_bufs[li] + BX_a = torch.bmm(B_l_a.transpose(-1, -2), X_mid.float()) + CF_a = C_l_a.unsqueeze(-1) * F_attn.unsqueeze(1) + X_mid_out = (CF_a.float() + BX_a).to(X_mid.dtype) + + # 2. FFN mHC pre_block — fused P5 + A_l_f, B_l_f, C_l_f = ffn_mhc._dynamic_params(X_mid_out) + x_quant_ffn = mhc_rmsnorm_quantize_nvfp4( + X_mid_out, A_l_f, ffn_norm_w.to(dev, torch.float32)) + x_ffn = dequantize_nvfp4(x_quant_ffn.x_fp4, x_quant_ffn.x_sf, x_quant_ffn.gsa) + + # 3. Router + MoE + SE + token_id_dev = dec_tid32_per_gpu[gpu] + topk_w, topk_ids = routers.get(li)(x_ffn, token_ids=token_id_dev) + routed_out = moe_runners.get(li).run(x_ffn, topk_w, topk_ids) + shared_out = se_runners.get(li).run(x_ffn) + F_ffn = routed_out + shared_out + + # 4. mHC post_block (ffn) + BX_f = torch.bmm(B_l_f.transpose(-1, -2), X_mid_out.float()) + CF_f = C_l_f.unsqueeze(-1) * F_ffn.unsqueeze(1) + X_next = (CF_f.float() + BX_f).to(X_mid.dtype) + + self.x_out_bufs[li].copy_(X_next) + + self.graphs_b[li] = graph_b - self.graphs[li] = graph if (li + 1) % 10 == 0: - print(f" Captured {li+1}/{self.n_layers} layer graphs", flush=True) + print(f" Captured {li+1}/{self.n_layers} layer A/B graphs", flush=True) # ---- Capture hc_head + norm + lm_head on cuda:0 ---- torch.cuda.set_device(0) @@ -236,7 +359,8 @@ class CUDAGraphDecoder: self.logits_buf.copy_(logits) self.captured = True - print(f" Captured {len(self.graphs)} layer graphs + lm_head", flush=True) + print(f" Captured {len(self.graphs_a)} layer A/B graph pairs + lm_head", flush=True) + # ===================================================================== def dequant_nvfp4(weight, weight_scale, weight_scale_2=None, input_scale=None): O, I2 = weight.shape; I = I2 * 2 @@ -878,7 +1002,8 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, kv_cache, positions, compressor, indexer, prod_lin, x_quant=None, _profile_detail=False, _profile_times=None, - comp_rope_cos=None, comp_rope_sin=None): + comp_rope_cos=None, comp_rope_sin=None, + q_heads=None, kv_3d=None): dev = x_normed.device; T = x_normed.shape[0] n_h = cfg["num_attention_heads"]; hd = cfg["head_dim"]; rd = cfg.get("qk_rope_head_dim", 64) o_groups = cfg.get("o_groups", 16); o_rank = cfg.get("o_lora_rank", 1024) @@ -895,40 +1020,43 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, _pt('q_a_start') # 1. Q: q_a (NVFP4 GEMM) → q_a_norm → q_b (NVFP4 GEMM) → q_b_norm - q_a = prod_lin['q_a'].run_from_quantized(x_quant) if x_quant is not None else prod_lin['q_a'](x_normed) - _pt('q_a_end') - if VERBOSE >= 2 and li < 3: - # Compare q_a with PyTorch reference - q_a_ref = do_nvfp4_linear_ref(x_normed, w, pfx, 'q_a_proj') - if q_a_ref is not None: - cos_qa = torch.nn.functional.cosine_similarity(q_a.flatten().float(), q_a_ref.flatten().float(), dim=0).item() - print(f" L{li} q_a: |prod|={q_a.abs().max().item():.6f} |ref|={q_a_ref.abs().max().item():.6f} cos={cos_qa:.6f}", flush=True) - q_norm_w = w.get(f"{pfx}.q_a_norm.weight") - # B3: Fused rmsnorm+quant for q_a_norm → q_b path - # Replaces: rmsnorm(q_a, w) → BF16 → q_b quantizes internally - # With: fused rmsnorm+NVFP4 quantize → QuantizedActivation → q_b.run_from_quantized - # Saves: ~6 kernel launches per layer (rmsnorm 4+ + quantize 2 vs fused 2) - if q_norm_w is not None: - from dsv4.ops.quantize import rmsnorm_quantize_nvfp4 as _rmsnorm_quantize, dequantize_nvfp4 as _dequantize_nvfp4 - q_a_quant = _rmsnorm_quantize(q_a, q_norm_w.to(dev, torch.float32)) - q_a = _dequantize_nvfp4(q_a_quant.x_fp4, q_a_quant.x_sf, q_a_quant.gsa) - _pt('q_b_start') - if q_norm_w is not None: - q = prod_lin['q_b'].run_from_quantized(q_a_quant) - else: - q = prod_lin['q_b'](q_a) - q = unweighted_rmsnorm(q).bfloat16() - _pt('q_b_end') - q_heads = q.reshape(T, n_h, hd); q_heads = _apply_rope(q_heads, positions, rope_cos, rope_sin, rd) + # When q_heads is provided (from CUDA graph A), skip projections — only apply RoPE + if q_heads is None: + q_a = prod_lin['q_a'].run_from_quantized(x_quant) if x_quant is not None else prod_lin['q_a'](x_normed) + _pt('q_a_end') + if VERBOSE >= 2 and li < 3: + # Compare q_a with PyTorch reference + q_a_ref = do_nvfp4_linear_ref(x_normed, w, pfx, 'q_a_proj') + if q_a_ref is not None: + cos_qa = torch.nn.functional.cosine_similarity(q_a.flatten().float(), q_a_ref.flatten().float(), dim=0).item() + print(f" L{li} q_a: |prod|={q_a.abs().max().item():.6f} |ref|={q_a_ref.abs().max().item():.6f} cos={cos_qa:.6f}", flush=True) + q_norm_w = w.get(f"{pfx}.q_a_norm.weight") + # B3: Fused rmsnorm+quant for q_a_norm → q_b path + if q_norm_w is not None: + from dsv4.ops.quantize import rmsnorm_quantize_nvfp4 as _rmsnorm_quantize, dequantize_nvfp4 as _dequantize_nvfp4 + q_a_quant = _rmsnorm_quantize(q_a, q_norm_w.to(dev, torch.float32)) + q_a = _dequantize_nvfp4(q_a_quant.x_fp4, q_a_quant.x_sf, q_a_quant.gsa) + _pt('q_b_start') + if q_norm_w is not None: + q = prod_lin['q_b'].run_from_quantized(q_a_quant) + else: + q = prod_lin['q_b'](q_a) + q = unweighted_rmsnorm(q).bfloat16() + _pt('q_b_end') + q_heads = q.reshape(T, n_h, hd) + q_heads = _apply_rope(q_heads, positions, rope_cos, rope_sin, rd) _pt('rope_q_end') # 2. KV (NVFP4 GEMM, MQA, single KV head) + # When kv_3d is provided (from CUDA graph A), skip projections — only apply RoPE _pt('kv_start') - kv = prod_lin['kv'].run_from_quantized(x_quant) if x_quant is not None else prod_lin['kv'](x_normed) - _pt('kv_end') - kv_norm_w = w.get(f"{pfx}.kv_norm.weight") - if kv_norm_w is not None: kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32)) - kv_3d = kv.reshape(T, 1, hd); kv_3d = _apply_rope(kv_3d, positions, rope_cos, rope_sin, rd) + if kv_3d is None: + kv = prod_lin['kv'].run_from_quantized(x_quant) if x_quant is not None else prod_lin['kv'](x_normed) + _pt('kv_end') + kv_norm_w = w.get(f"{pfx}.kv_norm.weight") + if kv_norm_w is not None: kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32)) + kv_3d = kv.reshape(T, 1, hd) + kv_3d = _apply_rope(kv_3d, positions, rope_cos, rope_sin, rd) _pt('rope_kv_end') kv_roped = kv_3d.reshape(T, hd); kv_cache.append_swa(kv_roped, positions) @@ -1670,7 +1798,7 @@ def main(): graph_decoder = None if _args.cuda_graph: print(" CUDA graph capture requested — will capture after warmup step") - graph_decoder = CUDAGraphDecoder(n_layers, NUM_GPUS, H, [f'cuda:{g}' for g in range(NUM_GPUS)]) + graph_decoder = CUDAGraphDecoder(n_layers, NUM_GPUS, H, [f'cuda:{g}' for g in range(NUM_GPUS)], cfg) graph_decoder.pre_allocate(cfg) for step in range(MAX_NEW_TOKENS): @@ -1692,18 +1820,41 @@ def main(): # ---- Forward: graph replay or eager ---- if graph_decoder is not None and graph_decoder.captured: - # CUDA graph replay path — one graph per layer + # CUDA graph replay path — A/B split with eager attention for li in range(n_layers): gpu = li % NUM_GPUS torch.cuda.set_device(gpu) + dev = f'cuda:{gpu}' - # Copy X into graph input buffer (copy_ handles cross-GPU transfer) + # Copy X into graph A input buffer (copy_ handles cross-GPU transfer) graph_decoder.x_in_bufs[li].copy_(X) - # Replay layer graph - graph_decoder.graphs[li].replay() + # Replay graph A: mHC pre_block + RMSNorm + q_a/q_b/kv projections + graph_decoder.graphs_a[li].replay() - # Read output from graph + # ---- Eager attention (NOT captured) ---- + # Read graph A outputs from pre-allocated buffers + x_normed = graph_decoder.x_normed_bufs[li] + q_heads = graph_decoder.q_heads_bufs[li] + kv_3d = graph_decoder.kv_3d_bufs[li] + + # Run full attention eagerly (compressor + indexer + FMHA + o_proj) + F_attn, _ = forward_attention( + x_normed, layer_w[li], li, cfg, *rope_caches[gpu], + kv_caches[li], dec_pos_per_gpu[gpu], dec_tid32_per_gpu[gpu], + compressors.get(li), indexers.get(li), prod_lins.get(li), + q_heads=q_heads, kv_3d=kv_3d, # pass pre-computed q/kv from graph A + comp_rope_cos=comp_rope_caches[gpu][0] if comp_rope_caches else None, + comp_rope_sin=comp_rope_caches[gpu][1] if comp_rope_caches else None, + ) + + # Write F_attn to graph B input buffer + graph_decoder.F_attn_bufs[li].copy_(F_attn) + + # Replay graph B: mHC post_block + FFN + MoE + SE + graph_decoder.graphs_b[li].replay() + + # Read output from graph B X = graph_decoder.x_out_bufs[li] # Transfer last layer output to cuda:0 for lm_head graph