[Frontend] Refactor prompt processing (#4028)
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
@@ -12,6 +12,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
||||
load_chat_template,
|
||||
parse_chat_message_content)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionLogProb, ChatCompletionLogProbs,
|
||||
ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
|
||||
@@ -20,7 +21,8 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
||||
FunctionCall, ToolCall, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
OpenAIServing)
|
||||
OpenAIServing,
|
||||
PromptAdapterPath)
|
||||
from vllm.inputs import PromptInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
@@ -37,17 +39,24 @@ logger = init_logger(__name__)
|
||||
|
||||
class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
def __init__(self,
|
||||
engine: AsyncLLMEngine,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
response_role: str,
|
||||
lora_modules: Optional[List[LoRAModulePath]] = None,
|
||||
chat_template: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
engine: AsyncLLMEngine,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
response_role: str,
|
||||
*,
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
):
|
||||
super().__init__(engine=engine,
|
||||
model_config=model_config,
|
||||
served_model_names=served_model_names,
|
||||
lora_modules=lora_modules)
|
||||
lora_modules=lora_modules,
|
||||
prompt_adapters=prompt_adapters,
|
||||
request_logger=request_logger)
|
||||
|
||||
self.response_role = response_role
|
||||
|
||||
@@ -74,7 +83,12 @@ class OpenAIServingChat(OpenAIServing):
|
||||
return error_check_ret
|
||||
|
||||
try:
|
||||
_, lora_request = self._maybe_get_adapter(request)
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
model_config = self.model_config
|
||||
tokenizer = await self.engine.get_tokenizer(lora_request)
|
||||
|
||||
conversation: List[ConversationMessage] = []
|
||||
@@ -82,7 +96,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
for msg in request.messages:
|
||||
chat_parsed_result = parse_chat_message_content(
|
||||
msg, self.model_config, tokenizer)
|
||||
msg, model_config, tokenizer)
|
||||
|
||||
conversation.extend(chat_parsed_result.messages)
|
||||
mm_futures.extend(chat_parsed_result.mm_futures)
|
||||
@@ -116,14 +130,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
logger.error("Error in loading multi-modal data: %s", e)
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
request_id = f"chat-{random_uuid()}"
|
||||
try:
|
||||
# Tokenize/detokenize depending on prompt format (string/token list)
|
||||
prompt_ids, prompt_text = await self._validate_prompt_and_tokenize(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt=prompt,
|
||||
add_special_tokens=request.add_special_tokens)
|
||||
sampling_params = request.to_sampling_params()
|
||||
decoding_config = await self.engine.get_decoding_config()
|
||||
guided_decoding_backend = request.guided_decoding_backend \
|
||||
@@ -137,31 +145,47 @@ class OpenAIServingChat(OpenAIServing):
|
||||
sampling_params.logits_processors = []
|
||||
sampling_params.logits_processors.append(
|
||||
guided_decode_logits_processor)
|
||||
|
||||
prompt_inputs = self._tokenize_prompt_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt,
|
||||
truncate_prompt_tokens=sampling_params.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
|
||||
self._log_inputs(request_id,
|
||||
prompt_inputs,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
engine_inputs: PromptInputs = {
|
||||
"prompt_token_ids": prompt_inputs["prompt_token_ids"],
|
||||
}
|
||||
if mm_data is not None:
|
||||
engine_inputs["multi_modal_data"] = mm_data
|
||||
|
||||
is_tracing_enabled = await self.engine.is_tracing_enabled()
|
||||
trace_headers = None
|
||||
if is_tracing_enabled and raw_request:
|
||||
trace_headers = extract_trace_headers(raw_request.headers)
|
||||
if (not is_tracing_enabled and raw_request
|
||||
and contains_trace_headers(raw_request.headers)):
|
||||
log_tracing_disabled_warning()
|
||||
|
||||
result_generator = self.engine.generate(
|
||||
engine_inputs,
|
||||
sampling_params,
|
||||
request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
inputs: PromptInputs = {
|
||||
"prompt": prompt_text,
|
||||
"prompt_token_ids": prompt_ids,
|
||||
}
|
||||
if mm_data:
|
||||
inputs["multi_modal_data"] = mm_data
|
||||
|
||||
is_tracing_enabled = await self.engine.is_tracing_enabled()
|
||||
trace_headers = None
|
||||
if is_tracing_enabled and raw_request:
|
||||
trace_headers = extract_trace_headers(raw_request.headers)
|
||||
if not is_tracing_enabled and raw_request and contains_trace_headers(
|
||||
raw_request.headers):
|
||||
log_tracing_disabled_warning()
|
||||
|
||||
result_generator = self.engine.generate(
|
||||
inputs,
|
||||
sampling_params,
|
||||
request_id,
|
||||
lora_request,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
# Streaming response
|
||||
if request.stream:
|
||||
return self.chat_completion_stream_generator(
|
||||
@@ -195,10 +219,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
first_iteration = True
|
||||
|
||||
# Send response for each token for each request.n (index)
|
||||
assert request.n is not None
|
||||
previous_texts = [""] * request.n
|
||||
previous_num_tokens = [0] * request.n
|
||||
finish_reason_sent = [False] * request.n
|
||||
num_choices = 1 if request.n is None else request.n
|
||||
previous_texts = [""] * num_choices
|
||||
previous_num_tokens = [0] * num_choices
|
||||
finish_reason_sent = [False] * num_choices
|
||||
|
||||
try:
|
||||
async for res in result_generator:
|
||||
# We need to do it here, because if there are exceptions in
|
||||
@@ -208,7 +233,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# Send first response for each request.n (index) with
|
||||
# the role
|
||||
role = self.get_chat_request_role(request)
|
||||
for i in range(request.n):
|
||||
for i in range(num_choices):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(role=role),
|
||||
@@ -236,19 +261,19 @@ class OpenAIServingChat(OpenAIServing):
|
||||
last_msg_content = conversation[-1]["content"]
|
||||
|
||||
if last_msg_content:
|
||||
for i in range(request.n):
|
||||
for i in range(num_choices):
|
||||
choice_data = (
|
||||
ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(
|
||||
content=last_msg_content),
|
||||
logprobs=None,
|
||||
finish_reason=None))
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
logprobs=None,
|
||||
model=model_name)
|
||||
if (request.stream_options and
|
||||
request.stream_options.include_usage):
|
||||
|
||||
Reference in New Issue
Block a user