[Frontend] Chat-based Embeddings API (#9759)

This commit is contained in:
Cyrus Leung
2024-11-01 16:13:35 +08:00
committed by GitHub
parent d3aa2a8b2f
commit 06386a64dd
21 changed files with 846 additions and 408 deletions

View File

@@ -2,28 +2,38 @@ import json
import pathlib
from dataclasses import dataclass
from http import HTTPStatus
from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
Optional, Sequence, Tuple, TypedDict, Union)
from pydantic import Field
from starlette.datastructures import Headers
from typing_extensions import Annotated
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
ConversationMessage,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages_futures)
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest,
DetokenizeRequest,
EmbeddingRequest, ErrorResponse,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
ErrorResponse,
LoadLoraAdapterRequest,
ModelCard, ModelList,
ModelPermission,
TokenizeChatRequest,
TokenizeCompletionRequest,
TokenizeRequest,
UnloadLoraAdapterRequest)
from vllm.entrypoints.openai.tool_parsers import ToolParser
# yapf: enable
from vllm.inputs import TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@@ -31,8 +41,10 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import AtomicCounter
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import AtomicCounter, is_list_of
logger = init_logger(__name__)
@@ -56,8 +68,14 @@ class LoRAModulePath:
base_model_name: Optional[str] = None
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
EmbeddingRequest, TokenizeRequest]
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
EmbeddingCompletionRequest,
TokenizeCompletionRequest]
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
TokenizeChatRequest]
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest]
class TextTokensPrompt(TypedDict):
@@ -65,6 +83,9 @@ class TextTokensPrompt(TypedDict):
prompt_token_ids: List[int]
RequestPrompt = Union[List[int], str, TextTokensPrompt]
class OpenAIServing:
def __init__(
@@ -246,7 +267,8 @@ class OpenAIServing:
token_num = len(input_ids)
# Note: EmbeddingRequest doesn't have max_tokens
if isinstance(request, EmbeddingRequest):
if isinstance(request,
(EmbeddingChatRequest, EmbeddingCompletionRequest)):
if token_num > self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
@@ -373,10 +395,115 @@ class OpenAIServing:
truncate_prompt_tokens=truncate_prompt_tokens,
)
def _preprocess_completion(
self,
request: CompletionLikeRequest,
tokenizer: AnyTokenizer,
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> Tuple[Sequence[TextTokensPrompt], List[TokensPrompt]]:
request_prompts = [
request_prompt
for request_prompt in self._tokenize_prompt_input_or_inputs(
request,
tokenizer,
input_or_inputs,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
]
engine_prompts = [
TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"])
for request_prompt in request_prompts
]
return request_prompts, engine_prompts
async def _preprocess_chat(
self,
request: ChatLikeRequest,
tokenizer: AnyTokenizer,
messages: List[ChatCompletionMessageParam],
chat_template: Optional[str] = None,
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tool_dicts: Optional[List[Dict[str, Any]]] = None,
documents: Optional[List[Dict[str, str]]] = None,
chat_template_kwargs: Optional[Dict[str, Any]] = None,
tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = False,
) -> Tuple[List[ConversationMessage], Sequence[RequestPrompt],
List[TokensPrompt]]:
conversation, mm_data_future = parse_chat_messages_futures(
messages,
self.model_config,
tokenizer,
)
request_prompt: Union[str, List[int]]
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
if is_mistral_tokenizer:
request_prompt = apply_mistral_chat_template(
tokenizer,
messages=messages,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tool_dicts,
documents=documents,
**(chat_template_kwargs or {}),
)
else:
request_prompt = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tool_dicts,
documents=documents,
**(chat_template_kwargs or {}),
)
mm_data = await mm_data_future
if tool_parser is not None:
if not isinstance(request, ChatCompletionRequest):
msg = "Tool usage is only supported for Chat Completions API"
raise NotImplementedError(msg)
request = tool_parser(tokenizer).adjust_request(request=request)
if isinstance(request_prompt, str):
prompt_inputs = self._tokenize_prompt_input(
request,
tokenizer,
request_prompt,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
else:
# For MistralTokenizer
assert is_list_of(request_prompt, int), (
"Prompt has to be either a string or a list of token ids")
prompt_inputs = TextTokensPrompt(
prompt=tokenizer.decode(request_prompt),
prompt_token_ids=request_prompt)
engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["prompt_token_ids"])
if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data
return conversation, [request_prompt], [engine_prompt]
def _log_inputs(
self,
request_id: str,
inputs: Union[str, List[int], TextTokensPrompt],
inputs: RequestPrompt,
params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]],
lora_request: Optional[LoRARequest],
@@ -404,6 +531,20 @@ class OpenAIServing:
prompt_adapter_request=prompt_adapter_request,
)
async def _get_trace_headers(
self,
headers: Headers,
) -> Optional[Mapping[str, str]]:
is_tracing_enabled = await self.engine_client.is_tracing_enabled()
if is_tracing_enabled:
return extract_trace_headers(headers)
if contains_trace_headers(headers):
log_tracing_disabled_warning()
return None
@staticmethod
def _get_decoded_token(logprob: Logprob,
token_id: int,