profile: add per-layer attn vs ffn timing with CUDA sync

This commit is contained in:
2026-06-02 02:46:35 +00:00
parent 4017ef2f16
commit 4734e894c7

View File

@@ -535,13 +535,17 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
kv_cache, positions, token_id,
compressor=None, indexer=None,
moe_runner=None, se_runner=None, router=None,
prod_lin=None):
prod_lin=None, _profile_detail=False):
x_in, ctx_a = attn_mhc.pre_block(X_l); x_normed = rmsnorm(x_in, attn_norm_w)
if _profile_detail: torch.cuda.synchronize(); t_attn0 = time.perf_counter()
F_attn, _ = forward_attention(x_normed, w, li, cfg, rope_cos, rope_sin,
kv_cache, positions, compressor, indexer, prod_lin)
if _profile_detail: torch.cuda.synchronize(); t_attn1 = time.perf_counter()
X_mid = attn_mhc.post_block(X_l, F_attn, ctx_a)
x_in_f, ctx_f = ffn_mhc.pre_block(X_mid); x_ffn = rmsnorm(x_in_f, ffn_norm_w)
if _profile_detail: torch.cuda.synchronize(); t_ffn0 = time.perf_counter()
F_ffn = moe_forward(x_ffn, li, moe_runner, se_runner, router, token_id)
if _profile_detail: torch.cuda.synchronize(); t_ffn1 = time.perf_counter()
X_next = ffn_mhc.post_block(X_mid, F_ffn, ctx_f)
if VERBOSE >= 2 and (li < 3 or li >= 58):
print(f" L{li}: |X|={X_l.abs().max().item():.1f}->{X_next.abs().max().item():.1f} "
@@ -563,6 +567,11 @@ def forward_layer(X_l, w, li, cfg, rope_cos, rope_sin,
f"|X_l|={X_l.abs().max().item():.1f} "
f"|X_mid|={X_mid.abs().max().item():.1f} "
f"|X_next|={X_next.abs().max().item():.1f}", flush=True)
if _profile_detail and li < 3 or li == 30 or li >= 58:
torch.cuda.synchronize()
attn_ms = (t_attn1 - t_attn0) * 1000
ffn_ms = (t_ffn1 - t_ffn0) * 1000
print(f" L{li}: attn={attn_ms:.2f}ms ffn={ffn_ms:.2f}ms", flush=True)
return X_next
# =====================================================================
@@ -994,6 +1003,22 @@ def main():
prof_lm_head = 0.0
prof_sample = 0.0
prof_sample_start = 0.0
# CUDA event profiling — measures ACTUAL GPU time, not wall clock
# Only profile steps 1-3 (after warmup) to get stable results
cuda_events = {}
if profile:
for tag in ['embed', 'layers', 'hc_norm_lm', 'sample', 'diagnostics']:
cuda_events[tag] = (torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True))
# Per-layer category events (sampled on step 1 only)
layer_event_tags = ['mhc_pre', 'attn_proj', 'rope_kv', 'compress_idx', 'fmha', 'inv_rope', 'o_proj',
'mhc_post', 'mhc_pre_ffn', 'router', 'moe', 'shared_expert', 'mhc_post_ffn']
cuda_layer_events = {}
for tag in layer_event_tags:
cuda_layer_events[tag] = (torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True))
layer_event_accum = {tag: 0.0 for tag in layer_event_tags}
layer_event_count = 0
for step in range(MAX_NEW_TOKENS):
t1 = time.time()
dec_tid_buf[0] = all_tokens[-1]
@@ -1012,9 +1037,9 @@ def main():
kv_caches[li], dec_pos_buf, dec_tid32_buf,
compressors.get(li), indexers.get(li),
moe_runners.get(li), se_runners.get(li), routers.get(li),
prod_lin=prod_lins.get(li))
prod_lin=prod_lins.get(li),
_profile_detail=(profile and step == 1))
X = X.to('cuda:0'); torch.cuda.set_device(0)
if profile: torch.cuda.synchronize()
t_layers = time.perf_counter()
# After first decode step: fix gsa values from runtime amax