Fix performance when --generation-config is not None (#14223)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -244,6 +244,7 @@ class LLM:
|
|||||||
engine_args, usage_context=UsageContext.LLM_CLASS)
|
engine_args, usage_context=UsageContext.LLM_CLASS)
|
||||||
|
|
||||||
self.request_counter = Counter()
|
self.request_counter = Counter()
|
||||||
|
self.default_sampling_params: Union[dict[str, Any], None] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_engine_class() -> type[LLMEngine]:
|
def get_engine_class() -> type[LLMEngine]:
|
||||||
@@ -268,10 +269,11 @@ class LLM:
|
|||||||
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
|
tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
|
||||||
|
|
||||||
def get_default_sampling_params(self) -> SamplingParams:
|
def get_default_sampling_params(self) -> SamplingParams:
|
||||||
diff_sampling_param = (
|
if self.default_sampling_params is None:
|
||||||
self.llm_engine.model_config.get_diff_sampling_param())
|
self.default_sampling_params = (
|
||||||
if diff_sampling_param:
|
self.llm_engine.model_config.get_diff_sampling_param())
|
||||||
return SamplingParams.from_optional(**diff_sampling_param)
|
if self.default_sampling_params:
|
||||||
|
return SamplingParams.from_optional(**self.default_sampling_params)
|
||||||
return SamplingParams()
|
return SamplingParams()
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
|||||||
@@ -105,10 +105,11 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
"been registered") from e
|
"been registered") from e
|
||||||
|
|
||||||
self.enable_prompt_tokens_details = enable_prompt_tokens_details
|
self.enable_prompt_tokens_details = enable_prompt_tokens_details
|
||||||
diff_sampling_param = self.model_config.get_diff_sampling_param()
|
self.default_sampling_params = (
|
||||||
if diff_sampling_param:
|
self.model_config.get_diff_sampling_param())
|
||||||
|
if self.default_sampling_params:
|
||||||
logger.info("Overwriting default chat sampling param with: %s",
|
logger.info("Overwriting default chat sampling param with: %s",
|
||||||
diff_sampling_param)
|
self.default_sampling_params)
|
||||||
|
|
||||||
async def create_chat_completion(
|
async def create_chat_completion(
|
||||||
self,
|
self,
|
||||||
@@ -210,17 +211,14 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||||
default_max_tokens = self.max_model_len - len(
|
default_max_tokens = self.max_model_len - len(
|
||||||
engine_prompt["prompt_token_ids"])
|
engine_prompt["prompt_token_ids"])
|
||||||
# Build default sampling params
|
|
||||||
default_sampling_params = (
|
|
||||||
self.model_config.get_diff_sampling_param())
|
|
||||||
if request.use_beam_search:
|
if request.use_beam_search:
|
||||||
sampling_params = request.to_beam_search_params(
|
sampling_params = request.to_beam_search_params(
|
||||||
default_max_tokens, default_sampling_params)
|
default_max_tokens, self.default_sampling_params)
|
||||||
else:
|
else:
|
||||||
sampling_params = request.to_sampling_params(
|
sampling_params = request.to_sampling_params(
|
||||||
default_max_tokens,
|
default_max_tokens,
|
||||||
self.model_config.logits_processor_pattern,
|
self.model_config.logits_processor_pattern,
|
||||||
default_sampling_params)
|
self.default_sampling_params)
|
||||||
|
|
||||||
self._log_inputs(request_id,
|
self._log_inputs(request_id,
|
||||||
request_prompts[i],
|
request_prompts[i],
|
||||||
|
|||||||
@@ -51,11 +51,12 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
models=models,
|
models=models,
|
||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
||||||
diff_sampling_param = self.model_config.get_diff_sampling_param()
|
self.default_sampling_params = (
|
||||||
if diff_sampling_param:
|
self.model_config.get_diff_sampling_param())
|
||||||
|
if self.default_sampling_params:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Overwriting default completion sampling param with: %s",
|
"Overwriting default completion sampling param with: %s",
|
||||||
diff_sampling_param)
|
self.default_sampling_params)
|
||||||
|
|
||||||
async def create_completion(
|
async def create_completion(
|
||||||
self,
|
self,
|
||||||
@@ -119,17 +120,14 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||||
default_max_tokens = self.max_model_len - len(
|
default_max_tokens = self.max_model_len - len(
|
||||||
engine_prompt["prompt_token_ids"])
|
engine_prompt["prompt_token_ids"])
|
||||||
# Build default sampling params
|
|
||||||
default_sampling_params = (
|
|
||||||
self.model_config.get_diff_sampling_param())
|
|
||||||
if request.use_beam_search:
|
if request.use_beam_search:
|
||||||
sampling_params = request.to_beam_search_params(
|
sampling_params = request.to_beam_search_params(
|
||||||
default_max_tokens, default_sampling_params)
|
default_max_tokens, self.default_sampling_params)
|
||||||
else:
|
else:
|
||||||
sampling_params = request.to_sampling_params(
|
sampling_params = request.to_sampling_params(
|
||||||
default_max_tokens,
|
default_max_tokens,
|
||||||
self.model_config.logits_processor_pattern,
|
self.model_config.logits_processor_pattern,
|
||||||
default_sampling_params)
|
self.default_sampling_params)
|
||||||
|
|
||||||
request_id_item = f"{request_id}-{i}"
|
request_id_item = f"{request_id}-{i}"
|
||||||
|
|
||||||
|
|||||||
@@ -161,11 +161,12 @@ class OpenAIServingTranscription(OpenAIServing):
|
|||||||
request_logger=request_logger,
|
request_logger=request_logger,
|
||||||
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
||||||
|
|
||||||
diff_sampling_param = self.model_config.get_diff_sampling_param()
|
self.default_sampling_params = (
|
||||||
if diff_sampling_param:
|
self.model_config.get_diff_sampling_param())
|
||||||
|
if self.default_sampling_params:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Overwriting default completion sampling param with: %s",
|
"Overwriting default completion sampling param with: %s",
|
||||||
diff_sampling_param)
|
self.default_sampling_params)
|
||||||
|
|
||||||
async def _preprocess_transcription(
|
async def _preprocess_transcription(
|
||||||
self,
|
self,
|
||||||
@@ -273,9 +274,8 @@ class OpenAIServingTranscription(OpenAIServing):
|
|||||||
try:
|
try:
|
||||||
# TODO(rob): subtract len of tokenized prompt.
|
# TODO(rob): subtract len of tokenized prompt.
|
||||||
default_max_tokens = self.model_config.max_model_len
|
default_max_tokens = self.model_config.max_model_len
|
||||||
default_params = self.model_config.get_diff_sampling_param()
|
|
||||||
sampling_params = request.to_sampling_params(
|
sampling_params = request.to_sampling_params(
|
||||||
default_max_tokens, default_params)
|
default_max_tokens, self.default_sampling_params)
|
||||||
|
|
||||||
self._log_inputs(
|
self._log_inputs(
|
||||||
request_id,
|
request_id,
|
||||||
|
|||||||
Reference in New Issue
Block a user