diff --git a/single_shot_inference.py b/single_shot_inference.py index f475875c..ea47af2a 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -886,7 +886,7 @@ def main(): # Top-5 predictions for debugging # Top-20 predictions for debugging (includes thinking tokens) top20_vals, top20_ids = torch.topk(logits[0], 20) - top5_str = ' '.join([f'{tokenizer.decode([tid.item()])}({val.item():.1f})' for tid, val in zip(top5_ids[:5], top20_vals[:5])]) + top5_str = ' '.join([f'{tokenizer.decode([tid.item()])}({val.item():.1f})' for tid, val in zip(top20_ids[:5], top20_vals[:5])]) # Check if thinking tokens are in top-20 thinking_in_top20 = any(tid.item() in [128821, 128822] for tid in top20_ids) top20_ids_set = set(top20_ids.tolist())