Fix top5_ids variable name in decode logging

This commit is contained in:
2026-05-31 10:54:40 +00:00
parent a6d56d10ca
commit 9584fcbc23

View File

@@ -886,7 +886,7 @@ def main():
# Top-5 predictions for debugging
# Top-20 predictions for debugging (includes thinking tokens)
top20_vals, top20_ids = torch.topk(logits[0], 20)
top5_str = ' '.join([f'{tokenizer.decode([tid.item()])}({val.item():.1f})' for tid, val in zip(top5_ids[:5], top20_vals[:5])])
top5_str = ' '.join([f'{tokenizer.decode([tid.item()])}({val.item():.1f})' for tid, val in zip(top20_ids[:5], top20_vals[:5])])
# Check if thinking tokens are in top-20
thinking_in_top20 = any(tid.item() in [128821, 128822] for tid in top20_ids)
top20_ids_set = set(top20_ids.tolist())