diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 433fe961a..adcd488a0 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -388,7 +388,9 @@ class OpenAIServingChat(OpenAIServing): max_tokens = get_max_tokens( self.max_model_len, - request, + request.max_completion_tokens + if request.max_completion_tokens is not None + else request.max_tokens, self._extract_prompt_len(engine_prompt), self.default_sampling_params, ) diff --git a/vllm/entrypoints/openai/completion/serving.py b/vllm/entrypoints/openai/completion/serving.py index d2fa2f931..beb3c2c53 100644 --- a/vllm/entrypoints/openai/completion/serving.py +++ b/vllm/entrypoints/openai/completion/serving.py @@ -164,7 +164,7 @@ class OpenAIServingCompletion(OpenAIServing): max_tokens = get_max_tokens( self.max_model_len, - request, + request.max_tokens, self._extract_prompt_len(engine_prompt), self.default_sampling_params, ) diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index ebd629afa..2fabc5999 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -1176,7 +1176,7 @@ class OpenAIServing: sampling_params.max_tokens = get_max_tokens( self.max_model_len, - context.request, + context.request.max_output_tokens, self._extract_prompt_len(engine_prompt), self.default_sampling_params, # type: ignore ) diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py index 500401468..9f54a8081 100644 --- a/vllm/entrypoints/openai/responses/serving.py +++ b/vllm/entrypoints/openai/responses/serving.py @@ -441,7 +441,7 @@ class OpenAIServingResponses(OpenAIServing): default_max_tokens = get_max_tokens( self.max_model_len, - request, + request.max_output_tokens, self._extract_prompt_len(engine_prompt), self.default_sampling_params, ) diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index 4900bfa7d..98822b9c6 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -22,23 +22,11 @@ from vllm.platforms import current_platform from vllm.utils.argparse_utils import FlexibleArgumentParser if TYPE_CHECKING: - from vllm.entrypoints.openai.chat_completion.protocol import ( - ChatCompletionRequest, - ) - from vllm.entrypoints.openai.completion.protocol import ( - CompletionRequest, - ) - from vllm.entrypoints.openai.engine.protocol import ( - StreamOptions, - ) + from vllm.entrypoints.openai.engine.protocol import StreamOptions from vllm.entrypoints.openai.models.protocol import LoRAModulePath - from vllm.entrypoints.openai.responses.protocol import ResponsesRequest else: - ChatCompletionRequest = object - CompletionRequest = object StreamOptions = object LoRAModulePath = object - ResponsesRequest = object logger = init_logger(__name__) @@ -186,22 +174,10 @@ def cli_env_setup(): def get_max_tokens( max_model_len: int, - request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest", + max_tokens: int | None, input_length: int, default_sampling_params: dict, ) -> int: - # NOTE: Avoid isinstance() for better efficiency - max_tokens: int | None = None - if max_tokens is None: - # ChatCompletionRequest - max_tokens = getattr(request, "max_completion_tokens", None) - if max_tokens is None: - # ResponsesRequest - max_tokens = getattr(request, "max_output_tokens", None) - if max_tokens is None: - # CompletionRequest (also a fallback for ChatCompletionRequest) - max_tokens = getattr(request, "max_tokens", None) - default_max_tokens = max_model_len - input_length max_output_tokens = current_platform.get_max_output_tokens(input_length)