[Frontend] Refactor prompt processing (#4028)

Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Cyrus Leung
2024-07-23 01:13:53 +08:00
committed by GitHub
parent 89c1c6a196
commit 739b61a348
24 changed files with 699 additions and 391 deletions

View File

@@ -12,6 +12,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.chat_utils import (ConversationMessage,
load_chat_template,
parse_chat_message_content)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProb, ChatCompletionLogProbs,
ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
@@ -20,7 +21,8 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
FunctionCall, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
OpenAIServing,
PromptAdapterPath)
from vllm.inputs import PromptInputs
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
@@ -37,17 +39,24 @@ logger = init_logger(__name__)
class OpenAIServingChat(OpenAIServing):
def __init__(self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
response_role: str,
lora_modules: Optional[List[LoRAModulePath]] = None,
chat_template: Optional[str] = None):
def __init__(
self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
response_role: str,
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules)
lora_modules=lora_modules,
prompt_adapters=prompt_adapters,
request_logger=request_logger)
self.response_role = response_role
@@ -74,7 +83,12 @@ class OpenAIServingChat(OpenAIServing):
return error_check_ret
try:
_, lora_request = self._maybe_get_adapter(request)
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
model_config = self.model_config
tokenizer = await self.engine.get_tokenizer(lora_request)
conversation: List[ConversationMessage] = []
@@ -82,7 +96,7 @@ class OpenAIServingChat(OpenAIServing):
for msg in request.messages:
chat_parsed_result = parse_chat_message_content(
msg, self.model_config, tokenizer)
msg, model_config, tokenizer)
conversation.extend(chat_parsed_result.messages)
mm_futures.extend(chat_parsed_result.mm_futures)
@@ -116,14 +130,8 @@ class OpenAIServingChat(OpenAIServing):
logger.error("Error in loading multi-modal data: %s", e)
return self.create_error_response(str(e))
request_id = f"cmpl-{random_uuid()}"
request_id = f"chat-{random_uuid()}"
try:
# Tokenize/detokenize depending on prompt format (string/token list)
prompt_ids, prompt_text = await self._validate_prompt_and_tokenize(
request,
tokenizer,
prompt=prompt,
add_special_tokens=request.add_special_tokens)
sampling_params = request.to_sampling_params()
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
@@ -137,31 +145,47 @@ class OpenAIServingChat(OpenAIServing):
sampling_params.logits_processors = []
sampling_params.logits_processors.append(
guided_decode_logits_processor)
prompt_inputs = self._tokenize_prompt_input(
request,
tokenizer,
prompt,
truncate_prompt_tokens=sampling_params.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
self._log_inputs(request_id,
prompt_inputs,
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
engine_inputs: PromptInputs = {
"prompt_token_ids": prompt_inputs["prompt_token_ids"],
}
if mm_data is not None:
engine_inputs["multi_modal_data"] = mm_data
is_tracing_enabled = await self.engine.is_tracing_enabled()
trace_headers = None
if is_tracing_enabled and raw_request:
trace_headers = extract_trace_headers(raw_request.headers)
if (not is_tracing_enabled and raw_request
and contains_trace_headers(raw_request.headers)):
log_tracing_disabled_warning()
result_generator = self.engine.generate(
engine_inputs,
sampling_params,
request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
inputs: PromptInputs = {
"prompt": prompt_text,
"prompt_token_ids": prompt_ids,
}
if mm_data:
inputs["multi_modal_data"] = mm_data
is_tracing_enabled = await self.engine.is_tracing_enabled()
trace_headers = None
if is_tracing_enabled and raw_request:
trace_headers = extract_trace_headers(raw_request.headers)
if not is_tracing_enabled and raw_request and contains_trace_headers(
raw_request.headers):
log_tracing_disabled_warning()
result_generator = self.engine.generate(
inputs,
sampling_params,
request_id,
lora_request,
trace_headers=trace_headers,
)
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
@@ -195,10 +219,11 @@ class OpenAIServingChat(OpenAIServing):
first_iteration = True
# Send response for each token for each request.n (index)
assert request.n is not None
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
finish_reason_sent = [False] * request.n
num_choices = 1 if request.n is None else request.n
previous_texts = [""] * num_choices
previous_num_tokens = [0] * num_choices
finish_reason_sent = [False] * num_choices
try:
async for res in result_generator:
# We need to do it here, because if there are exceptions in
@@ -208,7 +233,7 @@ class OpenAIServingChat(OpenAIServing):
# Send first response for each request.n (index) with
# the role
role = self.get_chat_request_role(request)
for i in range(request.n):
for i in range(num_choices):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role=role),
@@ -236,19 +261,19 @@ class OpenAIServingChat(OpenAIServing):
last_msg_content = conversation[-1]["content"]
if last_msg_content:
for i in range(request.n):
for i in range(num_choices):
choice_data = (
ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(
content=last_msg_content),
logprobs=None,
finish_reason=None))
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
logprobs=None,
model=model_name)
if (request.stream_options and
request.stream_options.include_usage):