[Bugfix] Set SamplingParams.max_tokens for OpenAI requests if not provided by user (#6954)

This commit is contained in:
zifeitong
2024-07-31 21:13:34 -07:00
committed by GitHub
parent 0437492ea9
commit 3c10591ef2
5 changed files with 92 additions and 44 deletions

View File

@@ -25,9 +25,11 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
from vllm.inputs import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer_group import AnyTokenizer
@@ -150,6 +152,15 @@ class OpenAIServing:
})
return json_str
async def _guided_decode_logits_processor(
self, request: Union[ChatCompletionRequest, CompletionRequest],
tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]:
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
return await get_guided_decoding_logits_processor(
guided_decoding_backend, request, tokenizer)
async def _check_model(
self,
request: AnyRequest,
@@ -254,9 +265,7 @@ class OpenAIServing:
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the messages, "
f"Please reduce the length of the messages.")
request.max_tokens = self.max_model_len - token_num
if token_num + request.max_tokens > self.max_model_len:
elif token_num + request.max_tokens > self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "