[mypy] Enable following imports for entrypoints (#7248)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Fei <dfdfcai4@gmail.com>
This commit is contained in:
@@ -153,6 +153,68 @@ class SamplingParams(
|
||||
output_text_buffer_length: int = 0
|
||||
_all_stop_token_ids: Set[int] = msgspec.field(default_factory=set)
|
||||
|
||||
@staticmethod
|
||||
def from_optional(
|
||||
n: Optional[int] = 1,
|
||||
best_of: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = 0.0,
|
||||
frequency_penalty: Optional[float] = 0.0,
|
||||
repetition_penalty: Optional[float] = 1.0,
|
||||
temperature: Optional[float] = 1.0,
|
||||
top_p: Optional[float] = 1.0,
|
||||
top_k: int = -1,
|
||||
min_p: float = 0.0,
|
||||
seed: Optional[int] = None,
|
||||
use_beam_search: bool = False,
|
||||
length_penalty: float = 1.0,
|
||||
early_stopping: Union[bool, str] = False,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
include_stop_str_in_output: bool = False,
|
||||
ignore_eos: bool = False,
|
||||
max_tokens: Optional[int] = 16,
|
||||
min_tokens: int = 0,
|
||||
logprobs: Optional[int] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
detokenize: bool = True,
|
||||
skip_special_tokens: bool = True,
|
||||
spaces_between_special_tokens: bool = True,
|
||||
logits_processors: Optional[List[LogitsProcessor]] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int,
|
||||
msgspec.Meta(ge=1)]] = None,
|
||||
) -> "SamplingParams":
|
||||
return SamplingParams(
|
||||
n=1 if n is None else n,
|
||||
best_of=best_of,
|
||||
presence_penalty=0.0
|
||||
if presence_penalty is None else presence_penalty,
|
||||
frequency_penalty=0.0
|
||||
if frequency_penalty is None else frequency_penalty,
|
||||
repetition_penalty=1.0
|
||||
if repetition_penalty is None else repetition_penalty,
|
||||
temperature=1.0 if temperature is None else temperature,
|
||||
top_p=1.0 if top_p is None else top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
seed=seed,
|
||||
use_beam_search=use_beam_search,
|
||||
length_penalty=length_penalty,
|
||||
early_stopping=early_stopping,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
ignore_eos=ignore_eos,
|
||||
max_tokens=max_tokens,
|
||||
min_tokens=min_tokens,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
detokenize=detokenize,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
spaces_between_special_tokens=spaces_between_special_tokens,
|
||||
logits_processors=logits_processors,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.best_of = self.best_of or self.n
|
||||
if 0 < self.temperature < _MAX_TEMP:
|
||||
|
||||
Reference in New Issue
Block a user