Fix wrong truncate_prompt_tokens type hint (#22761)

Signed-off-by: Gabriel Marinho <gmarinho@ibm.com>
Signed-off-by: Gabriel Marinho <104592062+gmarinho2@users.noreply.github.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
Gabriel Marinho
2025-08-30 17:39:38 -03:00
committed by GitHub
parent 038e9be4eb
commit 5b8077b8ac
14 changed files with 101 additions and 102 deletions

View File

@@ -182,7 +182,8 @@ class SamplingParams(
optionally prompt tokens as a first argument."""
include_stop_str_in_output: bool = False
"""Whether to include the stop strings in output text."""
truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
truncate_prompt_tokens: Optional[Annotated[int,
msgspec.Meta(ge=-1)]] = None
"""If set to -1, will use the truncation size supported by the model. If
set to an integer k, will use only the last k tokens from the prompt
(i.e., left truncation). If set to `None`, truncation is disabled."""
@@ -241,7 +242,8 @@ class SamplingParams(
spaces_between_special_tokens: bool = True,
logits_processors: Optional[list[LogitsProcessor]] = None,
truncate_prompt_tokens: Optional[Annotated[int,
msgspec.Meta(ge=1)]] = None,
msgspec.Meta(
ge=-1)]] = None,
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
guided_decoding: Optional[GuidedDecodingParams] = None,
logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
@@ -411,9 +413,11 @@ class SamplingParams(
raise ValueError(f"prompt_logprobs must be non-negative, got "
f"{self.prompt_logprobs}.")
if (self.truncate_prompt_tokens is not None
and self.truncate_prompt_tokens < 1):
raise ValueError(f"truncate_prompt_tokens must be >= 1, "
f"got {self.truncate_prompt_tokens}")
and (self.truncate_prompt_tokens == 0
or self.truncate_prompt_tokens < -1)):
raise ValueError(
f"truncate_prompt_tokens must be an integer >= 1 or -1, "
f"got {self.truncate_prompt_tokens}")
assert isinstance(self.stop_token_ids, list)
if not all(isinstance(st_id, int) for st_id in self.stop_token_ids):
raise ValueError(f"stop_token_ids must contain only integers, "