diff --git a/single_shot_inference.py b/single_shot_inference.py index 7ed72262..cc8eafd4 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -19,6 +19,8 @@ log = logging.getLogger("single_shot") def parse_args(): p = argparse.ArgumentParser() p.add_argument('--max-tokens', type=int, default=128) + p.add_argument('--temperature', type=float, default=0.0, help='Sampling temperature (0=greedy)') + p.add_argument('--repetition-penalty', type=float, default=1.2, help='Repetition penalty factor') p.add_argument('--prompt', type=str, default=None) p.add_argument('--seed', type=int, default=42) p.add_argument('--verbose', type=int, default=1) @@ -898,7 +900,17 @@ def main(): x_out = hc_head.forward(X) if hc_head is not None else X[:, 0, :] if final_norm_w is not None: x_out = rmsnorm(x_out, final_norm_w) logits = F.linear(x_out, lm_w) - next_id = torch.argmax(logits, -1).item(); all_tokens.append(next_id) + # Sampling with repetition penalty + if _args.temperature > 0: + # Apply repetition penalty + if len(all_tokens) > 0: + for tid_pen in set(all_tokens[-64:]): + logits[0, tid_pen] /= _args.repetition_penalty + probs = torch.softmax(logits.float() / _args.temperature, -1) + next_id = torch.multinomial(probs, 1).item() + else: + next_id = torch.argmax(logits, -1).item() + all_tokens.append(next_id) dt = time.time() - t1 has_nan = torch.isnan(logits.float()).any().item() if step % 1 == 0 or has_nan: