[Misc] Simplify get_max_tokens (#34036)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-02-07 16:59:49 +08:00
committed by GitHub
parent 15a0b9e570
commit 11a4c9d30d
5 changed files with 8 additions and 30 deletions

View File

@@ -388,7 +388,9 @@ class OpenAIServingChat(OpenAIServing):
max_tokens = get_max_tokens(
self.max_model_len,
request,
request.max_completion_tokens
if request.max_completion_tokens is not None
else request.max_tokens,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params,
)

View File

@@ -164,7 +164,7 @@ class OpenAIServingCompletion(OpenAIServing):
max_tokens = get_max_tokens(
self.max_model_len,
request,
request.max_tokens,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params,
)

View File

@@ -1176,7 +1176,7 @@ class OpenAIServing:
sampling_params.max_tokens = get_max_tokens(
self.max_model_len,
context.request,
context.request.max_output_tokens,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params, # type: ignore
)

View File

@@ -441,7 +441,7 @@ class OpenAIServingResponses(OpenAIServing):
default_max_tokens = get_max_tokens(
self.max_model_len,
request,
request.max_output_tokens,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params,
)

View File

@@ -22,23 +22,11 @@ from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser
if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.entrypoints.openai.completion.protocol import (
CompletionRequest,
)
from vllm.entrypoints.openai.engine.protocol import (
StreamOptions,
)
from vllm.entrypoints.openai.engine.protocol import StreamOptions
from vllm.entrypoints.openai.models.protocol import LoRAModulePath
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
else:
ChatCompletionRequest = object
CompletionRequest = object
StreamOptions = object
LoRAModulePath = object
ResponsesRequest = object
logger = init_logger(__name__)
@@ -186,22 +174,10 @@ def cli_env_setup():
def get_max_tokens(
max_model_len: int,
request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest",
max_tokens: int | None,
input_length: int,
default_sampling_params: dict,
) -> int:
# NOTE: Avoid isinstance() for better efficiency
max_tokens: int | None = None
if max_tokens is None:
# ChatCompletionRequest
max_tokens = getattr(request, "max_completion_tokens", None)
if max_tokens is None:
# ResponsesRequest
max_tokens = getattr(request, "max_output_tokens", None)
if max_tokens is None:
# CompletionRequest (also a fallback for ChatCompletionRequest)
max_tokens = getattr(request, "max_tokens", None)
default_max_tokens = max_model_len - input_length
max_output_tokens = current_platform.get_max_output_tokens(input_length)