[Frontend] re-enable multi-modality input in the new beam search implementation (#9427)

Signed-off-by: Qishuai Ferdinandzhong@gmail.com
This commit is contained in:
Zhong Qishuai
2024-10-29 19:49:47 +08:00
committed by GitHub
parent eae3d48181
commit ef7865b4f9
7 changed files with 150 additions and 40 deletions

View File

@@ -308,7 +308,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
)
include_stop_str_in_output=self.include_stop_str_in_output)
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
@@ -606,7 +606,7 @@ class CompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
)
include_stop_str_in_output=self.include_stop_str_in_output)
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens

View File

@@ -236,9 +236,10 @@ class OpenAIServingChat(OpenAIServing):
if isinstance(sampling_params, BeamSearchParams):
result_generator = self.engine_client.beam_search(
engine_inputs['prompt_token_ids'],
request_id,
sampling_params,
prompt=engine_inputs,
model_config=self.model_config,
request_id=request_id,
params=sampling_params,
)
else:
result_generator = self.engine_client.generate(

View File

@@ -150,9 +150,13 @@ class OpenAIServingCompletion(OpenAIServing):
if isinstance(sampling_params, BeamSearchParams):
generator = self.engine_client.beam_search(
prompt_inputs["prompt_token_ids"],
request_id_item,
sampling_params,
prompt={
"prompt_token_ids":
prompt_inputs["prompt_token_ids"]
},
model_config=self.model_config,
request_id=request_id,
params=sampling_params,
)
else:
generator = self.engine_client.generate(