Enhance SamplingParams (#96)

This commit is contained in:
Woosuk Kwon
2023-05-11 15:45:30 -07:00
committed by GitHub
parent 55f8b0a5de
commit 42f1042e1c
7 changed files with 36 additions and 54 deletions

View File

@@ -367,7 +367,7 @@ def _sample(
next_token_ids = _sample_from_prompt(prob, sampling_params)
# Get top-k log probabilities for the next tokens.
next_logprobs = _get_topk_logprobs(
logprob, sampling_params.num_logprobs)
logprob, sampling_params.logprobs)
# Build the output.
for seq_id, next_token_id in zip(seq_ids, next_token_ids):
@@ -392,7 +392,7 @@ def _sample(
next_logprobs: Dict[int, Dict[int, float]] = {}
for i, seq_id in enumerate(seq_ids):
next_logprobs[seq_id] = _get_topk_logprobs(
logprob[i], sampling_params.num_logprobs)
logprob[i], sampling_params.logprobs)
# Build the output.
for seq_id, parent_seq_id, next_token_id in zip(