[Frontend] API support for beam search (#9087)
Co-authored-by: youkaichao <youkaichao@126.com>
This commit is contained in:
@@ -11,8 +11,8 @@ from typing_extensions import Annotated, Required, TypedDict
|
||||
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind,
|
||||
SamplingParams)
|
||||
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
||||
RequestOutputKind, SamplingParams)
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
@@ -288,6 +288,22 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
|
||||
# doc: end-chat-completion-extra-params
|
||||
|
||||
def to_beam_search_params(self,
|
||||
default_max_tokens: int) -> BeamSearchParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
n = self.n if self.n is not None else 1
|
||||
temperature = self.temperature if self.temperature is not None else 0.0
|
||||
|
||||
return BeamSearchParams(
|
||||
beam_width=n,
|
||||
max_tokens=max_tokens,
|
||||
ignore_eos=self.ignore_eos,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
@@ -567,6 +583,22 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
|
||||
# doc: end-completion-extra-params
|
||||
|
||||
def to_beam_search_params(self,
|
||||
default_max_tokens: int) -> BeamSearchParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
n = self.n if self.n is not None else 1
|
||||
temperature = self.temperature if self.temperature is not None else 0.0
|
||||
|
||||
return BeamSearchParams(
|
||||
beam_width=n,
|
||||
max_tokens=max_tokens,
|
||||
ignore_eos=self.ignore_eos,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import Union
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
||||
apply_hf_chat_template,
|
||||
@@ -33,6 +34,7 @@ from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
log_tracing_disabled_warning)
|
||||
@@ -203,9 +205,15 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
assert prompt_inputs is not None
|
||||
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens=self.max_model_len -
|
||||
len(prompt_inputs["prompt_token_ids"]))
|
||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||
default_max_tokens = self.max_model_len - len(
|
||||
prompt_inputs["prompt_token_ids"])
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
default_max_tokens)
|
||||
else:
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens)
|
||||
|
||||
self._log_inputs(request_id,
|
||||
prompt_inputs,
|
||||
@@ -227,15 +235,26 @@ class OpenAIServingChat(OpenAIServing):
|
||||
and contains_trace_headers(raw_request.headers)):
|
||||
log_tracing_disabled_warning()
|
||||
|
||||
result_generator = self.engine_client.generate(
|
||||
engine_inputs,
|
||||
sampling_params,
|
||||
request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=request.priority,
|
||||
)
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
if not isinstance(self.engine_client, AsyncLLMEngine):
|
||||
raise ValueError(
|
||||
"Beam search in the API server is only supported with"
|
||||
" AsyncLLMEngine. please add "
|
||||
"`--disable-frontend-multiprocessing` to "
|
||||
"use beam search.")
|
||||
result_generator = self.engine_client.beam_search(
|
||||
engine_inputs['prompt_token_ids'], request_id,
|
||||
sampling_params)
|
||||
else:
|
||||
result_generator = self.engine_client.generate(
|
||||
engine_inputs,
|
||||
sampling_params,
|
||||
request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=request.priority,
|
||||
)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Tuple, Union, cast
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
@@ -28,6 +29,7 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||
PromptAdapterPath)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
log_tracing_disabled_warning)
|
||||
@@ -120,9 +122,15 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
))
|
||||
|
||||
for i, prompt_inputs in enumerate(prompts):
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens=self.max_model_len -
|
||||
len(prompt_inputs["prompt_token_ids"]))
|
||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||
default_max_tokens = self.max_model_len - len(
|
||||
prompt_inputs["prompt_token_ids"])
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
default_max_tokens)
|
||||
else:
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens)
|
||||
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
@@ -141,15 +149,29 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
raw_request.headers):
|
||||
log_tracing_disabled_warning()
|
||||
|
||||
generator = self.engine_client.generate(
|
||||
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
if not isinstance(self.engine_client, AsyncLLMEngine):
|
||||
raise ValueError(
|
||||
"Beam search in the API server is only supported"
|
||||
" with AsyncLLMEngine. please add "
|
||||
"`--disable-frontend-multiprocessing` to "
|
||||
"use beam search.")
|
||||
generator = self.engine_client.beam_search(
|
||||
prompt_inputs["prompt_token_ids"], request_id_item,
|
||||
sampling_params)
|
||||
else:
|
||||
generator = self.engine_client.generate(
|
||||
{
|
||||
"prompt_token_ids":
|
||||
prompt_inputs["prompt_token_ids"]
|
||||
},
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -29,7 +29,7 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import AtomicCounter
|
||||
@@ -371,7 +371,8 @@ class OpenAIServing:
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: Union[str, List[int], TextTokensPrompt],
|
||||
params: Optional[Union[SamplingParams, PoolingParams]],
|
||||
params: Optional[Union[SamplingParams, PoolingParams,
|
||||
BeamSearchParams]],
|
||||
lora_request: Optional[LoRARequest],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
) -> None:
|
||||
|
||||
Reference in New Issue
Block a user