[Frontend] API support for beam search (#9087)

Co-authored-by: youkaichao <youkaichao@126.com>
This commit is contained in:
Brendan Wong
2024-10-05 23:39:03 -07:00
committed by GitHub
parent 23fea8714a
commit 168cab6bbf
12 changed files with 275 additions and 68 deletions

View File

@@ -11,8 +11,8 @@ from typing_extensions import Annotated, Required, TypedDict
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind,
SamplingParams)
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams)
from vllm.sequence import Logprob
from vllm.utils import random_uuid
@@ -288,6 +288,22 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params
def to_beam_search_params(self,
default_max_tokens: int) -> BeamSearchParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
n = self.n if self.n is not None else 1
temperature = self.temperature if self.temperature is not None else 0.0
return BeamSearchParams(
beam_width=n,
max_tokens=max_tokens,
ignore_eos=self.ignore_eos,
temperature=temperature,
)
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
@@ -567,6 +583,22 @@ class CompletionRequest(OpenAIBaseModel):
# doc: end-completion-extra-params
def to_beam_search_params(self,
default_max_tokens: int) -> BeamSearchParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
n = self.n if self.n is not None else 1
temperature = self.temperature if self.temperature is not None else 0.0
return BeamSearchParams(
beam_width=n,
max_tokens=max_tokens,
ignore_eos=self.ignore_eos,
temperature=temperature,
)
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:

View File

@@ -9,6 +9,7 @@ from typing import Union
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage,
apply_hf_chat_template,
@@ -33,6 +34,7 @@ from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
@@ -203,9 +205,15 @@ class OpenAIServingChat(OpenAIServing):
assert prompt_inputs is not None
sampling_params = request.to_sampling_params(
default_max_tokens=self.max_model_len -
len(prompt_inputs["prompt_token_ids"]))
sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len(
prompt_inputs["prompt_token_ids"])
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
default_max_tokens)
else:
sampling_params = request.to_sampling_params(
default_max_tokens)
self._log_inputs(request_id,
prompt_inputs,
@@ -227,15 +235,26 @@ class OpenAIServingChat(OpenAIServing):
and contains_trace_headers(raw_request.headers)):
log_tracing_disabled_warning()
result_generator = self.engine_client.generate(
engine_inputs,
sampling_params,
request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=request.priority,
)
if isinstance(sampling_params, BeamSearchParams):
if not isinstance(self.engine_client, AsyncLLMEngine):
raise ValueError(
"Beam search in the API server is only supported with"
" AsyncLLMEngine. please add "
"`--disable-frontend-multiprocessing` to "
"use beam search.")
result_generator = self.engine_client.beam_search(
engine_inputs['prompt_token_ids'], request_id,
sampling_params)
else:
result_generator = self.engine_client.generate(
engine_inputs,
sampling_params,
request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=request.priority,
)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))

View File

@@ -8,6 +8,7 @@ from typing import Tuple, Union, cast
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
@@ -28,6 +29,7 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
PromptAdapterPath)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
@@ -120,9 +122,15 @@ class OpenAIServingCompletion(OpenAIServing):
))
for i, prompt_inputs in enumerate(prompts):
sampling_params = request.to_sampling_params(
default_max_tokens=self.max_model_len -
len(prompt_inputs["prompt_token_ids"]))
sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len(
prompt_inputs["prompt_token_ids"])
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
default_max_tokens)
else:
sampling_params = request.to_sampling_params(
default_max_tokens)
request_id_item = f"{request_id}-{i}"
@@ -141,15 +149,29 @@ class OpenAIServingCompletion(OpenAIServing):
raw_request.headers):
log_tracing_disabled_warning()
generator = self.engine_client.generate(
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
sampling_params,
request_id_item,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=request.priority,
)
if isinstance(sampling_params, BeamSearchParams):
if not isinstance(self.engine_client, AsyncLLMEngine):
raise ValueError(
"Beam search in the API server is only supported"
" with AsyncLLMEngine. please add "
"`--disable-frontend-multiprocessing` to "
"use beam search.")
generator = self.engine_client.beam_search(
prompt_inputs["prompt_token_ids"], request_id_item,
sampling_params)
else:
generator = self.engine_client.generate(
{
"prompt_token_ids":
prompt_inputs["prompt_token_ids"]
},
sampling_params,
request_id_item,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator)
except ValueError as e:

View File

@@ -29,7 +29,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import AtomicCounter
@@ -371,7 +371,8 @@ class OpenAIServing:
self,
request_id: str,
inputs: Union[str, List[int], TextTokensPrompt],
params: Optional[Union[SamplingParams, PoolingParams]],
params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]],
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None: