[Bugfix] Set SamplingParams.max_tokens for OpenAI requests if not provided by user (#6954)
This commit is contained in:
@@ -24,8 +24,6 @@ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
OpenAIServing,
|
||||
PromptAdapterPath)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_guided_decoding_logits_processor)
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
@@ -95,31 +93,24 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
tokenizer = await self.engine.get_tokenizer(lora_request)
|
||||
|
||||
sampling_params = request.to_sampling_params(tokenizer)
|
||||
decoding_config = await self.engine.get_decoding_config()
|
||||
guided_decoding_backend = request.guided_decoding_backend \
|
||||
or decoding_config.guided_decoding_backend
|
||||
guided_decode_logit_processor = (
|
||||
await
|
||||
get_guided_decoding_logits_processor(guided_decoding_backend,
|
||||
request, tokenizer))
|
||||
if guided_decode_logit_processor is not None:
|
||||
if sampling_params.logits_processors is None:
|
||||
sampling_params.logits_processors = []
|
||||
sampling_params.logits_processors.append(
|
||||
guided_decode_logit_processor)
|
||||
|
||||
guided_decode_logits_processor = (
|
||||
await self._guided_decode_logits_processor(request, tokenizer))
|
||||
prompts = list(
|
||||
self._tokenize_prompt_input_or_inputs(
|
||||
request,
|
||||
tokenizer,
|
||||
request.prompt,
|
||||
truncate_prompt_tokens=sampling_params.
|
||||
truncate_prompt_tokens,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
))
|
||||
|
||||
for i, prompt_inputs in enumerate(prompts):
|
||||
sampling_params = request.to_sampling_params(
|
||||
tokenizer,
|
||||
guided_decode_logits_processor,
|
||||
default_max_tokens=self.max_model_len -
|
||||
len(prompt_inputs["prompt_token_ids"]))
|
||||
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
|
||||
Reference in New Issue
Block a user