Files
vllm/vllm/entrypoints/openai/serving_engine.py

248 lines
10 KiB
Python
Raw Normal View History

import json
import pathlib
from dataclasses import dataclass
2024-01-17 05:33:14 +00:00
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Tuple, Union
2024-03-25 23:59:47 +09:00
from pydantic import Field
from transformers import PreTrainedTokenizer
from typing_extensions import Annotated
from vllm.config import ModelConfig
2024-01-17 05:33:14 +00:00
from vllm.engine.async_llm_engine import AsyncLLMEngine
2024-03-25 23:59:47 +09:00
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest,
DetokenizeRequest,
EmbeddingRequest, ErrorResponse,
ModelCard, ModelList,
ModelPermission, TokenizeRequest)
2024-03-25 23:59:47 +09:00
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import Logprob
2024-01-17 05:33:14 +00:00
logger = init_logger(__name__)
@dataclass
class PromptAdapterPath:
name: str
local_path: str
@dataclass
class LoRAModulePath:
name: str
local_path: str
2024-01-17 05:33:14 +00:00
class OpenAIServing:
def __init__(
self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]] = None,
):
super().__init__()
2024-01-17 05:33:14 +00:00
self.engine = engine
self.model_config = model_config
self.max_model_len = model_config.max_model_len
self.served_model_names = served_model_names
self.lora_requests = []
if lora_modules is not None:
self.lora_requests = [
LoRARequest(
lora_name=lora.name,
lora_int_id=i,
lora_local_path=lora.local_path,
) for i, lora in enumerate(lora_modules, start=1)
]
2024-01-17 05:33:14 +00:00
self.prompt_adapter_requests = []
if prompt_adapters is not None:
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
with pathlib.Path(prompt_adapter.local_path,
"adapter_config.json").open() as f:
adapter_config = json.load(f)
num_virtual_tokens = adapter_config["num_virtual_tokens"]
self.prompt_adapter_requests.append(
PromptAdapterRequest(
prompt_adapter_name=prompt_adapter.name,
prompt_adapter_id=i,
prompt_adapter_local_path=prompt_adapter.local_path,
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
2024-01-17 05:33:14 +00:00
async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
model_cards = [
ModelCard(id=served_model_name,
max_model_len=self.max_model_len,
root=self.served_model_names[0],
2024-01-17 05:33:14 +00:00
permission=[ModelPermission()])
for served_model_name in self.served_model_names
2024-01-17 05:33:14 +00:00
]
lora_cards = [
ModelCard(id=lora.lora_name,
root=self.served_model_names[0],
permission=[ModelPermission()])
for lora in self.lora_requests
]
prompt_adapter_cards = [
ModelCard(id=prompt_adapter.prompt_adapter_name,
root=self.served_model_names[0],
permission=[ModelPermission()])
for prompt_adapter in self.prompt_adapter_requests
]
model_cards.extend(lora_cards)
model_cards.extend(prompt_adapter_cards)
2024-01-17 05:33:14 +00:00
return ModelList(data=model_cards)
def create_error_response(
self,
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
return ErrorResponse(message=message,
type=err_type,
code=status_code.value)
def create_streaming_error_response(
self,
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
json_str = json.dumps({
"error":
self.create_error_response(message=message,
err_type=err_type,
status_code=status_code).model_dump()
})
return json_str
async def _check_model(
self, request: Union[ChatCompletionRequest, CompletionRequest,
DetokenizeRequest, EmbeddingRequest,
TokenizeRequest]
) -> Optional[ErrorResponse]:
if request.model in self.served_model_names:
return None
if request.model in [lora.lora_name for lora in self.lora_requests]:
return None
if request.model in [
prompt_adapter.prompt_adapter_name
for prompt_adapter in self.prompt_adapter_requests
]:
return None
2024-01-17 05:33:14 +00:00
return self.create_error_response(
message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_adapter(
self, request: Union[CompletionRequest, ChatCompletionRequest,
EmbeddingRequest, TokenizeRequest,
DetokenizeRequest]
) -> Tuple[Optional[str], Optional[Union[LoRARequest,
PromptAdapterRequest]]]:
if request.model in self.served_model_names:
return None, None
for lora in self.lora_requests:
if request.model == lora.lora_name:
return 'LoRA', lora
for prompt_adapter in self.prompt_adapter_requests:
if request.model == prompt_adapter.prompt_adapter_name:
return 'PromptAdapter', prompt_adapter
# if _check_model has been called earlier, this will be unreachable
raise ValueError(f"The model `{request.model}` does not exist.")
async def _validate_prompt_and_tokenize(
self,
request: Union[ChatCompletionRequest, CompletionRequest,
DetokenizeRequest, EmbeddingRequest,
TokenizeRequest],
tokenizer: "PreTrainedTokenizer",
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[Annotated[int,
Field(ge=1)]] = None,
add_special_tokens: Optional[bool] = True
) -> Tuple[List[int], str]:
if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.")
if prompt and prompt_ids:
raise ValueError(
"Only one of prompt or prompt_ids should be provided.")
if prompt_ids is None:
# When using OpenAIServingChat for chat completions, for
# most models the special tokens (e.g., BOS) have already
# been added by the chat template. Therefore, we do not
# need to add them again.
# Set add_special_tokens to False (by default) to avoid
# adding the BOS tokens again.
tokenizer_kwargs: Dict[str, Any] = {
"add_special_tokens": add_special_tokens
}
if truncate_prompt_tokens is not None:
tokenizer_kwargs.update({
"truncation": True,
"max_length": truncate_prompt_tokens,
})
input_ids = tokenizer(prompt, **tokenizer_kwargs).input_ids
elif truncate_prompt_tokens is not None:
input_ids = prompt_ids[-truncate_prompt_tokens:]
else:
input_ids = prompt_ids
input_text = prompt if prompt is not None else tokenizer.decode(
input_ids)
2024-01-17 05:33:14 +00:00
token_num = len(input_ids)
# Note: EmbeddingRequest doesn't have max_tokens
if isinstance(request, EmbeddingRequest):
if token_num > self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the input for embedding "
f"generation. Please reduce the length of the input.", )
return input_ids, input_text
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation
if isinstance(request, (TokenizeRequest, DetokenizeRequest)):
return input_ids, input_text
2024-01-17 05:33:14 +00:00
if request.max_tokens is None:
if token_num >= self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the messages, "
f"Please reduce the length of the messages.", )
2024-01-17 05:33:14 +00:00
request.max_tokens = self.max_model_len - token_num
2024-01-17 05:33:14 +00:00
if token_num + request.max_tokens > self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{request.max_tokens + token_num} tokens "
2024-01-17 05:33:14 +00:00
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.", )
else:
return input_ids, input_text
@staticmethod
def _get_decoded_token(logprob: Logprob, token_id: int,
tokenizer: PreTrainedTokenizer) -> str:
if logprob.decoded_token is not None:
return logprob.decoded_token
return tokenizer.decode(token_id)