diff --git a/single_shot_inference.py b/single_shot_inference.py index 82b7c49c..f8d93a61 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -27,6 +27,7 @@ def parse_args(): p.add_argument('--seed', type=int, default=42) p.add_argument('--verbose', type=int, default=1) p.add_argument('--prefill-only', action='store_true') + p.add_argument('--profile', action='store_true', help='Profile per-component GPU time using CUDA events') p.add_argument('--num-gpus', type=int, default=8) p.add_argument('--checkpoint', type=str, default="/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4") p.add_argument('--prefill-tokens', type=str, default=None, @@ -986,12 +987,17 @@ def main(): # Decode print(f"\nDecoding (max {MAX_NEW_TOKENS} tokens)...") in_thinking = False + profile = _args.profile + prof_embed_layers = 0.0 + prof_lm_head = 0.0 + prof_sample = 0.0 for step in range(MAX_NEW_TOKENS): t1 = time.time() dec_tid_buf[0] = all_tokens[-1] dec_tid32_buf[0] = all_tokens[-1] dec_pos_buf[0] = len(all_tokens) - 1 + t_e = time.perf_counter() X = mHCLayer.init_state(embed(dec_tid_buf)) for li in range(n_layers): gpu = li % NUM_GPUS @@ -1005,9 +1011,13 @@ def main(): moe_runners.get(li), se_runners.get(li), routers.get(li), prod_lin=prod_lins.get(li)) X = X.to('cuda:0'); torch.cuda.set_device(0) + if profile: torch.cuda.synchronize() + t_layers = time.perf_counter() x_out = hc_head.forward(X) if hc_head is not None else X[:, 0, :] if final_norm_w is not None: x_out = rmsnorm(x_out, final_norm_w) logits = lm_head_lin(x_out) + if profile: torch.cuda.synchronize() + t_lm = time.perf_counter() # Only sync + validate on first 3 steps and every 20th step (reduces pipeline stalls) if step < 3 or (step + 1) % 20 == 0: torch.cuda.synchronize() # catch CUDA errors at source @@ -1042,10 +1052,16 @@ def main(): all_tokens.append(next_id) dt = time.time() - t1 + t_s = time.perf_counter() # Track thinking state if next_id == THINK_START: in_thinking = True elif next_id == THINK_END: in_thinking = False + if profile: + prof_embed_layers += (t_layers - t_e) + prof_lm_head += (t_lm - t_layers) + prof_sample += (t_s - t_lm) + # Diagnostics — reduce CPU syncs, only top-5 every 5 steps if step % 5 == 0 or step < 5: tv, ti = torch.topk(logits[0].float(), 5) @@ -1062,6 +1078,13 @@ def main(): if next_id == tokenizer.eos_token_id: print(f" EOS at step {step}", flush=True); break + if profile and MAX_NEW_TOKENS > 0: + n = MAX_NEW_TOKENS + print(f"\n PROFILE (sync'd wall clock, {n} steps):") + print(f" Embed + 61 layers: {prof_embed_layers:.3f}s total, {prof_embed_layers/n*1000:.1f}ms/token") + print(f" hc_head + norm + lm_head: {prof_lm_head:.3f}s total, {prof_lm_head/n*1000:.1f}ms/token") + print(f" Sampling: {prof_sample:.3f}s total, {prof_sample/n*1000:.1f}ms/token") + out = tokenizer.decode(all_tokens, skip_special_tokens=True) print(f"\n{'='*70}") print(f"Input: '{PROMPT}'")