[Misc] Simplify get_max_tokens (#34036)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -388,7 +388,9 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
max_tokens = get_max_tokens(
|
||||
self.max_model_len,
|
||||
request,
|
||||
request.max_completion_tokens
|
||||
if request.max_completion_tokens is not None
|
||||
else request.max_tokens,
|
||||
self._extract_prompt_len(engine_prompt),
|
||||
self.default_sampling_params,
|
||||
)
|
||||
|
||||
@@ -164,7 +164,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
max_tokens = get_max_tokens(
|
||||
self.max_model_len,
|
||||
request,
|
||||
request.max_tokens,
|
||||
self._extract_prompt_len(engine_prompt),
|
||||
self.default_sampling_params,
|
||||
)
|
||||
|
||||
@@ -1176,7 +1176,7 @@ class OpenAIServing:
|
||||
|
||||
sampling_params.max_tokens = get_max_tokens(
|
||||
self.max_model_len,
|
||||
context.request,
|
||||
context.request.max_output_tokens,
|
||||
self._extract_prompt_len(engine_prompt),
|
||||
self.default_sampling_params, # type: ignore
|
||||
)
|
||||
|
||||
@@ -441,7 +441,7 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
|
||||
default_max_tokens = get_max_tokens(
|
||||
self.max_model_len,
|
||||
request,
|
||||
request.max_output_tokens,
|
||||
self._extract_prompt_len(engine_prompt),
|
||||
self.default_sampling_params,
|
||||
)
|
||||
|
||||
@@ -22,23 +22,11 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.completion.protocol import (
|
||||
CompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
StreamOptions,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import StreamOptions
|
||||
from vllm.entrypoints.openai.models.protocol import LoRAModulePath
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
else:
|
||||
ChatCompletionRequest = object
|
||||
CompletionRequest = object
|
||||
StreamOptions = object
|
||||
LoRAModulePath = object
|
||||
ResponsesRequest = object
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -186,22 +174,10 @@ def cli_env_setup():
|
||||
|
||||
def get_max_tokens(
|
||||
max_model_len: int,
|
||||
request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest",
|
||||
max_tokens: int | None,
|
||||
input_length: int,
|
||||
default_sampling_params: dict,
|
||||
) -> int:
|
||||
# NOTE: Avoid isinstance() for better efficiency
|
||||
max_tokens: int | None = None
|
||||
if max_tokens is None:
|
||||
# ChatCompletionRequest
|
||||
max_tokens = getattr(request, "max_completion_tokens", None)
|
||||
if max_tokens is None:
|
||||
# ResponsesRequest
|
||||
max_tokens = getattr(request, "max_output_tokens", None)
|
||||
if max_tokens is None:
|
||||
# CompletionRequest (also a fallback for ChatCompletionRequest)
|
||||
max_tokens = getattr(request, "max_tokens", None)
|
||||
|
||||
default_max_tokens = max_model_len - input_length
|
||||
max_output_tokens = current_platform.get_max_output_tokens(input_length)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user