[Sampler] Support returning all logprobs or logits (#21792)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
@@ -156,6 +156,7 @@ class SamplingParams(
|
||||
Note that the implementation follows the OpenAI API: The API will
|
||||
always return the log probability of the sampled token, so there
|
||||
may be up to `logprobs+1` elements in the response.
|
||||
When set to -1, return all `vocab_size` log probabilities.
|
||||
prompt_logprobs: Number of log probabilities to return per prompt token.
|
||||
detokenize: Whether to detokenize the output. Defaults to True.
|
||||
skip_special_tokens: Whether to skip special tokens in the output.
|
||||
@@ -414,9 +415,10 @@ class SamplingParams(
|
||||
raise ValueError(
|
||||
f"min_tokens must be less than or equal to "
|
||||
f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
|
||||
if self.logprobs is not None and self.logprobs < 0:
|
||||
if (self.logprobs is not None and self.logprobs != -1
|
||||
and self.logprobs < 0):
|
||||
raise ValueError(
|
||||
f"logprobs must be non-negative, got {self.logprobs}.")
|
||||
f"logprobs must be non-negative or -1, got {self.logprobs}.")
|
||||
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
|
||||
raise ValueError(f"prompt_logprobs must be non-negative, got "
|
||||
f"{self.prompt_logprobs}.")
|
||||
|
||||
Reference in New Issue
Block a user