Support per-request seed (#2514)

This commit is contained in:
Nick Hill
2024-02-21 11:47:00 -08:00
committed by GitHub
parent dc903e70ac
commit 7d2dcce175
10 changed files with 296 additions and 91 deletions

View File

@@ -11,7 +11,8 @@ _SAMPLING_EPS = 1e-5
class SamplingType(IntEnum):
GREEDY = 0
RANDOM = 1
BEAM = 2
RANDOM_SEED = 2
BEAM = 3
LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
@@ -56,6 +57,7 @@ class SamplingParams:
min_p: Float that represents the minimum probability for a token to be
considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this.
seed: Random seed to use for the generation.
use_beam_search: Whether to use beam search instead of sampling.
length_penalty: Float that penalizes sequences based on their length.
Used in beam search.
@@ -101,6 +103,7 @@ class SamplingParams:
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
seed: Optional[int] = None,
use_beam_search: bool = False,
length_penalty: float = 1.0,
early_stopping: Union[bool, str] = False,
@@ -124,6 +127,7 @@ class SamplingParams:
self.top_p = top_p
self.top_k = top_k
self.min_p = min_p
self.seed = seed
self.use_beam_search = use_beam_search
self.length_penalty = length_penalty
self.early_stopping = early_stopping
@@ -229,6 +233,8 @@ class SamplingParams:
return SamplingType.BEAM
if self.temperature < _SAMPLING_EPS:
return SamplingType.GREEDY
if self.seed is not None:
return SamplingType.RANDOM_SEED
return SamplingType.RANDOM
def __repr__(self) -> str:
@@ -242,6 +248,7 @@ class SamplingParams:
f"top_p={self.top_p}, "
f"top_k={self.top_k}, "
f"min_p={self.min_p}, "
f"seed={self.seed}, "
f"use_beam_search={self.use_beam_search}, "
f"length_penalty={self.length_penalty}, "
f"early_stopping={self.early_stopping}, "