Support Min P Sampler (#1642)

This commit is contained in:
Roy
2023-11-18 08:20:49 +08:00
committed by GitHub
parent dcc543a298
commit e87557b069
2 changed files with 40 additions and 4 deletions

View File

@@ -52,6 +52,9 @@ class SamplingParams:
to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
top_k: Integer that controls the number of top tokens to consider. Set
to -1 to consider all tokens.
min_p: Float that represents the minimum probability for a token to be
considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this.
use_beam_search: Whether to use beam search instead of sampling.
length_penalty: Float that penalizes sequences based on their length.
Used in beam search.
@@ -94,6 +97,7 @@ class SamplingParams:
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: int = 0.0,
use_beam_search: bool = False,
length_penalty: float = 1.0,
early_stopping: Union[bool, str] = False,
@@ -115,6 +119,7 @@ class SamplingParams:
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.min_p = min_p
self.use_beam_search = use_beam_search
self.length_penalty = length_penalty
self.early_stopping = early_stopping
@@ -167,6 +172,9 @@ class SamplingParams:
if self.top_k < -1 or self.top_k == 0:
raise ValueError(f"top_k must be -1 (disable), or at least 1, "
f"got {self.top_k}.")
if not 0.0 <= self.min_p <= 1.0:
raise ValueError("min_p must be in [0, 1], got "
f"{self.min_p}.")
if self.max_tokens < 1:
raise ValueError(
f"max_tokens must be at least 1, got {self.max_tokens}.")
@@ -228,6 +236,7 @@ class SamplingParams:
f"temperature={self.temperature}, "
f"top_p={self.top_p}, "
f"top_k={self.top_k}, "
f"min_p={self.min_p}, "
f"use_beam_search={self.use_beam_search}, "
f"length_penalty={self.length_penalty}, "
f"early_stopping={self.early_stopping}, "