[Frontend] API support for beam search (#9087)

Co-authored-by: youkaichao <youkaichao@126.com>
This commit is contained in:
Brendan Wong
2024-10-05 23:39:03 -07:00
committed by GitHub
parent 23fea8714a
commit 168cab6bbf
12 changed files with 275 additions and 68 deletions

View File

@@ -8,6 +8,7 @@ from typing import Tuple, Union, cast
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
@@ -28,6 +29,7 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
PromptAdapterPath)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
@@ -120,9 +122,15 @@ class OpenAIServingCompletion(OpenAIServing):
))
for i, prompt_inputs in enumerate(prompts):
sampling_params = request.to_sampling_params(
default_max_tokens=self.max_model_len -
len(prompt_inputs["prompt_token_ids"]))
sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len(
prompt_inputs["prompt_token_ids"])
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
default_max_tokens)
else:
sampling_params = request.to_sampling_params(
default_max_tokens)
request_id_item = f"{request_id}-{i}"
@@ -141,15 +149,29 @@ class OpenAIServingCompletion(OpenAIServing):
raw_request.headers):
log_tracing_disabled_warning()
generator = self.engine_client.generate(
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
sampling_params,
request_id_item,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=request.priority,
)
if isinstance(sampling_params, BeamSearchParams):
if not isinstance(self.engine_client, AsyncLLMEngine):
raise ValueError(
"Beam search in the API server is only supported"
" with AsyncLLMEngine. please add "
"`--disable-frontend-multiprocessing` to "
"use beam search.")
generator = self.engine_client.beam_search(
prompt_inputs["prompt_token_ids"], request_id_item,
sampling_params)
else:
generator = self.engine_client.generate(
{
"prompt_token_ids":
prompt_inputs["prompt_token_ids"]
},
sampling_params,
request_id_item,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator)
except ValueError as e: