debug: add top-5 logit predictions

This commit is contained in:
2026-05-31 04:25:01 +00:00
parent aafe2eee12
commit ae79bd8fce

View File

@@ -769,6 +769,9 @@ def main():
x_out = (xf * rms * final_norm_w.float()).bfloat16()
logits = torch.nn.functional.linear(x_out, lm_w)
# Top-5 predictions for debugging
top5_vals, top5_ids = torch.topk(logits[0], 5)
top5_str = ' '.join([f'{tokenizer.decode([tid.item()])}({val.item():.1f})' for tid, val in zip(top5_ids, top5_vals)])
next_id = torch.argmax(logits, dim=-1).item()
generated.append(next_id)
all_tokens.append(next_id)
@@ -781,7 +784,7 @@ def main():
x_max = X.abs().max().item()
print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s) "
f"logits=[{lmin:.1f},{lmax:.1f}] nan={has_nan} inf={has_inf} "
f"|X|={x_max:.3f}")
f"|X|={x_max:.3f} top5: {top5_str}")
if has_nan or has_inf:
print(" Numerical issue — stopping")