Add temperature sampling + repetition penalty to fix degenerate repetition
With --temperature 0.7 --repetition-penalty 1.2, the model should generate more diverse text instead of repeating 'France' endlessly.
This commit is contained in:
@@ -19,6 +19,8 @@ log = logging.getLogger("single_shot")
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument('--max-tokens', type=int, default=128)
|
||||
p.add_argument('--temperature', type=float, default=0.0, help='Sampling temperature (0=greedy)')
|
||||
p.add_argument('--repetition-penalty', type=float, default=1.2, help='Repetition penalty factor')
|
||||
p.add_argument('--prompt', type=str, default=None)
|
||||
p.add_argument('--seed', type=int, default=42)
|
||||
p.add_argument('--verbose', type=int, default=1)
|
||||
@@ -898,7 +900,17 @@ def main():
|
||||
x_out = hc_head.forward(X) if hc_head is not None else X[:, 0, :]
|
||||
if final_norm_w is not None: x_out = rmsnorm(x_out, final_norm_w)
|
||||
logits = F.linear(x_out, lm_w)
|
||||
next_id = torch.argmax(logits, -1).item(); all_tokens.append(next_id)
|
||||
# Sampling with repetition penalty
|
||||
if _args.temperature > 0:
|
||||
# Apply repetition penalty
|
||||
if len(all_tokens) > 0:
|
||||
for tid_pen in set(all_tokens[-64:]):
|
||||
logits[0, tid_pen] /= _args.repetition_penalty
|
||||
probs = torch.softmax(logits.float() / _args.temperature, -1)
|
||||
next_id = torch.multinomial(probs, 1).item()
|
||||
else:
|
||||
next_id = torch.argmax(logits, -1).item()
|
||||
all_tokens.append(next_id)
|
||||
dt = time.time() - t1
|
||||
has_nan = torch.isnan(logits.float()).any().item()
|
||||
if step % 1 == 0 or has_nan:
|
||||
|
||||
Reference in New Issue
Block a user