From ae79bd8fce580d76f6236844cc48e51120f7a6ae Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 31 May 2026 04:25:01 +0000 Subject: [PATCH] debug: add top-5 logit predictions --- single_shot_inference.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/single_shot_inference.py b/single_shot_inference.py index 5d61e74c..a07cbbb3 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -769,6 +769,9 @@ def main(): x_out = (xf * rms * final_norm_w.float()).bfloat16() logits = torch.nn.functional.linear(x_out, lm_w) + # Top-5 predictions for debugging + top5_vals, top5_ids = torch.topk(logits[0], 5) + top5_str = ' '.join([f'{tokenizer.decode([tid.item()])}({val.item():.1f})' for tid, val in zip(top5_ids, top5_vals)]) next_id = torch.argmax(logits, dim=-1).item() generated.append(next_id) all_tokens.append(next_id) @@ -781,7 +784,7 @@ def main(): x_max = X.abs().max().item() print(f" Step {step}: {next_id} '{tok_str}' ({dt:.2f}s) " f"logits=[{lmin:.1f},{lmax:.1f}] nan={has_nan} inf={has_inf} " - f"|X|={x_max:.3f}") + f"|X|={x_max:.3f} top5: {top5_str}") if has_nan or has_inf: print(" Numerical issue — stopping")