Add option to completion API to truncate prompt tokens (#3144)
This commit is contained in:
@@ -5,6 +5,7 @@ from functools import cached_property
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from pydantic import conint
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
@@ -94,6 +95,9 @@ class SamplingParams:
|
||||
tokens in the output. Defaults to True.
|
||||
logits_processors: List of functions that modify logits based on
|
||||
previously generated tokens.
|
||||
truncate_prompt_tokens: If set to an integer k, will use only the last k
|
||||
tokens from the prompt (i.e., left truncation). Defaults to None
|
||||
(i.e., no truncation).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -123,6 +127,7 @@ class SamplingParams:
|
||||
skip_special_tokens: bool = True,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
logits_processors: Optional[List[LogitsProcessor]] = None,
|
||||
truncate_prompt_tokens: Optional[conint(ge=1)] = None,
|
||||
) -> None:
|
||||
self.n = n
|
||||
self.best_of = best_of if best_of is not None else n
|
||||
@@ -160,6 +165,7 @@ class SamplingParams:
|
||||
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||
self.logits_processors = logits_processors
|
||||
self.include_stop_str_in_output = include_stop_str_in_output
|
||||
self.truncate_prompt_tokens = truncate_prompt_tokens
|
||||
self._verify_args()
|
||||
if self.use_beam_search:
|
||||
self._verify_beam_search()
|
||||
@@ -216,6 +222,10 @@ class SamplingParams:
|
||||
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
|
||||
raise ValueError(f"prompt_logprobs must be non-negative, got "
|
||||
f"{self.prompt_logprobs}.")
|
||||
if (self.truncate_prompt_tokens is not None
|
||||
and self.truncate_prompt_tokens < 1):
|
||||
raise ValueError(f"truncate_prompt_tokens must be >= 1, "
|
||||
f"got {self.truncate_prompt_tokens}")
|
||||
if self.stop and not self.detokenize:
|
||||
raise ValueError(
|
||||
"stop strings are only supported when detokenize is True. "
|
||||
@@ -300,4 +310,5 @@ class SamplingParams:
|
||||
f"prompt_logprobs={self.prompt_logprobs}, "
|
||||
f"skip_special_tokens={self.skip_special_tokens}, "
|
||||
"spaces_between_special_tokens="
|
||||
f"{self.spaces_between_special_tokens})")
|
||||
f"{self.spaces_between_special_tokens}, "
|
||||
f"truncate_prompt_tokens={self.truncate_prompt_tokens})")
|
||||
|
||||
Reference in New Issue
Block a user