[Bugfix] Set SamplingParams.max_tokens for OpenAI requests if not provided by user (#6954)

This commit is contained in:
zifeitong
2024-07-31 21:13:34 -07:00
committed by GitHub
parent 0437492ea9
commit 3c10591ef2
5 changed files with 92 additions and 44 deletions

View File

@@ -24,8 +24,6 @@ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing,
PromptAdapterPath)
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.outputs import RequestOutput
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
@@ -95,31 +93,24 @@ class OpenAIServingCompletion(OpenAIServing):
tokenizer = await self.engine.get_tokenizer(lora_request)
sampling_params = request.to_sampling_params(tokenizer)
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logit_processor = (
await
get_guided_decoding_logits_processor(guided_decoding_backend,
request, tokenizer))
if guided_decode_logit_processor is not None:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(
guided_decode_logit_processor)
guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))
prompts = list(
self._tokenize_prompt_input_or_inputs(
request,
tokenizer,
request.prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
))
for i, prompt_inputs in enumerate(prompts):
sampling_params = request.to_sampling_params(
tokenizer,
guided_decode_logits_processor,
default_max_tokens=self.max_model_len -
len(prompt_inputs["prompt_token_ids"]))
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,