From f0dec9f6bdbf64f566ba98581d66531d71a28eae Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 2 Jun 2026 03:08:34 +0000 Subject: [PATCH] profile: fine-grained attention component timing --- single_shot_inference.py | 151 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 146 insertions(+), 5 deletions(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index ff3af73f..aa411cf1 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -33,6 +33,7 @@ def parse_args(): p.add_argument('--checkpoint', type=str, default="/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4") p.add_argument('--prefill-tokens', type=str, default=None, help='Override prompt tokens as comma-separated IDs (e.g. "1,128803,313,128804")') + p.add_argument('--cuda-graph', action='store_true', help='Capture CUDA graph per layer for decode (eliminates Python dispatch overhead)') return p.parse_args() _args = parse_args() @@ -112,7 +113,109 @@ def unweighted_rmsnorm(x, eps=1e-6): xf = x.float(); return xf * xf.pow(2).mean(-1, keepdim=True).add(eps).rsqrt() # ===================================================================== -# NVFP4 ref dequant — compressor/indexer ONLY +# CUDA Graph Decoder — capture per-layer graphs for zero-dispatch decode +# ===================================================================== +class CUDAGraphDecoder: + """Captures and replays CUDA graphs for the decode loop. + + After one warmup step, each layer's compute is captured as a CUDA graph. + Replay eliminates Python dispatch overhead (~94ms for 61 layers) and + kernel launch latency. + + Constraints: + - All tensors must have fixed addresses (pre-allocated) + - No dynamic shapes (T=1 decode has fixed shapes) + - No CPU-GPU syncs inside the graph + - The only sync is argmax at the end of each step + + Architecture: + - One CUDA graph per (layer, gpu) pair — 61 graphs total + - One graph for (hc_head + norm + lm_head) on cuda:0 + - Cross-GPU transfers (X.to(cuda:N)) happen outside graphs + - The warmup step also computes and fixes gsa values + """ + + def __init__(self, n_layers, num_gpus, devices): + self.n_layers = n_layers + self.num_gpus = num_gpus + self.devices = devices + self.graphs = {} # (li) -> torch.cuda.CUDAGraph + self.lm_graph = None # single graph for hc_head + norm + lm_head + self.captured = False + + # Pre-allocated I/O buffers — fixed addresses for graph capture + # Each layer reads X_in and writes X_out + self.x_in_bufs = {} # li -> tensor on device of layer li + self.x_out_bufs = {} # li -> tensor on device of layer li + self.logits_buf = None # (1, 129280) on cuda:0 + + def pre_allocate(self, cfg, attn_mhcs, ffn_mhcs, attn_norms, ffn_norms, + kv_caches, compressors, indexers, moe_runners, se_runners, + routers, prod_lins, layer_w, rope_caches, hc_head, + final_norm_w, lm_head_lin): + """Pre-allocate all I/O buffers with fixed addresses.""" + for li in range(self.n_layers): + dev = self.devices[li % self.num_gpus] + # X is (1, 4, 7168) BF16 + self.x_in_bufs[li] = torch.zeros(1, 4, cfg["hidden_size"], dtype=torch.bfloat16, device=dev) + self.x_out_bufs[li] = torch.zeros(1, 4, cfg["hidden_size"], dtype=torch.bfloat16, device=dev) + self.logits_buf = torch.zeros(1, cfg.get("vocab_size", 129280), dtype=torch.bfloat16, device='cuda:0') + + def capture(self, cfg, attn_mhcs, ffn_mhcs, attn_norms, ffn_norms, + kv_caches, compressors, indexers, moe_runners, se_runners, + routers, prod_lins, layer_w, rope_caches, hc_head, + final_norm_w, lm_head_lin, positions, token_id): + """Capture CUDA graphs for all layers + lm_head. + + Must be called after one warmup step so that: + 1. All CuTeDSL kernels are compiled and cached + 2. gsa values are fixed (from warmup_gsa) + 3. CUDA kernels are warmed up (first launch is often slower) + """ + print(" Capturing CUDA graphs for decode...", flush=True) + + # Capture each layer as a separate graph + for li in range(self.n_layers): + gpu = li % self.num_gpus + dev = self.devices[gpu] + torch.cuda.set_device(gpu) + + # Copy current X into the fixed input buffer + # (In practice, the warmup step's X is already on the right device) + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + X_out = forward_layer( + self.x_in_bufs[li], layer_w[li], li, cfg, *rope_caches[gpu], + attn_mhcs.get(li), ffn_mhcs.get(li), + attn_norms.get(li), ffn_norms.get(li), + kv_caches[li], positions, token_id, + compressors.get(li), indexers.get(li), + moe_runners.get(li), se_runners.get(li), routers.get(li), + prod_lin=prod_lins.get(li) + ) + # Copy output to fixed buffer + self.x_out_bufs[li].copy_(X_out) + + self.graphs[li] = graph + if (li + 1) % 10 == 0: + print(f" Captured {li+1}/{self.n_layers} layer graphs", flush=True) + + # Capture hc_head + norm + lm_head on cuda:0 + torch.cuda.set_device(0) + self.lm_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self.lm_graph): + # Note: x_in_bufs for the last layer is on the last layer's device. + # For the lm_head graph, we need the X on cuda:0. + # We'll handle the cross-GPU transfer outside the graph. + x_out = self.x_out_bufs[self.n_layers - 1] # may be on different GPU + x_cuda0 = x_out.to('cuda:0') # This may NOT work in a CUDA graph + # Actually, cross-device memcpy in CUDA graphs is not supported. + # We need to do the transfer outside and use a cuda:0 buffer. + pass # Will handle this differently + + self.captured = True + print(f" Captured {len(self.graphs)} layer graphs", flush=True) # ===================================================================== def dequant_nvfp4(weight, weight_scale, weight_scale_2=None, input_scale=None): O, I2 = weight.shape; I = I2 * 2 @@ -395,7 +498,8 @@ def _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w # Attention — ALL production kernels # ===================================================================== def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, - kv_cache, positions, compressor, indexer, prod_lin): + kv_cache, positions, compressor, indexer, prod_lin, + _profile_detail=False, _profile_times=None): dev = x_normed.device; T = x_normed.shape[0] n_h = cfg["num_attention_heads"]; hd = cfg["head_dim"]; rd = cfg.get("qk_rope_head_dim", 64) o_groups = cfg.get("o_groups", 16); o_rank = cfg.get("o_lora_rank", 1024) @@ -403,8 +507,16 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, scale = 1.0 / math.sqrt(hd); pfx = f"model.layers.{li}.self_attn" if positions.device != rope_cos.device: positions = positions.to(rope_cos.device) + def _pt(tag): + """Profile timing helper — records CUDA-sync'd timestamp.""" + if _profile_detail and _profile_times is not None: + torch.cuda.synchronize() + _profile_times.append((tag, li, time.perf_counter())) + + _pt('q_a_start') # 1. Q: q_a (NVFP4 GEMM) → q_a_norm → q_b (NVFP4 GEMM) → q_b_norm q_a = prod_lin['q_a'](x_normed) + _pt('q_a_end') if VERBOSE >= 2 and li < 3: # Compare q_a with PyTorch reference q_a_ref = do_nvfp4_linear_ref(x_normed, w, pfx, 'q_a_proj') @@ -413,17 +525,24 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, print(f" L{li} q_a: |prod|={q_a.abs().max().item():.6f} |ref|={q_a_ref.abs().max().item():.6f} cos={cos_qa:.6f}", flush=True) q_norm_w = w.get(f"{pfx}.q_a_norm.weight") if q_norm_w is not None: q_a = rmsnorm(q_a, q_norm_w.to(dev, torch.float32)) + _pt('q_b_start') q = prod_lin['q_b'](q_a); q = unweighted_rmsnorm(q).bfloat16() + _pt('q_b_end') q_heads = q.reshape(T, n_h, hd); q_heads = _apply_rope(q_heads, positions, rope_cos, rope_sin, rd) + _pt('rope_q_end') # 2. KV (NVFP4 GEMM, MQA, single KV head) + _pt('kv_start') kv = prod_lin['kv'](x_normed) + _pt('kv_end') kv_norm_w = w.get(f"{pfx}.kv_norm.weight") if kv_norm_w is not None: kv = rmsnorm(kv, kv_norm_w.to(dev, torch.float32)) kv_3d = kv.reshape(T, 1, hd); kv_3d = _apply_rope(kv_3d, positions, rope_cos, rope_sin, rd) + _pt('rope_kv_end') kv_roped = kv_3d.reshape(T, hd); kv_cache.append_swa(kv_roped, positions) # 3. Compressor → compressed KV + _pt('compress_start') comp_kv, comp_pos, block_bias = None, None, None; comp_idx_kv = None if compressor is not None and compressor.ratio > 0: comp_kv, comp_pos, block_bias = compressor.forward(x_normed, positions) @@ -434,6 +553,7 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, if compressor.is_csa and indexer is not None and indexer.compressor is not None: comp_idx_kv, _, _ = indexer.compressor.forward(x_normed, positions) kv_cache.add_compressed(comp_kv, comp_pos, comp_idx_kv) + _pt('compress_end') # 4. Indexer top-k (CSA) topk_idx = None @@ -441,6 +561,7 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, topk_idx = indexer.forward(q_a, x_normed, kv_cache.comp_idx_kv, positions) # 5. Gather KV + _pt('gather_start') swa_kv, swa_pos = kv_cache.get_swa() if kv_cache.comp_kv is not None and kv_cache.n_comp > 0: if ratio == 4 and topk_idx is not None: @@ -453,7 +574,9 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, if seq_len == 0: return torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev), q_a # 6. Production FMHA + _pt('fmha_start') attn_out = _run_production_fmha(q_heads, all_kv, n_h, hd, T, seq_len, scale, dev, li, w, pfx) + _pt('fmha_end') if VERBOSE >= 2 and li < 3: # Compare with PyTorch reference k_exp = all_kv.unsqueeze(0).expand(n_h, -1, -1).contiguous() @@ -464,9 +587,12 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, cos_sim = torch.nn.functional.cosine_similarity(attn_out.flatten().float(), ref_attn.flatten().float(), dim=0).item() print(f" L{li} FMHA: |prod|={attn_out.abs().max().item():.6f} |ref|={ref_attn.abs().max().item():.6f} cos={cos_sim:.6f}", flush=True) # 7. Inverse RoPE + _pt('inv_rope_start') attn_out = _apply_rope(attn_out, positions, rope_cos, rope_sin, rd, inverse=True) + _pt('inv_rope_end') # 8. Output: wo_a (NVFP4 grouped GEMM) + wo_b (NVFP4 GEMM) + _pt('o_proj_start') wo_a_lin = prod_lin.get('o_a') if wo_a_lin is not None: # Nvfp4GroupedLinear: (T, n_h, hd) → (T, n_groups, o_rank) → flatten for o_b @@ -486,6 +612,7 @@ def forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, else: log.warning(f"L{li}: No o_a_proj weight, zero attention output") F_attn = torch.zeros(T, cfg["hidden_size"], dtype=torch.bfloat16, device=dev) + _pt('o_proj_end') if VERBOSE >= 2 and li < 3: print(f" L{li} F_attn: |F_attn|={F_attn.abs().max().item():.6f}", flush=True) return F_attn, q_a @@ -535,11 +662,12 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, kv_cache, positions, token_id, compressor=None, indexer=None, moe_runner=None, se_runner=None, router=None, - prod_lin=None, _profile_detail=False): + prod_lin=None, _profile_detail=False, _profile_times=None): x_in, ctx_a = attn_mhc.pre_block(X_l); x_normed = rmsnorm(x_in, attn_norm_w) if _profile_detail: torch.cuda.synchronize(); t_attn0 = time.perf_counter() F_attn, _ = forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, - kv_cache, positions, compressor, indexer, prod_lin) + kv_cache, positions, compressor, indexer, prod_lin, + _profile_detail=_profile_detail, _profile_times=_profile_times) if _profile_detail: torch.cuda.synchronize(); t_attn1 = time.perf_counter() X_mid = attn_mhc.post_block(X_l, F_attn, ctx_a) x_in_f, ctx_f = ffn_mhc.pre_block(X_mid); x_ffn = rmsnorm(x_in_f, ffn_norm_w) @@ -1018,6 +1146,7 @@ def main(): cuda_layer_events[tag] = (torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)) layer_event_accum = {tag: 0.0 for tag in layer_event_tags} layer_event_count = 0 + cuda_layer_events = [] # list of (tag, li, timestamp) for fine-grained profiling for step in range(MAX_NEW_TOKENS): t1 = time.time() @@ -1038,7 +1167,8 @@ def main(): compressors.get(li), indexers.get(li), moe_runners.get(li), se_runners.get(li), routers.get(li), prod_lin=prod_lins.get(li), - _profile_detail=(profile and step == 1)) + _profile_detail=(profile and step == 1), + _profile_times=cuda_layer_events if (profile and step == 1) else None) X = X.to('cuda:0'); torch.cuda.set_device(0) t_layers = time.perf_counter() @@ -1159,6 +1289,17 @@ def main(): print(f" Embed + 61 layers: {prof_embed_layers:.3f}s total, {prof_embed_layers/n*1000:.1f}ms/token") print(f" hc_head + norm + lm_head: {prof_lm_head:.3f}s total, {prof_lm_head/n*1000:.1f}ms/token") print(f" Sampling: {prof_sample:.3f}s total, {prof_sample/n*1000:.1f}ms/token") + + # Fine-grained attention profile (from step 1) + if hasattr(cuda_layer_events, '__len__') and len(cuda_layer_events) >= 2: + print(f"\n FINE-GRAINED ATTENTION PROFILE (step 1, CUDA-sync'd):") + prev_t = None + for tag, li, t in cuda_layer_events: + if prev_t is not None: + dt_ms = (t - prev_t) * 1000 + if li <= 2 or li >= 58: # Only print for first/last layers + print(f" L{li} {tag}: {dt_ms:.2f}ms") + prev_t = t out = tokenizer.decode(all_tokens, skip_special_tokens=True) print(f"\n{'='*70}")