diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 6a22bece6..aa79e9da3 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -68,6 +68,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import ( from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.inputs.data import TokensPrompt +from vllm.inputs.parse import get_prompt_components from vllm.logger import init_logger from vllm.logprobs import Logprob from vllm.outputs import CompletionOutput, RequestOutput @@ -374,20 +375,18 @@ class OpenAIServingChat(OpenAIServing): generators: list[AsyncGenerator[RequestOutput, None]] = [] try: for i, engine_prompt in enumerate(engine_prompts): - prompt_text, _, _ = self._get_prompt_components(engine_prompt) + prompt_text, _, _ = get_prompt_components(engine_prompt) + # If we are creating sub requests for multiple prompts, ensure that they # have unique request ids. sub_request_id = ( request_id if len(engine_prompts) == 1 else f"{request_id}_{i}" ) - if self.default_sampling_params is None: - self.default_sampling_params = {} - max_tokens = get_max_tokens( max_model_len=self.max_model_len, request=request, - input_length=len(engine_prompt["prompt_token_ids"]), + prompt=engine_prompt, default_sampling_params=self.default_sampling_params, ) diff --git a/vllm/entrypoints/openai/completion/serving.py b/vllm/entrypoints/openai/completion/serving.py index 92156c7f2..24cf486a6 100644 --- a/vllm/entrypoints/openai/completion/serving.py +++ b/vllm/entrypoints/openai/completion/serving.py @@ -36,6 +36,7 @@ from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.exceptions import VLLMValidationError from vllm.inputs.data import EmbedsPrompt, TokensPrompt, is_embeds_prompt +from vllm.inputs.parse import get_prompt_components from vllm.logger import init_logger from vllm.logprobs import Logprob from vllm.outputs import RequestOutput @@ -162,25 +163,12 @@ class OpenAIServingCompletion(OpenAIServing): generators: list[AsyncGenerator[RequestOutput, None]] = [] try: for i, engine_prompt in enumerate(engine_prompts): - prompt_text, prompt_token_ids, prompt_embeds = ( - self._get_prompt_components(engine_prompt) - ) - - input_length = None - if prompt_token_ids is not None: - input_length = len(prompt_token_ids) - elif prompt_embeds is not None: - input_length = len(prompt_embeds) - else: - raise NotImplementedError - - if self.default_sampling_params is None: - self.default_sampling_params = {} + prompt_text, _, _ = get_prompt_components(engine_prompt) max_tokens = get_max_tokens( max_model_len=self.max_model_len, request=request, - input_length=input_length, + prompt=engine_prompt, default_sampling_params=self.default_sampling_params, ) diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index f18dba5ee..0433f28d9 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -94,11 +94,14 @@ from vllm.entrypoints.serve.tokenize.protocol import ( TokenizeCompletionRequest, TokenizeResponse, ) -from vllm.entrypoints.utils import _validate_truncation_size, sanitize_message +from vllm.entrypoints.utils import ( + _validate_truncation_size, + get_max_tokens, + sanitize_message, +) from vllm.exceptions import VLLMValidationError from vllm.inputs.data import PromptType, TokensPrompt from vllm.inputs.parse import ( - PromptComponents, get_prompt_components, is_explicit_encoder_decoder_prompt, ) @@ -1287,7 +1290,7 @@ class OpenAIServing: priority: int = 0, **kwargs, ): - prompt_text, _, _ = self._get_prompt_components(engine_prompt) + prompt_text, _, _ = get_prompt_components(engine_prompt) orig_priority = priority sub_request = 0 @@ -1338,10 +1341,12 @@ class OpenAIServing: # yield context # Create inputs for the next turn. - # Render the next prompt token ids. + # Render the next prompt token ids and update sampling_params. if isinstance(context, (HarmonyContext, StreamingHarmonyContext)): - prompt_token_ids = context.render_for_completion() - engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) + token_ids = context.render_for_completion() + engine_prompt = TokensPrompt(prompt_token_ids=token_ids) + + sampling_params.max_tokens = self.max_model_len - len(token_ids) elif isinstance(context, ParsableContext): engine_prompts = await self._render_next_turn( context.request, @@ -1353,19 +1358,19 @@ class OpenAIServing: context.chat_template_content_format, ) engine_prompt = engine_prompts[0] - prompt_text, _, _ = self._get_prompt_components(engine_prompt) + prompt_text, _, _ = get_prompt_components(engine_prompt) + + sampling_params.max_tokens = get_max_tokens( + self.max_model_len, + context.request, + engine_prompt, + self.default_sampling_params, # type: ignore + ) - # Update the sampling params. - sampling_params.max_tokens = self.max_model_len - len( - engine_prompt["prompt_token_ids"] - ) # OPTIMIZATION priority = orig_priority - 1 sub_request += 1 - def _get_prompt_components(self, prompt: PromptType) -> PromptComponents: - return get_prompt_components(prompt) - def _log_inputs( self, request_id: str, @@ -1376,7 +1381,7 @@ class OpenAIServing: if self.request_logger is None: return - prompt, prompt_token_ids, prompt_embeds = self._get_prompt_components(inputs) + prompt, prompt_token_ids, prompt_embeds = get_prompt_components(inputs) self.request_logger.log_inputs( request_id, diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py index 702167a24..ebde2b063 100644 --- a/vllm/entrypoints/openai/responses/serving.py +++ b/vllm/entrypoints/openai/responses/serving.py @@ -116,6 +116,7 @@ from vllm.entrypoints.openai.responses.utils import ( extract_tool_types, should_continue_final_message, ) +from vllm.entrypoints.utils import get_max_tokens from vllm.exceptions import VLLMValidationError from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger @@ -435,8 +436,11 @@ class OpenAIServingResponses(OpenAIServing): if maybe_error is not None: return maybe_error - default_max_tokens = self.max_model_len - len( - engine_prompt["prompt_token_ids"] + default_max_tokens = get_max_tokens( + self.max_model_len, + request, + engine_prompt, + self.default_sampling_params, ) sampling_params = request.to_sampling_params( diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index a42f08f7f..c329e7a19 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -17,8 +17,10 @@ from starlette.background import BackgroundTask, BackgroundTasks from vllm import envs from vllm.engine.arg_utils import EngineArgs +from vllm.inputs import EmbedsPrompt, TokensPrompt from vllm.logger import current_formatter_type, init_logger from vllm.platforms import current_platform +from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.utils.argparse_utils import FlexibleArgumentParser if TYPE_CHECKING: @@ -32,11 +34,15 @@ if TYPE_CHECKING: StreamOptions, ) from vllm.entrypoints.openai.models.protocol import LoRAModulePath + from vllm.entrypoints.openai.responses.protocol import ( + ResponsesRequest, + ) else: ChatCompletionRequest = object CompletionRequest = object StreamOptions = object LoRAModulePath = object + ResponsesRequest = object logger = init_logger(__name__) @@ -211,11 +217,26 @@ def _validate_truncation_size( def get_max_tokens( max_model_len: int, - request: "ChatCompletionRequest | CompletionRequest", - input_length: int, + request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest", + prompt: TokensPrompt | EmbedsPrompt, default_sampling_params: dict, ) -> int: - max_tokens = getattr(request, "max_completion_tokens", None) or request.max_tokens + # NOTE: Avoid isinstance() for better efficiency + max_tokens: int | None = None + if max_tokens is None: + # ChatCompletionRequest + max_tokens = getattr(request, "max_completion_tokens", None) + if max_tokens is None: + # ResponsesRequest + max_tokens = getattr(request, "max_output_tokens", None) + if max_tokens is None: + # CompletionRequest (also a fallback for ChatCompletionRequest) + max_tokens = getattr(request, "max_tokens", None) + + input_length = length_from_prompt_token_ids_or_embeds( + prompt.get("prompt_token_ids"), # type: ignore[arg-type] + prompt.get("prompt_embeds"), # type: ignore[arg-type] + ) default_max_tokens = max_model_len - input_length max_output_tokens = current_platform.get_max_output_tokens(input_length)