Implement stop strings and best_of (#114)
This commit is contained in:
@@ -283,20 +283,20 @@ def _sample_from_prompt(
|
||||
) -> List[int]:
|
||||
if sampling_params.use_beam_search:
|
||||
# Beam search.
|
||||
beam_width = sampling_params.n
|
||||
beam_width = sampling_params.best_of
|
||||
_, next_token_ids = torch.topk(prob, beam_width)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
elif sampling_params.temperature == 0.0:
|
||||
# Greedy sampling.
|
||||
assert sampling_params.n == 1
|
||||
assert sampling_params.best_of == 1
|
||||
next_token_id = torch.argmax(prob)
|
||||
next_token_ids = [next_token_id.item()]
|
||||
else:
|
||||
# Random sampling.
|
||||
# Sample n tokens for the prompt.
|
||||
n = sampling_params.n
|
||||
# Sample `best_of` tokens for the prompt.
|
||||
num_seqs = sampling_params.best_of
|
||||
next_token_ids = torch.multinomial(
|
||||
prob, num_samples=n, replacement=True)
|
||||
prob, num_samples=num_seqs, replacement=True)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
return next_token_ids
|
||||
|
||||
@@ -308,7 +308,7 @@ def _sample_from_generation_tokens(
|
||||
seq_logprobs: List[float],
|
||||
sampling_params: SamplingParams,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
# NOTE(woosuk): sampling_params.n can be greater than
|
||||
# NOTE(woosuk): sampling_params.best_of can be greater than
|
||||
# len(seq_ids) because some sequences in the group might have
|
||||
# been already terminated.
|
||||
if sampling_params.use_beam_search:
|
||||
@@ -372,7 +372,7 @@ def _sample(
|
||||
seq_ids, sampling_params = seq_group
|
||||
if i < input_metadata.num_prompts:
|
||||
# Generate the next tokens for a prompt input.
|
||||
assert len(seq_ids) == sampling_params.n
|
||||
assert len(seq_ids) == sampling_params.best_of
|
||||
prob = probs[idx]
|
||||
logprob = logprobs[idx]
|
||||
idx += 1
|
||||
@@ -397,7 +397,7 @@ def _sample(
|
||||
|
||||
# Sample the next tokens.
|
||||
seq_logprobs = [
|
||||
input_metadata.seq_data[seq_id].cumulative_logprobs
|
||||
input_metadata.seq_data[seq_id].cumulative_logprob
|
||||
for seq_id in seq_ids]
|
||||
parent_seq_ids, next_token_ids = _sample_from_generation_tokens(
|
||||
seq_ids, prob, logprob, seq_logprobs, sampling_params)
|
||||
|
||||
Reference in New Issue
Block a user