[Frontend] Chat-based Embeddings API (#9759)

This commit is contained in:
Cyrus Leung
2024-11-01 16:13:35 +08:00
committed by GitHub
parent d3aa2a8b2f
commit 06386a64dd
21 changed files with 846 additions and 408 deletions

View File

@@ -2,10 +2,7 @@ from typing import List, Optional, Union
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
apply_mistral_chat_template,
load_chat_template,
parse_chat_messages_futures)
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
@@ -20,7 +17,6 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.utils import random_uuid
logger = init_logger(__name__)
@@ -62,59 +58,51 @@ class OpenAIServingTokenization(OpenAIServing):
request_id = f"tokn-{random_uuid()}"
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
try:
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
prompt: Union[str, List[int]]
if isinstance(request, TokenizeChatRequest):
model_config = self.model_config
conversation, mm_data_future = parse_chat_messages_futures(
request.messages, model_config, tokenizer)
mm_data = await mm_data_future
if mm_data:
logger.warning(
"Multi-modal inputs are ignored during tokenization")
if isinstance(tokenizer, MistralTokenizer):
prompt = apply_mistral_chat_template(
if isinstance(request, TokenizeChatRequest):
(
_,
request_prompts,
engine_prompts,
) = await self._preprocess_chat(
request,
tokenizer,
messages=request.messages,
request.messages,
chat_template=self.chat_template,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
add_special_tokens=request.add_special_tokens,
)
else:
prompt = apply_hf_chat_template(
request_prompts, engine_prompts = self._preprocess_completion(
request,
tokenizer,
conversation=conversation,
chat_template=self.chat_template,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
request.prompt,
add_special_tokens=request.add_special_tokens,
)
else:
prompt = request.prompt
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
self._log_inputs(request_id,
prompt,
params=None,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
input_ids: List[int] = []
for i, engine_prompt in enumerate(engine_prompts):
self._log_inputs(request_id,
request_prompts[i],
params=None,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
# Silently ignore prompt adapter since it does not affect tokenization
# Silently ignore prompt adapter since it does not affect
# tokenization (Unlike in Embeddings API where an error is raised)
prompt_input = self._tokenize_prompt_input(
request,
tokenizer,
prompt,
add_special_tokens=request.add_special_tokens,
)
input_ids = prompt_input["prompt_token_ids"]
input_ids.extend(engine_prompt["prompt_token_ids"])
return TokenizeResponse(tokens=input_ids,
count=len(input_ids),
@@ -143,9 +131,8 @@ class OpenAIServingTokenization(OpenAIServing):
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for tokenization")
# Silently ignore prompt adapter since it does not affect tokenization
# (Unlike in Embeddings API where an error is raised)
prompt_input = self._tokenize_prompt_input(
request,