Fix top5_ids variable name in decode logging
This commit is contained in:
@@ -886,7 +886,7 @@ def main():
|
||||
# Top-5 predictions for debugging
|
||||
# Top-20 predictions for debugging (includes thinking tokens)
|
||||
top20_vals, top20_ids = torch.topk(logits[0], 20)
|
||||
top5_str = ' '.join([f'{tokenizer.decode([tid.item()])}({val.item():.1f})' for tid, val in zip(top5_ids[:5], top20_vals[:5])])
|
||||
top5_str = ' '.join([f'{tokenizer.decode([tid.item()])}({val.item():.1f})' for tid, val in zip(top20_ids[:5], top20_vals[:5])])
|
||||
# Check if thinking tokens are in top-20
|
||||
thinking_in_top20 = any(tid.item() in [128821, 128822] for tid in top20_ids)
|
||||
top20_ids_set = set(top20_ids.tolist())
|
||||
|
||||
Reference in New Issue
Block a user