[Misc] Simplify get_max_tokens (#34036)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-02-07 16:59:49 +08:00
committed by GitHub
parent 15a0b9e570
commit 11a4c9d30d
5 changed files with 8 additions and 30 deletions

View File

@@ -388,7 +388,9 @@ class OpenAIServingChat(OpenAIServing):
max_tokens = get_max_tokens( max_tokens = get_max_tokens(
self.max_model_len, self.max_model_len,
request, request.max_completion_tokens
if request.max_completion_tokens is not None
else request.max_tokens,
self._extract_prompt_len(engine_prompt), self._extract_prompt_len(engine_prompt),
self.default_sampling_params, self.default_sampling_params,
) )

View File

@@ -164,7 +164,7 @@ class OpenAIServingCompletion(OpenAIServing):
max_tokens = get_max_tokens( max_tokens = get_max_tokens(
self.max_model_len, self.max_model_len,
request, request.max_tokens,
self._extract_prompt_len(engine_prompt), self._extract_prompt_len(engine_prompt),
self.default_sampling_params, self.default_sampling_params,
) )

View File

@@ -1176,7 +1176,7 @@ class OpenAIServing:
sampling_params.max_tokens = get_max_tokens( sampling_params.max_tokens = get_max_tokens(
self.max_model_len, self.max_model_len,
context.request, context.request.max_output_tokens,
self._extract_prompt_len(engine_prompt), self._extract_prompt_len(engine_prompt),
self.default_sampling_params, # type: ignore self.default_sampling_params, # type: ignore
) )

View File

@@ -441,7 +441,7 @@ class OpenAIServingResponses(OpenAIServing):
default_max_tokens = get_max_tokens( default_max_tokens = get_max_tokens(
self.max_model_len, self.max_model_len,
request, request.max_output_tokens,
self._extract_prompt_len(engine_prompt), self._extract_prompt_len(engine_prompt),
self.default_sampling_params, self.default_sampling_params,
) )

View File

@@ -22,23 +22,11 @@ from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import ( from vllm.entrypoints.openai.engine.protocol import StreamOptions
ChatCompletionRequest,
)
from vllm.entrypoints.openai.completion.protocol import (
CompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
StreamOptions,
)
from vllm.entrypoints.openai.models.protocol import LoRAModulePath from vllm.entrypoints.openai.models.protocol import LoRAModulePath
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
else: else:
ChatCompletionRequest = object
CompletionRequest = object
StreamOptions = object StreamOptions = object
LoRAModulePath = object LoRAModulePath = object
ResponsesRequest = object
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -186,22 +174,10 @@ def cli_env_setup():
def get_max_tokens( def get_max_tokens(
max_model_len: int, max_model_len: int,
request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest", max_tokens: int | None,
input_length: int, input_length: int,
default_sampling_params: dict, default_sampling_params: dict,
) -> int: ) -> int:
# NOTE: Avoid isinstance() for better efficiency
max_tokens: int | None = None
if max_tokens is None:
# ChatCompletionRequest
max_tokens = getattr(request, "max_completion_tokens", None)
if max_tokens is None:
# ResponsesRequest
max_tokens = getattr(request, "max_output_tokens", None)
if max_tokens is None:
# CompletionRequest (also a fallback for ChatCompletionRequest)
max_tokens = getattr(request, "max_tokens", None)
default_max_tokens = max_model_len - input_length default_max_tokens = max_model_len - input_length
max_output_tokens = current_platform.get_max_output_tokens(input_length) max_output_tokens = current_platform.get_max_output_tokens(input_length)