[Frontend] Cleanup serving engine (#33103)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-27 12:47:26 +08:00
committed by GitHub
parent 3b8f0fe59e
commit e0b005d9cf
5 changed files with 57 additions and 40 deletions

View File

@@ -68,6 +68,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import TokensPrompt
from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput
@@ -374,20 +375,18 @@ class OpenAIServingChat(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
prompt_text, _, _ = self._get_prompt_components(engine_prompt)
prompt_text, _, _ = get_prompt_components(engine_prompt)
# If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids.
sub_request_id = (
request_id if len(engine_prompts) == 1 else f"{request_id}_{i}"
)
if self.default_sampling_params is None:
self.default_sampling_params = {}
max_tokens = get_max_tokens(
max_model_len=self.max_model_len,
request=request,
input_length=len(engine_prompt["prompt_token_ids"]),
prompt=engine_prompt,
default_sampling_params=self.default_sampling_params,
)

View File

@@ -36,6 +36,7 @@ from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt
from vllm.inputs.parse import get_prompt_components
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
@@ -162,25 +163,12 @@ class OpenAIServingCompletion(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
prompt_text, prompt_token_ids, prompt_embeds = (
self._get_prompt_components(engine_prompt)
)
input_length = None
if prompt_token_ids is not None:
input_length = len(prompt_token_ids)
elif prompt_embeds is not None:
input_length = len(prompt_embeds)
else:
raise NotImplementedError
if self.default_sampling_params is None:
self.default_sampling_params = {}
prompt_text, _, _ = get_prompt_components(engine_prompt)
max_tokens = get_max_tokens(
max_model_len=self.max_model_len,
request=request,
input_length=input_length,
prompt=engine_prompt,
default_sampling_params=self.default_sampling_params,
)

View File

@@ -94,11 +94,14 @@ from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeCompletionRequest,
TokenizeResponse,
)
from vllm.entrypoints.utils import _validate_truncation_size, sanitize_message
from vllm.entrypoints.utils import (
_validate_truncation_size,
get_max_tokens,
sanitize_message,
)
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.parse import (
PromptComponents,
get_prompt_components,
is_explicit_encoder_decoder_prompt,
)
@@ -1287,7 +1290,7 @@ class OpenAIServing:
priority: int = 0,
**kwargs,
):
prompt_text, _, _ = self._get_prompt_components(engine_prompt)
prompt_text, _, _ = get_prompt_components(engine_prompt)
orig_priority = priority
sub_request = 0
@@ -1338,10 +1341,12 @@ class OpenAIServing:
# yield context
# Create inputs for the next turn.
# Render the next prompt token ids.
# Render the next prompt token ids and update sampling_params.
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
prompt_token_ids = context.render_for_completion()
engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
token_ids = context.render_for_completion()
engine_prompt = TokensPrompt(prompt_token_ids=token_ids)
sampling_params.max_tokens = self.max_model_len - len(token_ids)
elif isinstance(context, ParsableContext):
engine_prompts = await self._render_next_turn(
context.request,
@@ -1353,19 +1358,19 @@ class OpenAIServing:
context.chat_template_content_format,
)
engine_prompt = engine_prompts[0]
prompt_text, _, _ = self._get_prompt_components(engine_prompt)
prompt_text, _, _ = get_prompt_components(engine_prompt)
sampling_params.max_tokens = get_max_tokens(
self.max_model_len,
context.request,
engine_prompt,
self.default_sampling_params, # type: ignore
)
# Update the sampling params.
sampling_params.max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"]
)
# OPTIMIZATION
priority = orig_priority - 1
sub_request += 1
def _get_prompt_components(self, prompt: PromptType) -> PromptComponents:
return get_prompt_components(prompt)
def _log_inputs(
self,
request_id: str,
@@ -1376,7 +1381,7 @@ class OpenAIServing:
if self.request_logger is None:
return
prompt, prompt_token_ids, prompt_embeds = self._get_prompt_components(inputs)
prompt, prompt_token_ids, prompt_embeds = get_prompt_components(inputs)
self.request_logger.log_inputs(
request_id,

View File

@@ -116,6 +116,7 @@ from vllm.entrypoints.openai.responses.utils import (
extract_tool_types,
should_continue_final_message,
)
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
@@ -435,8 +436,11 @@ class OpenAIServingResponses(OpenAIServing):
if maybe_error is not None:
return maybe_error
default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"]
default_max_tokens = get_max_tokens(
self.max_model_len,
request,
engine_prompt,
self.default_sampling_params,
)
sampling_params = request.to_sampling_params(

View File

@@ -17,8 +17,10 @@ from starlette.background import BackgroundTask, BackgroundTasks
from vllm import envs
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import EmbedsPrompt, TokensPrompt
from vllm.logger import current_formatter_type, init_logger
from vllm.platforms import current_platform
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.argparse_utils import FlexibleArgumentParser
if TYPE_CHECKING:
@@ -32,11 +34,15 @@ if TYPE_CHECKING:
StreamOptions,
)
from vllm.entrypoints.openai.models.protocol import LoRAModulePath
from vllm.entrypoints.openai.responses.protocol import (
ResponsesRequest,
)
else:
ChatCompletionRequest = object
CompletionRequest = object
StreamOptions = object
LoRAModulePath = object
ResponsesRequest = object
logger = init_logger(__name__)
@@ -211,11 +217,26 @@ def _validate_truncation_size(
def get_max_tokens(
max_model_len: int,
request: "ChatCompletionRequest | CompletionRequest",
input_length: int,
request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest",
prompt: TokensPrompt | EmbedsPrompt,
default_sampling_params: dict,
) -> int:
max_tokens = getattr(request, "max_completion_tokens", None) or request.max_tokens
# NOTE: Avoid isinstance() for better efficiency
max_tokens: int | None = None
if max_tokens is None:
# ChatCompletionRequest
max_tokens = getattr(request, "max_completion_tokens", None)
if max_tokens is None:
# ResponsesRequest
max_tokens = getattr(request, "max_output_tokens", None)
if max_tokens is None:
# CompletionRequest (also a fallback for ChatCompletionRequest)
max_tokens = getattr(request, "max_tokens", None)
input_length = length_from_prompt_token_ids_or_embeds(
prompt.get("prompt_token_ids"), # type: ignore[arg-type]
prompt.get("prompt_embeds"), # type: ignore[arg-type]
)
default_max_tokens = max_model_len - input_length
max_output_tokens = current_platform.get_max_output_tokens(input_length)