Add top-20 logging and thinking token detection in decode loop

This commit is contained in:
2026-05-31 10:49:28 +00:00
parent d891ae7e96
commit a6d56d10ca

View File

@@ -884,8 +884,12 @@ def main():
logits = torch.nn.functional.linear(x_out, lm_w)
# Top-5 predictions for debugging
top5_vals, top5_ids = torch.topk(logits[0], 5)
top5_str = ' '.join([f'{tokenizer.decode([tid.item()])}({val.item():.1f})' for tid, val in zip(top5_ids, top5_vals)])
# Top-20 predictions for debugging (includes thinking tokens)
top20_vals, top20_ids = torch.topk(logits[0], 20)
top5_str = ' '.join([f'{tokenizer.decode([tid.item()])}({val.item():.1f})' for tid, val in zip(top5_ids[:5], top20_vals[:5])])
# Check if thinking tokens are in top-20
thinking_in_top20 = any(tid.item() in [128821, 128822] for tid in top20_ids)
top20_ids_set = set(top20_ids.tolist())
next_id = torch.argmax(logits, dim=-1).item()
generated.append(next_id)
all_tokens.append(next_id)
@@ -899,6 +903,12 @@ def main():
print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s) "
f"logits=[{lmin:.1f},{lmax:.1f}] nan={has_nan} inf={has_inf} "
f"|X|={x_max:.3f} top5: {top5_str}", flush=True)
if thinking_in_top20:
for tid_t, val_t in zip(top20_ids, top20_vals):
if tid_t.item() in [128821, 128822]:
print(f" THINK TOKEN: {tid_t.item()} logit={val_t.item():.3f}", flush=True)
if step % 5 == 0:
print(f" Top-20: {[(tokenizer.decode([t.item()]), f'{v.item():.2f}') for t, v in zip(top20_ids, top20_vals)]}", flush=True)
if has_nan or has_inf:
print(" Numerical issue — stopping")