Add option to completion API to truncate prompt tokens (#3144)

This commit is contained in:
Thomas Parnell
2024-04-05 19:15:42 +02:00
committed by GitHub
parent cfaf49a167
commit 1d7c940d74
4 changed files with 41 additions and 8 deletions

View File

@@ -5,6 +5,7 @@ from functools import cached_property
from typing import Callable, List, Optional, Union
import torch
from pydantic import conint
_SAMPLING_EPS = 1e-5
@@ -94,6 +95,9 @@ class SamplingParams:
tokens in the output. Defaults to True.
logits_processors: List of functions that modify logits based on
previously generated tokens.
truncate_prompt_tokens: If set to an integer k, will use only the last k
tokens from the prompt (i.e., left truncation). Defaults to None
(i.e., no truncation).
"""
def __init__(
@@ -123,6 +127,7 @@ class SamplingParams:
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
logits_processors: Optional[List[LogitsProcessor]] = None,
truncate_prompt_tokens: Optional[conint(ge=1)] = None,
) -> None:
self.n = n
self.best_of = best_of if best_of is not None else n
@@ -160,6 +165,7 @@ class SamplingParams:
self.spaces_between_special_tokens = spaces_between_special_tokens
self.logits_processors = logits_processors
self.include_stop_str_in_output = include_stop_str_in_output
self.truncate_prompt_tokens = truncate_prompt_tokens
self._verify_args()
if self.use_beam_search:
self._verify_beam_search()
@@ -216,6 +222,10 @@ class SamplingParams:
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
raise ValueError(f"prompt_logprobs must be non-negative, got "
f"{self.prompt_logprobs}.")
if (self.truncate_prompt_tokens is not None
and self.truncate_prompt_tokens < 1):
raise ValueError(f"truncate_prompt_tokens must be >= 1, "
f"got {self.truncate_prompt_tokens}")
if self.stop and not self.detokenize:
raise ValueError(
"stop strings are only supported when detokenize is True. "
@@ -300,4 +310,5 @@ class SamplingParams:
f"prompt_logprobs={self.prompt_logprobs}, "
f"skip_special_tokens={self.skip_special_tokens}, "
"spaces_between_special_tokens="
f"{self.spaces_between_special_tokens})")
f"{self.spaces_between_special_tokens}, "
f"truncate_prompt_tokens={self.truncate_prompt_tokens})")