Implement prompt logprobs & Batched topk for computing logprobs (#1328)
Co-authored-by: Yunmo Chen <16273544+wanmok@users.noreply.github.com>
This commit is contained in:
@@ -60,6 +60,12 @@ class SamplingParams:
|
||||
tokens after the EOS token is generated.
|
||||
max_tokens: Maximum number of tokens to generate per output sequence.
|
||||
logprobs: Number of log probabilities to return per output token.
|
||||
Note that the implementation follows the OpenAI API: The return
|
||||
result includes the log probabilities on the `logprobs` most likely
|
||||
tokens, as well the chosen tokens. The API will always return the
|
||||
log probability of the sampled token, so there may be up to
|
||||
`logprobs+1` elements in the response.
|
||||
prompt_logprobs: Number of log probabilities to return per prompt token.
|
||||
skip_special_tokens: Whether to skip special tokens in the output.
|
||||
"""
|
||||
|
||||
@@ -80,6 +86,7 @@ class SamplingParams:
|
||||
ignore_eos: bool = False,
|
||||
max_tokens: int = 16,
|
||||
logprobs: Optional[int] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
skip_special_tokens: bool = True,
|
||||
) -> None:
|
||||
self.n = n
|
||||
@@ -105,6 +112,7 @@ class SamplingParams:
|
||||
self.ignore_eos = ignore_eos
|
||||
self.max_tokens = max_tokens
|
||||
self.logprobs = logprobs
|
||||
self.prompt_logprobs = prompt_logprobs
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
|
||||
self._verify_args()
|
||||
@@ -142,6 +150,9 @@ class SamplingParams:
|
||||
if self.logprobs is not None and self.logprobs < 0:
|
||||
raise ValueError(
|
||||
f"logprobs must be non-negative, got {self.logprobs}.")
|
||||
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
|
||||
raise ValueError(f"prompt_logprobs must be non-negative, got "
|
||||
f"{self.prompt_logprobs}.")
|
||||
|
||||
def _verify_beam_search(self) -> None:
|
||||
if self.best_of == 1:
|
||||
@@ -200,4 +211,5 @@ class SamplingParams:
|
||||
f"ignore_eos={self.ignore_eos}, "
|
||||
f"max_tokens={self.max_tokens}, "
|
||||
f"logprobs={self.logprobs}, "
|
||||
f"prompt_logprobs={self.prompt_logprobs}, "
|
||||
f"skip_special_tokens={self.skip_special_tokens})")
|
||||
|
||||
Reference in New Issue
Block a user