[Frontend] Chat-based Embeddings API (#9759)

This commit is contained in:
Cyrus Leung
2024-11-01 16:13:35 +08:00
committed by GitHub
parent d3aa2a8b2f
commit 06386a64dd
21 changed files with 846 additions and 408 deletions

View File

@@ -11,7 +11,7 @@ from argparse import Namespace
from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
from typing import AsyncIterator, Set
from typing import AsyncIterator, Optional, Set
import uvloop
from fastapi import APIRouter, FastAPI, Request
@@ -51,7 +51,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
@@ -248,22 +248,27 @@ def mount_metrics(app: FastAPI):
app.routes.append(metrics_route)
def chat(request: Request) -> OpenAIServingChat:
def base(request: Request) -> OpenAIServing:
# Reuse the existing instance
return tokenization(request)
def chat(request: Request) -> Optional[OpenAIServingChat]:
return request.app.state.openai_serving_chat
def completion(request: Request) -> OpenAIServingCompletion:
def completion(request: Request) -> Optional[OpenAIServingCompletion]:
return request.app.state.openai_serving_completion
def embedding(request: Request) -> Optional[OpenAIServingEmbedding]:
return request.app.state.openai_serving_embedding
def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization
def embedding(request: Request) -> OpenAIServingEmbedding:
return request.app.state.openai_serving_embedding
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
@@ -277,7 +282,9 @@ async def health(raw_request: Request) -> Response:
@router.post("/tokenize")
async def tokenize(request: TokenizeRequest, raw_request: Request):
generator = await tokenization(raw_request).create_tokenize(request)
handler = tokenization(raw_request)
generator = await handler.create_tokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
@@ -289,7 +296,9 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
@router.post("/detokenize")
async def detokenize(request: DetokenizeRequest, raw_request: Request):
generator = await tokenization(raw_request).create_detokenize(request)
handler = tokenization(raw_request)
generator = await handler.create_detokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
@@ -301,7 +310,9 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request):
@router.get("/v1/models")
async def show_available_models(raw_request: Request):
models = await completion(raw_request).show_available_models()
handler = base(raw_request)
models = await handler.show_available_models()
return JSONResponse(content=models.model_dump())
@@ -314,9 +325,12 @@ async def show_version():
@router.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
handler = chat(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Chat Completions API")
generator = await chat(raw_request).create_chat_completion(
request, raw_request)
generator = await handler.create_chat_completion(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
@@ -330,8 +344,12 @@ async def create_chat_completion(request: ChatCompletionRequest,
@router.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request):
generator = await completion(raw_request).create_completion(
request, raw_request)
handler = completion(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Completions API")
generator = await handler.create_completion(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
@@ -343,8 +361,12 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
@router.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
generator = await embedding(raw_request).create_embedding(
request, raw_request)
handler = embedding(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Embeddings API")
generator = await handler.create_embedding(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
@@ -382,30 +404,26 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
@router.post("/v1/load_lora_adapter")
async def load_lora_adapter(request: LoadLoraAdapterRequest,
raw_request: Request):
response = await chat(raw_request).load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
response = await completion(raw_request).load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
for route in [chat, completion, embedding]:
handler = route(raw_request)
if handler is not None:
response = await handler.load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
return Response(status_code=200, content=response)
@router.post("/v1/unload_lora_adapter")
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
raw_request: Request):
response = await chat(raw_request).unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
response = await completion(raw_request).unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
for route in [chat, completion, embedding]:
handler = route(raw_request)
if handler is not None:
response = await handler.unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
return Response(status_code=200, content=response)
@@ -501,7 +519,8 @@ def init_app_state(
chat_template=args.chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser)
tool_parser=args.tool_call_parser,
) if model_config.task == "generate" else None
state.openai_serving_completion = OpenAIServingCompletion(
engine_client,
model_config,
@@ -510,13 +529,14 @@ def init_app_state(
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
) if model_config.task == "generate" else None
state.openai_serving_embedding = OpenAIServingEmbedding(
engine_client,
model_config,
base_model_paths,
request_logger=request_logger,
)
chat_template=args.chat_template,
) if model_config.task == "embedding" else None
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
model_config,

View File

@@ -708,7 +708,7 @@ class CompletionRequest(OpenAIBaseModel):
return data
class EmbeddingRequest(OpenAIBaseModel):
class EmbeddingCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings
model: str
@@ -720,10 +720,15 @@ class EmbeddingRequest(OpenAIBaseModel):
# doc: begin-embedding-pooling-params
additional_data: Optional[Any] = None
# doc: end-embedding-pooling-params
# doc: begin-embedding-extra-params
add_special_tokens: bool = Field(
default=True,
description=(
"If true (the default), special tokens (e.g. BOS) will be added to "
"the prompt."),
)
priority: int = Field(
default=0,
description=(
@@ -737,6 +742,82 @@ class EmbeddingRequest(OpenAIBaseModel):
return PoolingParams(additional_data=self.additional_data)
class EmbeddingChatRequest(OpenAIBaseModel):
model: str
messages: List[ChatCompletionMessageParam]
encoding_format: Literal["float", "base64"] = "float"
dimensions: Optional[int] = None
user: Optional[str] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
# doc: begin-chat-embedding-pooling-params
additional_data: Optional[Any] = None
# doc: end-chat-embedding-pooling-params
# doc: begin-chat-embedding-extra-params
add_generation_prompt: bool = Field(
default=True,
description=
("If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."),
)
continue_final_message: bool = Field(
default=False,
description=
("If this is set, the chat will be formatted so that the final "
"message in the chat is open-ended, without any EOS tokens. The "
"model will continue this message rather than starting a new one. "
"This allows you to \"prefill\" part of the model's response for it. "
"Cannot be used at the same time as `add_generation_prompt`."),
)
add_special_tokens: bool = Field(
default=False,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to false (as is the "
"default)."),
)
chat_template: Optional[str] = Field(
default=None,
description=(
"A Jinja template to use for this conversion. "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."),
)
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
default=None,
description=("Additional kwargs to pass to the template renderer. "
"Will be accessible by the chat template."),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."))
# doc: end-chat-embedding-extra-params
@model_validator(mode="before")
@classmethod
def check_generation_prompt(cls, data):
if data.get("continue_final_message") and data.get(
"add_generation_prompt"):
raise ValueError("Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True.")
return data
def to_pooling_params(self):
return PoolingParams(additional_data=self.additional_data)
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
class CompletionLogProbs(OpenAIBaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
@@ -799,7 +880,7 @@ class EmbeddingResponseData(OpenAIBaseModel):
class EmbeddingResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
id: str = Field(default_factory=lambda: f"embd-{random_uuid()}")
object: str = "list"
created: int = Field(default_factory=lambda: int(time.time()))
model: str

View File

@@ -217,13 +217,14 @@ async def main(args):
prompt_adapters=None,
request_logger=request_logger,
chat_template=None,
)
) if model_config.task == "generate" else None
openai_serving_embedding = OpenAIServingEmbedding(
engine,
model_config,
base_model_paths,
request_logger=request_logger,
)
chat_template=None,
) if model_config.task == "embedding" else None
tracker = BatchProgressTracker()
logger.info("Reading batch from %s...", args.input_file)
@@ -240,14 +241,31 @@ async def main(args):
# Determine the type of request and run it.
if request.url == "/v1/chat/completions":
response_futures.append(
run_request(openai_serving_chat.create_chat_completion,
request, tracker))
handler_fn = (None if openai_serving_chat is None else
openai_serving_chat.create_chat_completion)
if handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
error_msg=
"The model does not support Chat Completions API",
))
continue
response_futures.append(run_request(handler_fn, request, tracker))
tracker.submitted()
elif request.url == "/v1/embeddings":
response_futures.append(
run_request(openai_serving_embedding.create_embedding, request,
tracker))
handler_fn = (None if openai_serving_embedding is None else
openai_serving_embedding.create_embedding)
if handler_fn is None:
response_futures.append(
make_async_error_request_output(
request,
error_msg="The model does not support Embeddings API",
))
continue
response_futures.append(run_request(handler_fn, request, tracker))
tracker.submitted()
else:
response_futures.append(

View File

@@ -10,11 +10,7 @@ from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage,
apply_hf_chat_template,
apply_mistral_chat_template,
load_chat_template,
parse_chat_messages_futures)
from vllm.entrypoints.chat_utils import ConversationMessage, load_chat_template
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProb, ChatCompletionLogProbs,
@@ -27,16 +23,12 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
OpenAIServing,
PromptAdapterPath,
TextTokensPrompt)
PromptAdapterPath)
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import iterate_with_cancellation
@@ -94,12 +86,12 @@ class OpenAIServingChat(OpenAIServing):
raw_request: Optional[Request] = None,
) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
ErrorResponse]:
"""Completion API similar to OpenAI's API.
"""
Chat Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI
ChatCompletion API.
Chat Completion API.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
@@ -118,143 +110,106 @@ class OpenAIServingChat(OpenAIServing):
prompt_adapter_request,
) = self._maybe_get_adapters(request)
model_config = self.model_config
tokenizer = await self.engine_client.get_tokenizer(lora_request)
tool_parser = self.tool_parser
conversation, mm_data_future = parse_chat_messages_futures(
request.messages, model_config, tokenizer)
# validation for OpenAI tools
# tool_choice = "required" is not supported
if request.tool_choice == "required":
return self.create_error_response(
"tool_choice = \"required\" is not supported!")
if (request.tool_choice == "auto" and
not (self.enable_auto_tools and tool_parser is not None)
and not isinstance(tokenizer, MistralTokenizer)):
# for hf tokenizers, "auto" tools requires
# --enable-auto-tool-choice and --tool-call-parser
return self.create_error_response(
"\"auto\" tool choice requires "
"--enable-auto-tool-choice and --tool-call-parser to be set"
)
tool_dicts = None if request.tools is None else [
tool.model_dump() for tool in request.tools
]
prompt: Union[str, List[int]]
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
if is_mistral_tokenizer:
prompt = apply_mistral_chat_template(
tokenizer,
messages=request.messages,
chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tools=tool_dicts,
documents=request.documents,
**(request.chat_template_kwargs or {}),
)
else:
prompt = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tools=tool_dicts,
documents=request.documents,
**(request.chat_template_kwargs or {}),
)
except Exception as e:
logger.exception("Error in applying chat template from request")
(
conversation,
request_prompts,
engine_prompts,
) = await self._preprocess_chat(
request,
tokenizer,
request.messages,
chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tool_dicts=tool_dicts,
documents=request.documents,
chat_template_kwargs=request.chat_template_kwargs,
tool_parser=tool_parser,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
try:
mm_data = await mm_data_future
except Exception as e:
logger.exception("Error in loading multi-modal data")
return self.create_error_response(str(e))
# validation for OpenAI tools
# tool_choice = "required" is not supported
if request.tool_choice == "required":
return self.create_error_response(
"tool_choice = \"required\" is not supported!")
if not is_mistral_tokenizer and request.tool_choice == "auto" and not (
self.enable_auto_tools and self.tool_parser is not None):
# for hf tokenizers, "auto" tools requires
# --enable-auto-tool-choice and --tool-call-parser
return self.create_error_response(
"\"auto\" tool choice requires "
"--enable-auto-tool-choice and --tool-call-parser to be set")
request_id = f"chat-{request.request_id}"
request_id = f"chatcmpl-{request.request_id}"
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[RequestOutput, None]] = []
try:
if self.enable_auto_tools and self.tool_parser:
request = self.tool_parser(tokenizer).adjust_request(
request=request)
for i, engine_prompt in enumerate(engine_prompts):
sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len(
engine_prompt["prompt_token_ids"])
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
default_max_tokens)
else:
sampling_params = request.to_sampling_params(
default_max_tokens)
if isinstance(prompt, str):
prompt_inputs = self._tokenize_prompt_input(
request,
tokenizer,
prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
else:
assert isinstance(prompt, list) and isinstance(
prompt[0], int
), "Prompt has to be either a string or a list of token ids"
prompt_inputs = TextTokensPrompt(
prompt=tokenizer.decode(prompt), prompt_token_ids=prompt)
self._log_inputs(request_id,
request_prompts[i],
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
assert prompt_inputs is not None
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len(
prompt_inputs["prompt_token_ids"])
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
default_max_tokens)
else:
sampling_params = request.to_sampling_params(
default_max_tokens)
if isinstance(sampling_params, BeamSearchParams):
generator = self.engine_client.beam_search(
prompt=engine_prompt,
model_config=self.model_config,
request_id=request_id,
params=sampling_params,
)
else:
generator = self.engine_client.generate(
engine_prompt,
sampling_params,
request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=request.priority,
)
self._log_inputs(request_id,
prompt_inputs,
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
engine_inputs = TokensPrompt(
prompt_token_ids=prompt_inputs["prompt_token_ids"])
if mm_data is not None:
engine_inputs["multi_modal_data"] = mm_data
is_tracing_enabled = (await
self.engine_client.is_tracing_enabled())
trace_headers = None
if is_tracing_enabled and raw_request:
trace_headers = extract_trace_headers(raw_request.headers)
if (not is_tracing_enabled and raw_request
and contains_trace_headers(raw_request.headers)):
log_tracing_disabled_warning()
if isinstance(sampling_params, BeamSearchParams):
result_generator = self.engine_client.beam_search(
prompt=engine_inputs,
model_config=self.model_config,
request_id=request_id,
params=sampling_params,
)
else:
result_generator = self.engine_client.generate(
engine_inputs,
sampling_params,
request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=request.priority,
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
assert len(generators) == 1
result_generator, = generators
if raw_request:
result_generator = iterate_with_cancellation(
result_generator, raw_request.is_disconnected)
@@ -626,6 +581,9 @@ class OpenAIServingChat(OpenAIServing):
final_res = res
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
assert final_res is not None

View File

@@ -1,7 +1,6 @@
import asyncio
import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional)
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import Tuple, Union, cast
@@ -30,18 +29,11 @@ from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__)
TypeTokenIDs = List[int]
TypeTopLogProbs = List[Optional[Dict[int, float]]]
TypeCreateLogProbsFn = Callable[
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
class OpenAIServingCompletion(OpenAIServing):
@@ -101,8 +93,6 @@ class OpenAIServingCompletion(OpenAIServing):
if raw_request:
raw_request.state.request_metadata = request_metadata
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[RequestOutput, None]] = []
try:
(
lora_request,
@@ -111,19 +101,24 @@ class OpenAIServingCompletion(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer(lora_request)
prompts = list(
self._tokenize_prompt_input_or_inputs(
request,
tokenizer,
request.prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
))
request_prompts, engine_prompts = self._preprocess_completion(
request,
tokenizer,
request.prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
for i, prompt_inputs in enumerate(prompts):
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len(
prompt_inputs["prompt_token_ids"])
engine_prompt["prompt_token_ids"])
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
default_max_tokens)
@@ -134,36 +129,24 @@ class OpenAIServingCompletion(OpenAIServing):
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,
prompt_inputs,
request_prompts[i],
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
is_tracing_enabled = (await
self.engine_client.is_tracing_enabled())
trace_headers = None
if is_tracing_enabled:
trace_headers = extract_trace_headers(raw_request.headers)
if not is_tracing_enabled and contains_trace_headers(
raw_request.headers):
log_tracing_disabled_warning()
trace_headers = (await
self._get_trace_headers(raw_request.headers))
if isinstance(sampling_params, BeamSearchParams):
generator = self.engine_client.beam_search(
prompt={
"prompt_token_ids":
prompt_inputs["prompt_token_ids"]
},
prompt=engine_prompt,
model_config=self.model_config,
request_id=request_id,
params=sampling_params,
)
else:
generator = self.engine_client.generate(
{
"prompt_token_ids":
prompt_inputs["prompt_token_ids"]
},
engine_prompt,
sampling_params,
request_id_item,
lora_request=lora_request,
@@ -180,6 +163,8 @@ class OpenAIServingCompletion(OpenAIServing):
result_generator = merge_async_iterators(
*generators, is_cancelled=raw_request.is_disconnected)
num_prompts = len(engine_prompts)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use
# beam search.
@@ -195,16 +180,22 @@ class OpenAIServingCompletion(OpenAIServing):
request_id,
created_time,
model_name,
num_prompts=len(prompts),
num_prompts=num_prompts,
tokenizer=tokenizer,
request_metadata=request_metadata)
# Non-streaming response
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
final_res_batch: List[Optional[RequestOutput]] = [None] * num_prompts
try:
async for i, res in result_generator:
final_res_batch[i] = res
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
try:
for i, final_res in enumerate(final_res_batch):
assert final_res is not None
@@ -212,7 +203,7 @@ class OpenAIServingCompletion(OpenAIServing):
# We did not pass it into vLLM engine to avoid being redundant
# with the inputs token IDs
if final_res.prompt is None:
final_res.prompt = prompts[i]["prompt"]
final_res.prompt = request_prompts[i]["prompt"]
final_res_batch_checked = cast(List[RequestOutput],
final_res_batch)
@@ -226,8 +217,6 @@ class OpenAIServingCompletion(OpenAIServing):
tokenizer,
request_metadata,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))

View File

@@ -9,8 +9,10 @@ from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData,
ErrorResponse, UsageInfo)
@@ -21,8 +23,6 @@ from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__)
TypeTokenIDs = List[int]
def _get_embedding(
output: EmbeddingOutput,
@@ -76,6 +76,7 @@ class OpenAIServingEmbedding(OpenAIServing):
base_model_paths: List[BaseModelPath],
*,
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
):
super().__init__(engine_client=engine_client,
model_config=model_config,
@@ -83,21 +84,20 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_modules=None,
prompt_adapters=None,
request_logger=request_logger)
self._enabled = self._check_embedding_mode(
model_config.task == "embedding")
self.chat_template = load_chat_template(chat_template)
async def create_embedding(
self,
request: EmbeddingRequest,
raw_request: Optional[Request] = None,
) -> Union[EmbeddingResponse, ErrorResponse]:
"""Completion API similar to OpenAI's API.
"""
Embedding API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
"""
if not self._enabled:
return self.create_error_response("Embedding API disabled")
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
@@ -122,8 +122,6 @@ class OpenAIServingEmbedding(OpenAIServing):
"greater than max_model_len."
" Please, select a smaller truncation size.")
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
try:
(
lora_request,
@@ -132,32 +130,60 @@ class OpenAIServingEmbedding(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer(lora_request)
if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for embedding models")
if isinstance(request, EmbeddingChatRequest):
(
_,
request_prompts,
engine_prompts,
) = await self._preprocess_chat(
request,
tokenizer,
request.messages,
chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
else:
request_prompts, engine_prompts = self._preprocess_completion(
request,
tokenizer,
request.input,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
try:
pooling_params = request.to_pooling_params()
prompts = list(
self._tokenize_prompt_input_or_inputs(request, tokenizer,
request.input,
truncate_prompt_tokens))
for i, prompt_inputs in enumerate(prompts):
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,
prompt_inputs,
request_prompts[i],
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
if prompt_adapter_request is not None:
raise NotImplementedError(
"Prompt adapter is not supported "
"for embedding models")
trace_headers = (None if raw_request is None else await
self._get_trace_headers(raw_request.headers))
generator = self.engine_client.encode(
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
engine_prompt,
pooling_params,
request_id_item,
lora_request=lora_request,
trace_headers=trace_headers,
priority=request.priority,
)
@@ -171,13 +197,18 @@ class OpenAIServingEmbedding(OpenAIServing):
is_cancelled=raw_request.is_disconnected if raw_request else None,
)
num_prompts = len(engine_prompts)
# Non-streaming response
final_res_batch: List[Optional[EmbeddingRequestOutput]]
final_res_batch = [None] * len(prompts)
final_res_batch = [None] * num_prompts
try:
async for i, res in result_generator:
final_res_batch[i] = res
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
try:
for final_res in final_res_batch:
assert final_res is not None
@@ -187,18 +218,8 @@ class OpenAIServingEmbedding(OpenAIServing):
response = request_output_to_embedding_response(
final_res_batch_checked, request_id, created_time, model_name,
encoding_format)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return response
def _check_embedding_mode(self, embedding_mode: bool) -> bool:
if not embedding_mode:
logger.warning(
"embedding_mode is False. Embedding API will not work.")
else:
logger.info("Activating the server engine with embedding enabled.")
return embedding_mode

View File

@@ -2,28 +2,38 @@ import json
import pathlib
from dataclasses import dataclass
from http import HTTPStatus
from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
Optional, Sequence, Tuple, TypedDict, Union)
from pydantic import Field
from starlette.datastructures import Headers
from typing_extensions import Annotated
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
ConversationMessage,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages_futures)
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest,
DetokenizeRequest,
EmbeddingRequest, ErrorResponse,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
ErrorResponse,
LoadLoraAdapterRequest,
ModelCard, ModelList,
ModelPermission,
TokenizeChatRequest,
TokenizeCompletionRequest,
TokenizeRequest,
UnloadLoraAdapterRequest)
from vllm.entrypoints.openai.tool_parsers import ToolParser
# yapf: enable
from vllm.inputs import TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@@ -31,8 +41,10 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import AtomicCounter
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import AtomicCounter, is_list_of
logger = init_logger(__name__)
@@ -56,8 +68,14 @@ class LoRAModulePath:
base_model_name: Optional[str] = None
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
EmbeddingRequest, TokenizeRequest]
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
EmbeddingCompletionRequest,
TokenizeCompletionRequest]
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
TokenizeChatRequest]
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest]
class TextTokensPrompt(TypedDict):
@@ -65,6 +83,9 @@ class TextTokensPrompt(TypedDict):
prompt_token_ids: List[int]
RequestPrompt = Union[List[int], str, TextTokensPrompt]
class OpenAIServing:
def __init__(
@@ -246,7 +267,8 @@ class OpenAIServing:
token_num = len(input_ids)
# Note: EmbeddingRequest doesn't have max_tokens
if isinstance(request, EmbeddingRequest):
if isinstance(request,
(EmbeddingChatRequest, EmbeddingCompletionRequest)):
if token_num > self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
@@ -373,10 +395,115 @@ class OpenAIServing:
truncate_prompt_tokens=truncate_prompt_tokens,
)
def _preprocess_completion(
self,
request: CompletionLikeRequest,
tokenizer: AnyTokenizer,
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> Tuple[Sequence[TextTokensPrompt], List[TokensPrompt]]:
request_prompts = [
request_prompt
for request_prompt in self._tokenize_prompt_input_or_inputs(
request,
tokenizer,
input_or_inputs,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
]
engine_prompts = [
TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"])
for request_prompt in request_prompts
]
return request_prompts, engine_prompts
async def _preprocess_chat(
self,
request: ChatLikeRequest,
tokenizer: AnyTokenizer,
messages: List[ChatCompletionMessageParam],
chat_template: Optional[str] = None,
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tool_dicts: Optional[List[Dict[str, Any]]] = None,
documents: Optional[List[Dict[str, str]]] = None,
chat_template_kwargs: Optional[Dict[str, Any]] = None,
tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = False,
) -> Tuple[List[ConversationMessage], Sequence[RequestPrompt],
List[TokensPrompt]]:
conversation, mm_data_future = parse_chat_messages_futures(
messages,
self.model_config,
tokenizer,
)
request_prompt: Union[str, List[int]]
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
if is_mistral_tokenizer:
request_prompt = apply_mistral_chat_template(
tokenizer,
messages=messages,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tool_dicts,
documents=documents,
**(chat_template_kwargs or {}),
)
else:
request_prompt = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tool_dicts,
documents=documents,
**(chat_template_kwargs or {}),
)
mm_data = await mm_data_future
if tool_parser is not None:
if not isinstance(request, ChatCompletionRequest):
msg = "Tool usage is only supported for Chat Completions API"
raise NotImplementedError(msg)
request = tool_parser(tokenizer).adjust_request(request=request)
if isinstance(request_prompt, str):
prompt_inputs = self._tokenize_prompt_input(
request,
tokenizer,
request_prompt,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
else:
# For MistralTokenizer
assert is_list_of(request_prompt, int), (
"Prompt has to be either a string or a list of token ids")
prompt_inputs = TextTokensPrompt(
prompt=tokenizer.decode(request_prompt),
prompt_token_ids=request_prompt)
engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["prompt_token_ids"])
if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data
return conversation, [request_prompt], [engine_prompt]
def _log_inputs(
self,
request_id: str,
inputs: Union[str, List[int], TextTokensPrompt],
inputs: RequestPrompt,
params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]],
lora_request: Optional[LoRARequest],
@@ -404,6 +531,20 @@ class OpenAIServing:
prompt_adapter_request=prompt_adapter_request,
)
async def _get_trace_headers(
self,
headers: Headers,
) -> Optional[Mapping[str, str]]:
is_tracing_enabled = await self.engine_client.is_tracing_enabled()
if is_tracing_enabled:
return extract_trace_headers(headers)
if contains_trace_headers(headers):
log_tracing_disabled_warning()
return None
@staticmethod
def _get_decoded_token(logprob: Logprob,
token_id: int,

View File

@@ -2,10 +2,7 @@ from typing import List, Optional, Union
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
apply_mistral_chat_template,
load_chat_template,
parse_chat_messages_futures)
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
@@ -20,7 +17,6 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.utils import random_uuid
logger = init_logger(__name__)
@@ -62,59 +58,51 @@ class OpenAIServingTokenization(OpenAIServing):
request_id = f"tokn-{random_uuid()}"
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
try:
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
prompt: Union[str, List[int]]
if isinstance(request, TokenizeChatRequest):
model_config = self.model_config
conversation, mm_data_future = parse_chat_messages_futures(
request.messages, model_config, tokenizer)
mm_data = await mm_data_future
if mm_data:
logger.warning(
"Multi-modal inputs are ignored during tokenization")
if isinstance(tokenizer, MistralTokenizer):
prompt = apply_mistral_chat_template(
if isinstance(request, TokenizeChatRequest):
(
_,
request_prompts,
engine_prompts,
) = await self._preprocess_chat(
request,
tokenizer,
messages=request.messages,
request.messages,
chat_template=self.chat_template,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
add_special_tokens=request.add_special_tokens,
)
else:
prompt = apply_hf_chat_template(
request_prompts, engine_prompts = self._preprocess_completion(
request,
tokenizer,
conversation=conversation,
chat_template=self.chat_template,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
request.prompt,
add_special_tokens=request.add_special_tokens,
)
else:
prompt = request.prompt
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
self._log_inputs(request_id,
prompt,
params=None,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
input_ids: List[int] = []
for i, engine_prompt in enumerate(engine_prompts):
self._log_inputs(request_id,
request_prompts[i],
params=None,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
# Silently ignore prompt adapter since it does not affect tokenization
# Silently ignore prompt adapter since it does not affect
# tokenization (Unlike in Embeddings API where an error is raised)
prompt_input = self._tokenize_prompt_input(
request,
tokenizer,
prompt,
add_special_tokens=request.add_special_tokens,
)
input_ids = prompt_input["prompt_token_ids"]
input_ids.extend(engine_prompt["prompt_token_ids"])
return TokenizeResponse(tokens=input_ids,
count=len(input_ids),
@@ -143,9 +131,8 @@ class OpenAIServingTokenization(OpenAIServing):
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for tokenization")
# Silently ignore prompt adapter since it does not affect tokenization
# (Unlike in Embeddings API where an error is raised)
prompt_input = self._tokenize_prompt_input(
request,