[Sampler] Support returning all logprobs or logits (#21792)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
22quinn
2025-08-04 03:04:12 -07:00
committed by GitHub
parent fed5849d3f
commit 54de71d0df
6 changed files with 45 additions and 9 deletions

View File

@@ -156,6 +156,7 @@ class SamplingParams(
Note that the implementation follows the OpenAI API: The API will
always return the log probability of the sampled token, so there
may be up to `logprobs+1` elements in the response.
When set to -1, return all `vocab_size` log probabilities.
prompt_logprobs: Number of log probabilities to return per prompt token.
detokenize: Whether to detokenize the output. Defaults to True.
skip_special_tokens: Whether to skip special tokens in the output.
@@ -414,9 +415,10 @@ class SamplingParams(
raise ValueError(
f"min_tokens must be less than or equal to "
f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
if self.logprobs is not None and self.logprobs < 0:
if (self.logprobs is not None and self.logprobs != -1
and self.logprobs < 0):
raise ValueError(
f"logprobs must be non-negative, got {self.logprobs}.")
f"logprobs must be non-negative or -1, got {self.logprobs}.")
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
raise ValueError(f"prompt_logprobs must be non-negative, got "
f"{self.prompt_logprobs}.")