[Frontend] Refactor prompt processing (#4028)

Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Cyrus Leung
2024-07-23 01:13:53 +08:00
committed by GitHub
parent 89c1c6a196
commit 739b61a348
24 changed files with 699 additions and 391 deletions

View File

@@ -2,23 +2,33 @@ import json
import pathlib
from dataclasses import dataclass
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
from pydantic import Field
from transformers import PreTrainedTokenizer
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing_extensions import Annotated
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest,
DetokenizeRequest,
EmbeddingRequest, ErrorResponse,
ModelCard, ModelList,
ModelPermission, TokenizeRequest)
ModelPermission,
TokenizeChatRequest,
TokenizeCompletionRequest,
TokenizeRequest)
# yapf: enable
from vllm.inputs import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob
logger = init_logger(__name__)
@@ -36,6 +46,17 @@ class LoRAModulePath:
local_path: str
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
EmbeddingRequest, TokenizeRequest]
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
class TextTokensPrompt(TypedDict):
prompt: str
prompt_token_ids: List[int]
class OpenAIServing:
def __init__(
@@ -43,8 +64,10 @@ class OpenAIServing:
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]] = None,
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
):
super().__init__()
@@ -78,6 +101,8 @@ class OpenAIServing:
prompt_adapter_local_path=prompt_adapter.local_path,
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
self.request_logger = request_logger
async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
model_cards = [
@@ -126,9 +151,8 @@ class OpenAIServing:
return json_str
async def _check_model(
self, request: Union[ChatCompletionRequest, CompletionRequest,
DetokenizeRequest, EmbeddingRequest,
TokenizeRequest]
self,
request: AnyRequest,
) -> Optional[ErrorResponse]:
if request.model in self.served_model_names:
return None
@@ -144,64 +168,65 @@ class OpenAIServing:
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_adapter(
self, request: Union[CompletionRequest, ChatCompletionRequest,
EmbeddingRequest, TokenizeRequest,
DetokenizeRequest]
) -> Tuple[Optional[str], Optional[Union[LoRARequest,
PromptAdapterRequest]]]:
def _maybe_get_adapters(
self, request: AnyRequest
) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
None, PromptAdapterRequest]]:
if request.model in self.served_model_names:
return None, None
for lora in self.lora_requests:
if request.model == lora.lora_name:
return 'LoRA', lora
return lora, None
for prompt_adapter in self.prompt_adapter_requests:
if request.model == prompt_adapter.prompt_adapter_name:
return 'PromptAdapter', prompt_adapter
return None, prompt_adapter
# if _check_model has been called earlier, this will be unreachable
raise ValueError(f"The model `{request.model}` does not exist.")
async def _validate_prompt_and_tokenize(
self,
request: Union[ChatCompletionRequest, CompletionRequest,
DetokenizeRequest, EmbeddingRequest,
TokenizeRequest],
tokenizer: "PreTrainedTokenizer",
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[Annotated[int,
Field(ge=1)]] = None,
add_special_tokens: Optional[bool] = True
) -> Tuple[List[int], str]:
if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.")
if prompt and prompt_ids:
raise ValueError(
"Only one of prompt or prompt_ids should be provided.")
if prompt_ids is None:
# When using OpenAIServingChat for chat completions, for
# most models the special tokens (e.g., BOS) have already
# been added by the chat template. Therefore, we do not
# need to add them again.
# Set add_special_tokens to False (by default) to avoid
# adding the BOS tokens again.
tokenizer_kwargs: Dict[str, Any] = {
"add_special_tokens": add_special_tokens
}
if truncate_prompt_tokens is not None:
tokenizer_kwargs.update({
"truncation": True,
"max_length": truncate_prompt_tokens,
})
input_ids = tokenizer(prompt, **tokenizer_kwargs).input_ids
elif truncate_prompt_tokens is not None:
input_ids = prompt_ids[-truncate_prompt_tokens:]
def _normalize_prompt_text_to_input(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt: str,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
add_special_tokens: bool,
) -> TextTokensPrompt:
if truncate_prompt_tokens is None:
encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
else:
input_ids = prompt_ids
encoded = tokenizer(prompt,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=truncate_prompt_tokens)
input_text = prompt if prompt is not None else tokenizer.decode(
input_ids)
input_ids = encoded.input_ids
input_text = prompt
return self._validate_input(request, input_ids, input_text)
def _normalize_prompt_tokens_to_input(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_ids: List[int],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
) -> TextTokensPrompt:
if truncate_prompt_tokens is None:
input_ids = prompt_ids
else:
input_ids = prompt_ids[-truncate_prompt_tokens:]
input_text = tokenizer.decode(input_ids)
return self._validate_input(request, input_ids, input_text)
def _validate_input(
self,
request: AnyRequest,
input_ids: List[int],
input_text: str,
) -> TextTokensPrompt:
token_num = len(input_ids)
# Note: EmbeddingRequest doesn't have max_tokens
@@ -211,13 +236,16 @@ class OpenAIServing:
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the input for embedding "
f"generation. Please reduce the length of the input.", )
return input_ids, input_text
f"generation. Please reduce the length of the input.")
return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids)
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation
if isinstance(request, (TokenizeRequest, DetokenizeRequest)):
return input_ids, input_text
if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
DetokenizeRequest)):
return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids)
if request.max_tokens is None:
if token_num >= self.max_model_len:
@@ -225,7 +253,7 @@ class OpenAIServing:
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the messages, "
f"Please reduce the length of the messages.", )
f"Please reduce the length of the messages.")
request.max_tokens = self.max_model_len - token_num
if token_num + request.max_tokens > self.max_model_len:
@@ -235,13 +263,132 @@ class OpenAIServing:
f"{request.max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.", )
f"Please reduce the length of the messages or completion.")
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
def _tokenize_prompt_input(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_input: Union[str, List[int]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> TextTokensPrompt:
"""
A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
that assumes single input.
"""
return next(
self._tokenize_prompt_inputs(
request,
tokenizer,
[prompt_input],
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
))
def _tokenize_prompt_inputs(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_inputs: Iterable[Union[str, List[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> Iterator[TextTokensPrompt]:
"""
A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
that assumes multiple inputs.
"""
for text in prompt_inputs:
if isinstance(text, str):
yield self._normalize_prompt_text_to_input(
request,
tokenizer,
prompt=text,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
else:
yield self._normalize_prompt_tokens_to_input(
request,
tokenizer,
prompt_ids=text,
truncate_prompt_tokens=truncate_prompt_tokens,
)
def _tokenize_prompt_input_or_inputs(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> Iterator[TextTokensPrompt]:
"""
Tokenize/detokenize depending on the input format.
According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
, each input can be a string or array of tokens. Note that each request
can pass one or more inputs.
"""
for prompt_input in parse_and_batch_prompt(input_or_inputs):
# Although our type checking is based on mypy,
# VSCode Pyright extension should still work properly
# "is True" is required for Pyright to perform type narrowing
# See: https://github.com/microsoft/pyright/issues/7672
if prompt_input["is_tokens"] is False:
yield self._normalize_prompt_text_to_input(
request,
tokenizer,
prompt=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
else:
yield self._normalize_prompt_tokens_to_input(
request,
tokenizer,
prompt_ids=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens,
)
def _log_inputs(
self,
request_id: str,
inputs: Union[str, List[int], TextTokensPrompt],
params: Optional[Union[SamplingParams, PoolingParams]],
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None:
if self.request_logger is None:
return
if isinstance(inputs, str):
prompt = inputs
prompt_token_ids = None
elif isinstance(inputs, list):
prompt = None
prompt_token_ids = inputs
else:
return input_ids, input_text
prompt = inputs["prompt"]
prompt_token_ids = inputs["prompt_token_ids"]
self.request_logger.log_inputs(
request_id,
prompt,
prompt_token_ids,
params=params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
@staticmethod
def _get_decoded_token(logprob: Logprob, token_id: int,
tokenizer: PreTrainedTokenizer) -> str:
def _get_decoded_token(
logprob: Logprob,
token_id: int,
tokenizer: AnyTokenizer,
) -> str:
if logprob.decoded_token is not None:
return logprob.decoded_token
return tokenizer.decode(token_id)