From 4734e894c7e5f3244c1ea15bb91c470a2149465e Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 2 Jun 2026 02:46:35 +0000 Subject: [PATCH] profile: add per-layer attn vs ffn timing with CUDA sync --- single_shot_inference.py | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index 88912f8d..50abe2aa 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -535,13 +535,17 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, kv_cache, positions, token_id, compressor=None, indexer=None, moe_runner=None, se_runner=None, router=None, - prod_lin=None): + prod_lin=None, _profile_detail=False): x_in, ctx_a = attn_mhc.pre_block(X_l); x_normed = rmsnorm(x_in, attn_norm_w) + if _profile_detail: torch.cuda.synchronize(); t_attn0 = time.perf_counter() F_attn, _ = forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin, kv_cache, positions, compressor, indexer, prod_lin) + if _profile_detail: torch.cuda.synchronize(); t_attn1 = time.perf_counter() X_mid = attn_mhc.post_block(X_l, F_attn, ctx_a) x_in_f, ctx_f = ffn_mhc.pre_block(X_mid); x_ffn = rmsnorm(x_in_f, ffn_norm_w) + if _profile_detail: torch.cuda.synchronize(); t_ffn0 = time.perf_counter() F_ffn = moe_forward(x_ffn, li, moe_runner, se_runner, router, token_id) + if _profile_detail: torch.cuda.synchronize(); t_ffn1 = time.perf_counter() X_next = ffn_mhc.post_block(X_mid, F_ffn, ctx_f) if VERBOSE >= 2 and (li < 3 or li >= 58): print(f" L{li}: |X|={X_l.abs().max().item():.1f}->{X_next.abs().max().item():.1f} " @@ -563,6 +567,11 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin, f"|X_l|={X_l.abs().max().item():.1f} " f"|X_mid|={X_mid.abs().max().item():.1f} " f"|X_next|={X_next.abs().max().item():.1f}", flush=True) + if _profile_detail and li < 3 or li == 30 or li >= 58: + torch.cuda.synchronize() + attn_ms = (t_attn1 - t_attn0) * 1000 + ffn_ms = (t_ffn1 - t_ffn0) * 1000 + print(f" L{li}: attn={attn_ms:.2f}ms ffn={ffn_ms:.2f}ms", flush=True) return X_next # ===================================================================== @@ -994,6 +1003,22 @@ def main(): prof_lm_head = 0.0 prof_sample = 0.0 prof_sample_start = 0.0 + + # CUDA event profiling — measures ACTUAL GPU time, not wall clock + # Only profile steps 1-3 (after warmup) to get stable results + cuda_events = {} + if profile: + for tag in ['embed', 'layers', 'hc_norm_lm', 'sample', 'diagnostics']: + cuda_events[tag] = (torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)) + # Per-layer category events (sampled on step 1 only) + layer_event_tags = ['mhc_pre', 'attn_proj', 'rope_kv', 'compress_idx', 'fmha', 'inv_rope', 'o_proj', + 'mhc_post', 'mhc_pre_ffn', 'router', 'moe', 'shared_expert', 'mhc_post_ffn'] + cuda_layer_events = {} + for tag in layer_event_tags: + cuda_layer_events[tag] = (torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)) + layer_event_accum = {tag: 0.0 for tag in layer_event_tags} + layer_event_count = 0 + for step in range(MAX_NEW_TOKENS): t1 = time.time() dec_tid_buf[0] = all_tokens[-1] @@ -1012,9 +1037,9 @@ def main(): kv_caches[li], dec_pos_buf, dec_tid32_buf, compressors.get(li), indexers.get(li), moe_runners.get(li), se_runners.get(li), routers.get(li), - prod_lin=prod_lins.get(li)) + prod_lin=prod_lins.get(li), + _profile_detail=(profile and step == 1)) X = X.to('cuda:0'); torch.cuda.set_device(0) - if profile: torch.cuda.synchronize() t_layers = time.perf_counter() # After first decode step: fix gsa values from runtime amax