Implement stop strings and best_of (#114)

This commit is contained in:
Woosuk Kwon
2023-05-21 11:18:00 -07:00
committed by GitHub
parent c3442c1f6f
commit f746ced08d
9 changed files with 162 additions and 116 deletions

View File

@@ -283,20 +283,20 @@ def _sample_from_prompt(
) -> List[int]:
if sampling_params.use_beam_search:
# Beam search.
beam_width = sampling_params.n
beam_width = sampling_params.best_of
_, next_token_ids = torch.topk(prob, beam_width)
next_token_ids = next_token_ids.tolist()
elif sampling_params.temperature == 0.0:
# Greedy sampling.
assert sampling_params.n == 1
assert sampling_params.best_of == 1
next_token_id = torch.argmax(prob)
next_token_ids = [next_token_id.item()]
else:
# Random sampling.
# Sample n tokens for the prompt.
n = sampling_params.n
# Sample `best_of` tokens for the prompt.
num_seqs = sampling_params.best_of
next_token_ids = torch.multinomial(
prob, num_samples=n, replacement=True)
prob, num_samples=num_seqs, replacement=True)
next_token_ids = next_token_ids.tolist()
return next_token_ids
@@ -308,7 +308,7 @@ def _sample_from_generation_tokens(
seq_logprobs: List[float],
sampling_params: SamplingParams,
) -> Tuple[List[int], List[int]]:
# NOTE(woosuk): sampling_params.n can be greater than
# NOTE(woosuk): sampling_params.best_of can be greater than
# len(seq_ids) because some sequences in the group might have
# been already terminated.
if sampling_params.use_beam_search:
@@ -372,7 +372,7 @@ def _sample(
seq_ids, sampling_params = seq_group
if i < input_metadata.num_prompts:
# Generate the next tokens for a prompt input.
assert len(seq_ids) == sampling_params.n
assert len(seq_ids) == sampling_params.best_of
prob = probs[idx]
logprob = logprobs[idx]
idx += 1
@@ -397,7 +397,7 @@ def _sample(
# Sample the next tokens.
seq_logprobs = [
input_metadata.seq_data[seq_id].cumulative_logprobs
input_metadata.seq_data[seq_id].cumulative_logprob
for seq_id in seq_ids]
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
seq_ids, prob, logprob, seq_logprobs, sampling_params)