Fix wrong truncate_prompt_tokens type hint (#22761)
Signed-off-by: Gabriel Marinho <gmarinho@ibm.com> Signed-off-by: Gabriel Marinho <104592062+gmarinho2@users.noreply.github.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
@@ -182,7 +182,8 @@ class SamplingParams(
|
||||
optionally prompt tokens as a first argument."""
|
||||
include_stop_str_in_output: bool = False
|
||||
"""Whether to include the stop strings in output text."""
|
||||
truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
|
||||
truncate_prompt_tokens: Optional[Annotated[int,
|
||||
msgspec.Meta(ge=-1)]] = None
|
||||
"""If set to -1, will use the truncation size supported by the model. If
|
||||
set to an integer k, will use only the last k tokens from the prompt
|
||||
(i.e., left truncation). If set to `None`, truncation is disabled."""
|
||||
@@ -241,7 +242,8 @@ class SamplingParams(
|
||||
spaces_between_special_tokens: bool = True,
|
||||
logits_processors: Optional[list[LogitsProcessor]] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int,
|
||||
msgspec.Meta(ge=1)]] = None,
|
||||
msgspec.Meta(
|
||||
ge=-1)]] = None,
|
||||
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
|
||||
guided_decoding: Optional[GuidedDecodingParams] = None,
|
||||
logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
|
||||
@@ -411,9 +413,11 @@ class SamplingParams(
|
||||
raise ValueError(f"prompt_logprobs must be non-negative, got "
|
||||
f"{self.prompt_logprobs}.")
|
||||
if (self.truncate_prompt_tokens is not None
|
||||
and self.truncate_prompt_tokens < 1):
|
||||
raise ValueError(f"truncate_prompt_tokens must be >= 1, "
|
||||
f"got {self.truncate_prompt_tokens}")
|
||||
and (self.truncate_prompt_tokens == 0
|
||||
or self.truncate_prompt_tokens < -1)):
|
||||
raise ValueError(
|
||||
f"truncate_prompt_tokens must be an integer >= 1 or -1, "
|
||||
f"got {self.truncate_prompt_tokens}")
|
||||
assert isinstance(self.stop_token_ids, list)
|
||||
if not all(isinstance(st_id, int) for st_id in self.stop_token_ids):
|
||||
raise ValueError(f"stop_token_ids must contain only integers, "
|
||||
|
||||
Reference in New Issue
Block a user