[Frontend] Cleanup serving engine (#33103)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -68,6 +68,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
|
||||
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
|
||||
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.inputs.parse import get_prompt_components
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
@@ -374,20 +375,18 @@ class OpenAIServingChat(OpenAIServing):
|
||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
prompt_text, _, _ = self._get_prompt_components(engine_prompt)
|
||||
prompt_text, _, _ = get_prompt_components(engine_prompt)
|
||||
|
||||
# If we are creating sub requests for multiple prompts, ensure that they
|
||||
# have unique request ids.
|
||||
sub_request_id = (
|
||||
request_id if len(engine_prompts) == 1 else f"{request_id}_{i}"
|
||||
)
|
||||
|
||||
if self.default_sampling_params is None:
|
||||
self.default_sampling_params = {}
|
||||
|
||||
max_tokens = get_max_tokens(
|
||||
max_model_len=self.max_model_len,
|
||||
request=request,
|
||||
input_length=len(engine_prompt["prompt_token_ids"]),
|
||||
prompt=engine_prompt,
|
||||
default_sampling_params=self.default_sampling_params,
|
||||
)
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
|
||||
from vllm.inputs.parse import get_prompt_components
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.outputs import RequestOutput
|
||||
@@ -162,25 +163,12 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
generators: list[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
prompt_text, prompt_token_ids, prompt_embeds = (
|
||||
self._get_prompt_components(engine_prompt)
|
||||
)
|
||||
|
||||
input_length = None
|
||||
if prompt_token_ids is not None:
|
||||
input_length = len(prompt_token_ids)
|
||||
elif prompt_embeds is not None:
|
||||
input_length = len(prompt_embeds)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if self.default_sampling_params is None:
|
||||
self.default_sampling_params = {}
|
||||
prompt_text, _, _ = get_prompt_components(engine_prompt)
|
||||
|
||||
max_tokens = get_max_tokens(
|
||||
max_model_len=self.max_model_len,
|
||||
request=request,
|
||||
input_length=input_length,
|
||||
prompt=engine_prompt,
|
||||
default_sampling_params=self.default_sampling_params,
|
||||
)
|
||||
|
||||
|
||||
@@ -94,11 +94,14 @@ from vllm.entrypoints.serve.tokenize.protocol import (
|
||||
TokenizeCompletionRequest,
|
||||
TokenizeResponse,
|
||||
)
|
||||
from vllm.entrypoints.utils import _validate_truncation_size, sanitize_message
|
||||
from vllm.entrypoints.utils import (
|
||||
_validate_truncation_size,
|
||||
get_max_tokens,
|
||||
sanitize_message,
|
||||
)
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs.data import PromptType, TokensPrompt
|
||||
from vllm.inputs.parse import (
|
||||
PromptComponents,
|
||||
get_prompt_components,
|
||||
is_explicit_encoder_decoder_prompt,
|
||||
)
|
||||
@@ -1287,7 +1290,7 @@ class OpenAIServing:
|
||||
priority: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
prompt_text, _, _ = self._get_prompt_components(engine_prompt)
|
||||
prompt_text, _, _ = get_prompt_components(engine_prompt)
|
||||
|
||||
orig_priority = priority
|
||||
sub_request = 0
|
||||
@@ -1338,10 +1341,12 @@ class OpenAIServing:
|
||||
# yield context
|
||||
|
||||
# Create inputs for the next turn.
|
||||
# Render the next prompt token ids.
|
||||
# Render the next prompt token ids and update sampling_params.
|
||||
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
|
||||
prompt_token_ids = context.render_for_completion()
|
||||
engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
|
||||
token_ids = context.render_for_completion()
|
||||
engine_prompt = TokensPrompt(prompt_token_ids=token_ids)
|
||||
|
||||
sampling_params.max_tokens = self.max_model_len - len(token_ids)
|
||||
elif isinstance(context, ParsableContext):
|
||||
engine_prompts = await self._render_next_turn(
|
||||
context.request,
|
||||
@@ -1353,19 +1358,19 @@ class OpenAIServing:
|
||||
context.chat_template_content_format,
|
||||
)
|
||||
engine_prompt = engine_prompts[0]
|
||||
prompt_text, _, _ = self._get_prompt_components(engine_prompt)
|
||||
prompt_text, _, _ = get_prompt_components(engine_prompt)
|
||||
|
||||
sampling_params.max_tokens = get_max_tokens(
|
||||
self.max_model_len,
|
||||
context.request,
|
||||
engine_prompt,
|
||||
self.default_sampling_params, # type: ignore
|
||||
)
|
||||
|
||||
# Update the sampling params.
|
||||
sampling_params.max_tokens = self.max_model_len - len(
|
||||
engine_prompt["prompt_token_ids"]
|
||||
)
|
||||
# OPTIMIZATION
|
||||
priority = orig_priority - 1
|
||||
sub_request += 1
|
||||
|
||||
def _get_prompt_components(self, prompt: PromptType) -> PromptComponents:
|
||||
return get_prompt_components(prompt)
|
||||
|
||||
def _log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
@@ -1376,7 +1381,7 @@ class OpenAIServing:
|
||||
if self.request_logger is None:
|
||||
return
|
||||
|
||||
prompt, prompt_token_ids, prompt_embeds = self._get_prompt_components(inputs)
|
||||
prompt, prompt_token_ids, prompt_embeds = get_prompt_components(inputs)
|
||||
|
||||
self.request_logger.log_inputs(
|
||||
request_id,
|
||||
|
||||
@@ -116,6 +116,7 @@ from vllm.entrypoints.openai.responses.utils import (
|
||||
extract_tool_types,
|
||||
should_continue_final_message,
|
||||
)
|
||||
from vllm.entrypoints.utils import get_max_tokens
|
||||
from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
@@ -435,8 +436,11 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
if maybe_error is not None:
|
||||
return maybe_error
|
||||
|
||||
default_max_tokens = self.max_model_len - len(
|
||||
engine_prompt["prompt_token_ids"]
|
||||
default_max_tokens = get_max_tokens(
|
||||
self.max_model_len,
|
||||
request,
|
||||
engine_prompt,
|
||||
self.default_sampling_params,
|
||||
)
|
||||
|
||||
sampling_params = request.to_sampling_params(
|
||||
|
||||
@@ -17,8 +17,10 @@ from starlette.background import BackgroundTask, BackgroundTasks
|
||||
|
||||
from vllm import envs
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.inputs import EmbedsPrompt, TokensPrompt
|
||||
from vllm.logger import current_formatter_type, init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -32,11 +34,15 @@ if TYPE_CHECKING:
|
||||
StreamOptions,
|
||||
)
|
||||
from vllm.entrypoints.openai.models.protocol import LoRAModulePath
|
||||
from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponsesRequest,
|
||||
)
|
||||
else:
|
||||
ChatCompletionRequest = object
|
||||
CompletionRequest = object
|
||||
StreamOptions = object
|
||||
LoRAModulePath = object
|
||||
ResponsesRequest = object
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -211,11 +217,26 @@ def _validate_truncation_size(
|
||||
|
||||
def get_max_tokens(
|
||||
max_model_len: int,
|
||||
request: "ChatCompletionRequest | CompletionRequest",
|
||||
input_length: int,
|
||||
request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest",
|
||||
prompt: TokensPrompt | EmbedsPrompt,
|
||||
default_sampling_params: dict,
|
||||
) -> int:
|
||||
max_tokens = getattr(request, "max_completion_tokens", None) or request.max_tokens
|
||||
# NOTE: Avoid isinstance() for better efficiency
|
||||
max_tokens: int | None = None
|
||||
if max_tokens is None:
|
||||
# ChatCompletionRequest
|
||||
max_tokens = getattr(request, "max_completion_tokens", None)
|
||||
if max_tokens is None:
|
||||
# ResponsesRequest
|
||||
max_tokens = getattr(request, "max_output_tokens", None)
|
||||
if max_tokens is None:
|
||||
# CompletionRequest (also a fallback for ChatCompletionRequest)
|
||||
max_tokens = getattr(request, "max_tokens", None)
|
||||
|
||||
input_length = length_from_prompt_token_ids_or_embeds(
|
||||
prompt.get("prompt_token_ids"), # type: ignore[arg-type]
|
||||
prompt.get("prompt_embeds"), # type: ignore[arg-type]
|
||||
)
|
||||
default_max_tokens = max_model_len - input_length
|
||||
max_output_tokens = current_platform.get_max_output_tokens(input_length)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user