[Feature] Add load generation config from model (#11164)
Signed-off-by: liuyanyi <wolfsonliu@163.com> Signed-off-by: Yanyi Liu <wolfsonliu@163.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -258,6 +258,13 @@ class LLM:
|
||||
else:
|
||||
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
|
||||
|
||||
def get_default_sampling_params(self) -> SamplingParams:
|
||||
diff_sampling_param = (
|
||||
self.llm_engine.model_config.get_diff_sampling_param())
|
||||
if diff_sampling_param:
|
||||
return SamplingParams.from_optional(**diff_sampling_param)
|
||||
return SamplingParams()
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
@@ -441,7 +448,7 @@ class LLM:
|
||||
|
||||
if sampling_params is None:
|
||||
# Use default sampling params.
|
||||
sampling_params = SamplingParams()
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=parsed_prompts,
|
||||
|
||||
@@ -211,8 +211,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||
stream: Optional[bool] = False
|
||||
stream_options: Optional[StreamOptions] = None
|
||||
temperature: Optional[float] = 1.0
|
||||
top_p: Optional[float] = 1.0
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
tools: Optional[List[ChatCompletionToolsParam]] = None
|
||||
tool_choice: Optional[Union[Literal["none"], Literal["auto"],
|
||||
ChatCompletionNamedToolChoiceParam]] = "none"
|
||||
@@ -224,9 +224,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
# doc: begin-chat-completion-sampling-params
|
||||
best_of: Optional[int] = None
|
||||
use_beam_search: bool = False
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
repetition_penalty: float = 1.0
|
||||
top_k: Optional[int] = None
|
||||
min_p: Optional[float] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
length_penalty: float = 1.0
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
include_stop_str_in_output: bool = False
|
||||
@@ -348,15 +348,32 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
|
||||
# doc: end-chat-completion-extra-params
|
||||
|
||||
def to_beam_search_params(self,
|
||||
default_max_tokens: int) -> BeamSearchParams:
|
||||
# Default sampling parameters for chat completion requests
|
||||
_DEFAULT_SAMPLING_PARAMS: dict = {
|
||||
"repetition_penalty": 1.0,
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
"top_k": -1,
|
||||
"min_p": 0.0,
|
||||
}
|
||||
|
||||
def to_beam_search_params(
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
default_sampling_params: Optional[dict] = None
|
||||
) -> BeamSearchParams:
|
||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
||||
max_tokens = self.max_completion_tokens or self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
n = self.n if self.n is not None else 1
|
||||
temperature = self.temperature if self.temperature is not None else 0.0
|
||||
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get(
|
||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||
|
||||
return BeamSearchParams(
|
||||
beam_width=n,
|
||||
@@ -367,13 +384,36 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
include_stop_str_in_output=self.include_stop_str_in_output)
|
||||
|
||||
def to_sampling_params(
|
||||
self, default_max_tokens: int,
|
||||
logits_processor_pattern: Optional[str]) -> SamplingParams:
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
logits_processor_pattern: Optional[str],
|
||||
default_sampling_params: Optional[dict] = None) -> SamplingParams:
|
||||
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
|
||||
max_tokens = self.max_completion_tokens or self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
# Default parameters
|
||||
if (repetition_penalty := self.repetition_penalty) is None:
|
||||
repetition_penalty = default_sampling_params.get(
|
||||
"repetition_penalty",
|
||||
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
|
||||
)
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get(
|
||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||
if (top_p := self.top_p) is None:
|
||||
top_p = default_sampling_params.get(
|
||||
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
|
||||
if (top_k := self.top_k) is None:
|
||||
top_k = default_sampling_params.get(
|
||||
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
|
||||
if (min_p := self.min_p) is None:
|
||||
min_p = default_sampling_params.get(
|
||||
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])
|
||||
|
||||
prompt_logprobs = self.prompt_logprobs
|
||||
if prompt_logprobs is None and self.echo:
|
||||
prompt_logprobs = self.top_logprobs
|
||||
@@ -403,11 +443,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
best_of=self.best_of,
|
||||
presence_penalty=self.presence_penalty,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
top_k=self.top_k,
|
||||
min_p=self.min_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
seed=self.seed,
|
||||
stop=self.stop,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
@@ -584,15 +624,15 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
stream: Optional[bool] = False
|
||||
stream_options: Optional[StreamOptions] = None
|
||||
suffix: Optional[str] = None
|
||||
temperature: Optional[float] = 1.0
|
||||
top_p: Optional[float] = 1.0
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
# doc: begin-completion-sampling-params
|
||||
use_beam_search: bool = False
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
repetition_penalty: float = 1.0
|
||||
top_k: Optional[int] = None
|
||||
min_p: Optional[float] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
length_penalty: float = 1.0
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
include_stop_str_in_output: bool = False
|
||||
@@ -669,14 +709,30 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
|
||||
# doc: end-completion-extra-params
|
||||
|
||||
def to_beam_search_params(self,
|
||||
default_max_tokens: int) -> BeamSearchParams:
|
||||
# Default sampling parameters for completion requests
|
||||
_DEFAULT_SAMPLING_PARAMS: dict = {
|
||||
"repetition_penalty": 1.0,
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
"top_k": -1,
|
||||
"min_p": 0.0,
|
||||
}
|
||||
|
||||
def to_beam_search_params(
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
default_sampling_params: Optional[dict] = None
|
||||
) -> BeamSearchParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
n = self.n if self.n is not None else 1
|
||||
temperature = self.temperature if self.temperature is not None else 0.0
|
||||
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get("temperature", 1.0)
|
||||
|
||||
return BeamSearchParams(
|
||||
beam_width=n,
|
||||
@@ -687,12 +743,35 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
include_stop_str_in_output=self.include_stop_str_in_output)
|
||||
|
||||
def to_sampling_params(
|
||||
self, default_max_tokens: int,
|
||||
logits_processor_pattern: Optional[str]) -> SamplingParams:
|
||||
self,
|
||||
default_max_tokens: int,
|
||||
logits_processor_pattern: Optional[str],
|
||||
default_sampling_params: Optional[dict] = None) -> SamplingParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
if default_sampling_params is None:
|
||||
default_sampling_params = {}
|
||||
# Default parameters
|
||||
if (repetition_penalty := self.repetition_penalty) is None:
|
||||
repetition_penalty = default_sampling_params.get(
|
||||
"repetition_penalty",
|
||||
self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"],
|
||||
)
|
||||
if (temperature := self.temperature) is None:
|
||||
temperature = default_sampling_params.get(
|
||||
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||
if (top_p := self.top_p) is None:
|
||||
top_p = default_sampling_params.get(
|
||||
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
|
||||
if (top_k := self.top_k) is None:
|
||||
top_k = default_sampling_params.get(
|
||||
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"])
|
||||
if (min_p := self.min_p) is None:
|
||||
min_p = default_sampling_params.get(
|
||||
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"])
|
||||
|
||||
prompt_logprobs = self.prompt_logprobs
|
||||
if prompt_logprobs is None and self.echo:
|
||||
prompt_logprobs = self.logprobs
|
||||
@@ -718,11 +797,11 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
best_of=self.best_of,
|
||||
presence_penalty=self.presence_penalty,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
top_k=self.top_k,
|
||||
min_p=self.min_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
seed=self.seed,
|
||||
stop=self.stop,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
|
||||
@@ -91,6 +91,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
"been registered") from e
|
||||
|
||||
self.enable_prompt_tokens_details = enable_prompt_tokens_details
|
||||
diff_sampling_param = self.model_config.get_diff_sampling_param()
|
||||
if diff_sampling_param:
|
||||
logger.info("Overwriting default chat sampling param with: %s",
|
||||
diff_sampling_param)
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
@@ -191,13 +195,17 @@ class OpenAIServingChat(OpenAIServing):
|
||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||
default_max_tokens = self.max_model_len - len(
|
||||
engine_prompt["prompt_token_ids"])
|
||||
# Build default sampling params
|
||||
default_sampling_params = (
|
||||
self.model_config.get_diff_sampling_param())
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
default_max_tokens)
|
||||
default_max_tokens, default_sampling_params)
|
||||
else:
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens,
|
||||
self.model_config.logits_processor_pattern)
|
||||
self.model_config.logits_processor_pattern,
|
||||
default_sampling_params)
|
||||
|
||||
self._log_inputs(request_id,
|
||||
request_prompts[i],
|
||||
|
||||
@@ -55,6 +55,11 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
prompt_adapters=prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
||||
diff_sampling_param = self.model_config.get_diff_sampling_param()
|
||||
if diff_sampling_param:
|
||||
logger.info(
|
||||
"Overwriting default completion sampling param with: %s",
|
||||
diff_sampling_param)
|
||||
|
||||
async def create_completion(
|
||||
self,
|
||||
@@ -118,13 +123,17 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||
default_max_tokens = self.max_model_len - len(
|
||||
engine_prompt["prompt_token_ids"])
|
||||
# Build default sampling params
|
||||
default_sampling_params = (
|
||||
self.model_config.get_diff_sampling_param())
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
default_max_tokens)
|
||||
default_max_tokens, default_sampling_params)
|
||||
else:
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens,
|
||||
self.model_config.logits_processor_pattern)
|
||||
self.model_config.logits_processor_pattern,
|
||||
default_sampling_params)
|
||||
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user