Add temperature sampling + repetition penalty to fix degenerate repetition

With --temperature 0.7 --repetition-penalty 1.2, the model should generate
more diverse text instead of repeating 'France' endlessly.
This commit is contained in:
2026-06-01 14:54:49 +00:00
parent 6cca16f97a
commit 1d64b863ca

View File

@@ -19,6 +19,8 @@ log = logging.getLogger("single_shot")
def parse_args():
p = argparse.ArgumentParser()
p.add_argument('--max-tokens', type=int, default=128)
p.add_argument('--temperature', type=float, default=0.0, help='Sampling temperature (0=greedy)')
p.add_argument('--repetition-penalty', type=float, default=1.2, help='Repetition penalty factor')
p.add_argument('--prompt', type=str, default=None)
p.add_argument('--seed', type=int, default=42)
p.add_argument('--verbose', type=int, default=1)
@@ -898,7 +900,17 @@ def main():
x_out = hc_head.forward(X) if hc_head is not None else X[:, 0, :]
if final_norm_w is not None: x_out = rmsnorm(x_out, final_norm_w)
logits = F.linear(x_out, lm_w)
next_id = torch.argmax(logits, -1).item(); all_tokens.append(next_id)
# Sampling with repetition penalty
if _args.temperature > 0:
# Apply repetition penalty
if len(all_tokens) > 0:
for tid_pen in set(all_tokens[-64:]):
logits[0, tid_pen] /= _args.repetition_penalty
probs = torch.softmax(logits.float() / _args.temperature, -1)
next_id = torch.multinomial(probs, 1).item()
else:
next_id = torch.argmax(logits, -1).item()
all_tokens.append(next_id)
dt = time.time() - t1
has_nan = torch.isnan(logits.float()).any().item()
if step % 1 == 0 or has_nan: