Support per-request seed (#2514)
This commit is contained in:
@@ -11,7 +11,8 @@ _SAMPLING_EPS = 1e-5
|
||||
class SamplingType(IntEnum):
|
||||
GREEDY = 0
|
||||
RANDOM = 1
|
||||
BEAM = 2
|
||||
RANDOM_SEED = 2
|
||||
BEAM = 3
|
||||
|
||||
|
||||
LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
|
||||
@@ -56,6 +57,7 @@ class SamplingParams:
|
||||
min_p: Float that represents the minimum probability for a token to be
|
||||
considered, relative to the probability of the most likely token.
|
||||
Must be in [0, 1]. Set to 0 to disable this.
|
||||
seed: Random seed to use for the generation.
|
||||
use_beam_search: Whether to use beam search instead of sampling.
|
||||
length_penalty: Float that penalizes sequences based on their length.
|
||||
Used in beam search.
|
||||
@@ -101,6 +103,7 @@ class SamplingParams:
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
min_p: float = 0.0,
|
||||
seed: Optional[int] = None,
|
||||
use_beam_search: bool = False,
|
||||
length_penalty: float = 1.0,
|
||||
early_stopping: Union[bool, str] = False,
|
||||
@@ -124,6 +127,7 @@ class SamplingParams:
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.min_p = min_p
|
||||
self.seed = seed
|
||||
self.use_beam_search = use_beam_search
|
||||
self.length_penalty = length_penalty
|
||||
self.early_stopping = early_stopping
|
||||
@@ -229,6 +233,8 @@ class SamplingParams:
|
||||
return SamplingType.BEAM
|
||||
if self.temperature < _SAMPLING_EPS:
|
||||
return SamplingType.GREEDY
|
||||
if self.seed is not None:
|
||||
return SamplingType.RANDOM_SEED
|
||||
return SamplingType.RANDOM
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -242,6 +248,7 @@ class SamplingParams:
|
||||
f"top_p={self.top_p}, "
|
||||
f"top_k={self.top_k}, "
|
||||
f"min_p={self.min_p}, "
|
||||
f"seed={self.seed}, "
|
||||
f"use_beam_search={self.use_beam_search}, "
|
||||
f"length_penalty={self.length_penalty}, "
|
||||
f"early_stopping={self.early_stopping}, "
|
||||
|
||||
Reference in New Issue
Block a user