Add --profile flag: per-component GPU timing with CUDA sync (embed+layers, lm_head, sampling)
This commit is contained in:
@@ -27,6 +27,7 @@ def parse_args():
|
||||
p.add_argument('--seed', type=int, default=42)
|
||||
p.add_argument('--verbose', type=int, default=1)
|
||||
p.add_argument('--prefill-only', action='store_true')
|
||||
p.add_argument('--profile', action='store_true', help='Profile per-component GPU time using CUDA events')
|
||||
p.add_argument('--num-gpus', type=int, default=8)
|
||||
p.add_argument('--checkpoint', type=str, default="/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4")
|
||||
p.add_argument('--prefill-tokens', type=str, default=None,
|
||||
@@ -986,12 +987,17 @@ def main():
|
||||
# Decode
|
||||
print(f"\nDecoding (max {MAX_NEW_TOKENS} tokens)...")
|
||||
in_thinking = False
|
||||
profile = _args.profile
|
||||
prof_embed_layers = 0.0
|
||||
prof_lm_head = 0.0
|
||||
prof_sample = 0.0
|
||||
for step in range(MAX_NEW_TOKENS):
|
||||
t1 = time.time()
|
||||
dec_tid_buf[0] = all_tokens[-1]
|
||||
dec_tid32_buf[0] = all_tokens[-1]
|
||||
dec_pos_buf[0] = len(all_tokens) - 1
|
||||
|
||||
t_e = time.perf_counter()
|
||||
X = mHCLayer.init_state(embed(dec_tid_buf))
|
||||
for li in range(n_layers):
|
||||
gpu = li % NUM_GPUS
|
||||
@@ -1005,9 +1011,13 @@ def main():
|
||||
moe_runners.get(li), se_runners.get(li), routers.get(li),
|
||||
prod_lin=prod_lins.get(li))
|
||||
X = X.to('cuda:0'); torch.cuda.set_device(0)
|
||||
if profile: torch.cuda.synchronize()
|
||||
t_layers = time.perf_counter()
|
||||
x_out = hc_head.forward(X) if hc_head is not None else X[:, 0, :]
|
||||
if final_norm_w is not None: x_out = rmsnorm(x_out, final_norm_w)
|
||||
logits = lm_head_lin(x_out)
|
||||
if profile: torch.cuda.synchronize()
|
||||
t_lm = time.perf_counter()
|
||||
# Only sync + validate on first 3 steps and every 20th step (reduces pipeline stalls)
|
||||
if step < 3 or (step + 1) % 20 == 0:
|
||||
torch.cuda.synchronize() # catch CUDA errors at source
|
||||
@@ -1042,10 +1052,16 @@ def main():
|
||||
all_tokens.append(next_id)
|
||||
dt = time.time() - t1
|
||||
|
||||
t_s = time.perf_counter()
|
||||
# Track thinking state
|
||||
if next_id == THINK_START: in_thinking = True
|
||||
elif next_id == THINK_END: in_thinking = False
|
||||
|
||||
if profile:
|
||||
prof_embed_layers += (t_layers - t_e)
|
||||
prof_lm_head += (t_lm - t_layers)
|
||||
prof_sample += (t_s - t_lm)
|
||||
|
||||
# Diagnostics — reduce CPU syncs, only top-5 every 5 steps
|
||||
if step % 5 == 0 or step < 5:
|
||||
tv, ti = torch.topk(logits[0].float(), 5)
|
||||
@@ -1062,6 +1078,13 @@ def main():
|
||||
if next_id == tokenizer.eos_token_id:
|
||||
print(f" EOS at step {step}", flush=True); break
|
||||
|
||||
if profile and MAX_NEW_TOKENS > 0:
|
||||
n = MAX_NEW_TOKENS
|
||||
print(f"\n PROFILE (sync'd wall clock, {n} steps):")
|
||||
print(f" Embed + 61 layers: {prof_embed_layers:.3f}s total, {prof_embed_layers/n*1000:.1f}ms/token")
|
||||
print(f" hc_head + norm + lm_head: {prof_lm_head:.3f}s total, {prof_lm_head/n*1000:.1f}ms/token")
|
||||
print(f" Sampling: {prof_sample:.3f}s total, {prof_sample/n*1000:.1f}ms/token")
|
||||
|
||||
out = tokenizer.decode(all_tokens, skip_special_tokens=True)
|
||||
print(f"\n{'='*70}")
|
||||
print(f"Input: '{PROMPT}'")
|
||||
|
||||
Reference in New Issue
Block a user