[Misc] Simplify get_max_tokens (#34036)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -388,7 +388,9 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
max_tokens = get_max_tokens(
|
max_tokens = get_max_tokens(
|
||||||
self.max_model_len,
|
self.max_model_len,
|
||||||
request,
|
request.max_completion_tokens
|
||||||
|
if request.max_completion_tokens is not None
|
||||||
|
else request.max_tokens,
|
||||||
self._extract_prompt_len(engine_prompt),
|
self._extract_prompt_len(engine_prompt),
|
||||||
self.default_sampling_params,
|
self.default_sampling_params,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
|
|
||||||
max_tokens = get_max_tokens(
|
max_tokens = get_max_tokens(
|
||||||
self.max_model_len,
|
self.max_model_len,
|
||||||
request,
|
request.max_tokens,
|
||||||
self._extract_prompt_len(engine_prompt),
|
self._extract_prompt_len(engine_prompt),
|
||||||
self.default_sampling_params,
|
self.default_sampling_params,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1176,7 +1176,7 @@ class OpenAIServing:
|
|||||||
|
|
||||||
sampling_params.max_tokens = get_max_tokens(
|
sampling_params.max_tokens = get_max_tokens(
|
||||||
self.max_model_len,
|
self.max_model_len,
|
||||||
context.request,
|
context.request.max_output_tokens,
|
||||||
self._extract_prompt_len(engine_prompt),
|
self._extract_prompt_len(engine_prompt),
|
||||||
self.default_sampling_params, # type: ignore
|
self.default_sampling_params, # type: ignore
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -441,7 +441,7 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
|
|
||||||
default_max_tokens = get_max_tokens(
|
default_max_tokens = get_max_tokens(
|
||||||
self.max_model_len,
|
self.max_model_len,
|
||||||
request,
|
request.max_output_tokens,
|
||||||
self._extract_prompt_len(engine_prompt),
|
self._extract_prompt_len(engine_prompt),
|
||||||
self.default_sampling_params,
|
self.default_sampling_params,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -22,23 +22,11 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
from vllm.entrypoints.openai.engine.protocol import StreamOptions
|
||||||
ChatCompletionRequest,
|
|
||||||
)
|
|
||||||
from vllm.entrypoints.openai.completion.protocol import (
|
|
||||||
CompletionRequest,
|
|
||||||
)
|
|
||||||
from vllm.entrypoints.openai.engine.protocol import (
|
|
||||||
StreamOptions,
|
|
||||||
)
|
|
||||||
from vllm.entrypoints.openai.models.protocol import LoRAModulePath
|
from vllm.entrypoints.openai.models.protocol import LoRAModulePath
|
||||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
|
||||||
else:
|
else:
|
||||||
ChatCompletionRequest = object
|
|
||||||
CompletionRequest = object
|
|
||||||
StreamOptions = object
|
StreamOptions = object
|
||||||
LoRAModulePath = object
|
LoRAModulePath = object
|
||||||
ResponsesRequest = object
|
|
||||||
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@@ -186,22 +174,10 @@ def cli_env_setup():
|
|||||||
|
|
||||||
def get_max_tokens(
|
def get_max_tokens(
|
||||||
max_model_len: int,
|
max_model_len: int,
|
||||||
request: "CompletionRequest | ChatCompletionRequest | ResponsesRequest",
|
max_tokens: int | None,
|
||||||
input_length: int,
|
input_length: int,
|
||||||
default_sampling_params: dict,
|
default_sampling_params: dict,
|
||||||
) -> int:
|
) -> int:
|
||||||
# NOTE: Avoid isinstance() for better efficiency
|
|
||||||
max_tokens: int | None = None
|
|
||||||
if max_tokens is None:
|
|
||||||
# ChatCompletionRequest
|
|
||||||
max_tokens = getattr(request, "max_completion_tokens", None)
|
|
||||||
if max_tokens is None:
|
|
||||||
# ResponsesRequest
|
|
||||||
max_tokens = getattr(request, "max_output_tokens", None)
|
|
||||||
if max_tokens is None:
|
|
||||||
# CompletionRequest (also a fallback for ChatCompletionRequest)
|
|
||||||
max_tokens = getattr(request, "max_tokens", None)
|
|
||||||
|
|
||||||
default_max_tokens = max_model_len - input_length
|
default_max_tokens = max_model_len - input_length
|
||||||
max_output_tokens = current_platform.get_max_output_tokens(input_length)
|
max_output_tokens = current_platform.get_max_output_tokens(input_length)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user