feat: implement the min_tokens sampling parameter (#3124)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
@@ -79,6 +79,8 @@ class SamplingParams:
|
||||
ignore_eos: Whether to ignore the EOS token and continue generating
|
||||
tokens after the EOS token is generated.
|
||||
max_tokens: Maximum number of tokens to generate per output sequence.
|
||||
min_tokens: Minimum number of tokens to generate per output sequence
|
||||
before EOS or stop_token_ids can be generated
|
||||
logprobs: Number of log probabilities to return per output token.
|
||||
Note that the implementation follows the OpenAI API: The return
|
||||
result includes the log probabilities on the `logprobs` most likely
|
||||
@@ -113,6 +115,7 @@ class SamplingParams:
|
||||
include_stop_str_in_output: bool = False,
|
||||
ignore_eos: bool = False,
|
||||
max_tokens: Optional[int] = 16,
|
||||
min_tokens: int = 0,
|
||||
logprobs: Optional[int] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
skip_special_tokens: bool = True,
|
||||
@@ -144,6 +147,7 @@ class SamplingParams:
|
||||
self.stop_token_ids = list(stop_token_ids)
|
||||
self.ignore_eos = ignore_eos
|
||||
self.max_tokens = max_tokens
|
||||
self.min_tokens = min_tokens
|
||||
self.logprobs = logprobs
|
||||
self.prompt_logprobs = prompt_logprobs
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
@@ -161,6 +165,8 @@ class SamplingParams:
|
||||
self.top_k = -1
|
||||
self.min_p = 0.0
|
||||
self._verify_greedy_sampling()
|
||||
# injected by the engine
|
||||
self.eos_token_id = None
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if self.n < 1:
|
||||
@@ -191,6 +197,13 @@ class SamplingParams:
|
||||
if self.max_tokens is not None and self.max_tokens < 1:
|
||||
raise ValueError(
|
||||
f"max_tokens must be at least 1, got {self.max_tokens}.")
|
||||
if self.min_tokens < 0:
|
||||
raise ValueError(f"min_tokens must be greater than or equal to 0, "
|
||||
f"got {self.min_tokens}.")
|
||||
if self.max_tokens is not None and self.min_tokens > self.max_tokens:
|
||||
raise ValueError(
|
||||
f"min_tokens must be less than or equal to "
|
||||
f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
|
||||
if self.logprobs is not None and self.logprobs < 0:
|
||||
raise ValueError(
|
||||
f"logprobs must be non-negative, got {self.logprobs}.")
|
||||
@@ -272,6 +285,7 @@ class SamplingParams:
|
||||
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
|
||||
f"ignore_eos={self.ignore_eos}, "
|
||||
f"max_tokens={self.max_tokens}, "
|
||||
f"min_tokens={self.min_tokens}, "
|
||||
f"logprobs={self.logprobs}, "
|
||||
f"prompt_logprobs={self.prompt_logprobs}, "
|
||||
f"skip_special_tokens={self.skip_special_tokens}, "
|
||||
|
||||
Reference in New Issue
Block a user