[Frontend] don't block event loop in tokenization (preprocess) in OpenAI compatible server (#10635)

Signed-off-by: Tomer Asida <tomera@ai21.com>
This commit is contained in:
tomeras91
2024-11-27 23:21:10 +02:00
committed by GitHub
parent 9b4b150395
commit 395b1c7454
7 changed files with 206 additions and 56 deletions

View File

@@ -101,7 +101,7 @@ class OpenAIServingCompletion(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer(lora_request)
request_prompts, engine_prompts = self._preprocess_completion(
request_prompts, engine_prompts = await self._preprocess_completion(
request,
tokenizer,
request.prompt,

View File

@@ -156,13 +156,14 @@ class OpenAIServingEmbedding(OpenAIServing):
add_special_tokens=request.add_special_tokens,
)
else:
request_prompts, engine_prompts = self._preprocess_completion(
request,
tokenizer,
request.input,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
(request_prompts,
engine_prompts) = await self._preprocess_completion(
request,
tokenizer,
request.input,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))

View File

@@ -1,5 +1,6 @@
import json
import pathlib
from concurrent.futures.thread import ThreadPoolExecutor
from dataclasses import dataclass
from http import HTTPStatus
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
@@ -46,7 +47,7 @@ from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import AtomicCounter, is_list_of
from vllm.utils import AtomicCounter, is_list_of, make_async
logger = init_logger(__name__)
@@ -140,6 +141,14 @@ class OpenAIServing:
self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
self._tokenize_prompt_input_async = make_async(
self._tokenize_prompt_input, executor=self._tokenizer_executor)
self._tokenize_prompt_input_or_inputs_async = make_async(
self._tokenize_prompt_input_or_inputs,
executor=self._tokenizer_executor)
async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
model_cards = [
@@ -368,7 +377,7 @@ class OpenAIServing:
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> Iterator[TextTokensPrompt]:
) -> List[TextTokensPrompt]:
"""
Tokenize/detokenize depending on the input format.
@@ -376,45 +385,41 @@ class OpenAIServing:
, each input can be a string or array of tokens. Note that each request
can pass one or more inputs.
"""
for prompt_input in parse_and_batch_prompt(input_or_inputs):
# Although our type checking is based on mypy,
# VSCode Pyright extension should still work properly
# "is True" is required for Pyright to perform type narrowing
# See: https://github.com/microsoft/pyright/issues/7672
if prompt_input["is_tokens"] is False:
yield self._normalize_prompt_text_to_input(
request,
tokenizer,
prompt=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
else:
yield self._normalize_prompt_tokens_to_input(
request,
tokenizer,
prompt_ids=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens,
)
# Although our type checking is based on mypy,
# VSCode Pyright extension should still work properly
# "is True" is required for Pyright to perform type narrowing
# See: https://github.com/microsoft/pyright/issues/7672
return [
self._normalize_prompt_text_to_input(
request,
tokenizer,
prompt=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens)
if prompt_input["is_tokens"] is False else
self._normalize_prompt_tokens_to_input(
request,
tokenizer,
prompt_ids=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens)
for prompt_input in parse_and_batch_prompt(input_or_inputs)
]
def _preprocess_completion(
async def _preprocess_completion(
self,
request: CompletionLikeRequest,
tokenizer: AnyTokenizer,
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> Tuple[Sequence[TextTokensPrompt], List[TokensPrompt]]:
request_prompts = [
request_prompt
for request_prompt in self._tokenize_prompt_input_or_inputs(
request,
tokenizer,
input_or_inputs,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
]
) -> Tuple[List[TextTokensPrompt], List[TokensPrompt]]:
request_prompts = await self._tokenize_prompt_input_or_inputs_async(
request,
tokenizer,
input_or_inputs,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
engine_prompts = [
TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"])
@@ -493,7 +498,7 @@ class OpenAIServing:
request=request)
if isinstance(request_prompt, str):
prompt_inputs = self._tokenize_prompt_input(
prompt_inputs = await self._tokenize_prompt_input_async(
request,
tokenizer,
request_prompt,

View File

@@ -15,7 +15,7 @@ from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import EmbeddingRequestOutput
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import merge_async_iterators, random_uuid
from vllm.utils import make_async, merge_async_iterators, random_uuid
logger = init_logger(__name__)
@@ -145,9 +145,11 @@ class OpenAIServingScores(OpenAIServing):
tokenization_kwargs["truncation"] = True
tokenization_kwargs["max_length"] = truncate_prompt_tokens
prompt_inputs = tokenizer(text=q,
text_pair=t,
**tokenization_kwargs)
tokenize_async = make_async(tokenizer.__call__,
executor=self._tokenizer_executor)
prompt_inputs = await tokenize_async(text=q,
text_pair=t,
**tokenization_kwargs)
engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["input_ids"],
token_type_ids=prompt_inputs.get("token_type_ids"))

View File

@@ -81,12 +81,13 @@ class OpenAIServingTokenization(OpenAIServing):
add_special_tokens=request.add_special_tokens,
)
else:
request_prompts, engine_prompts = self._preprocess_completion(
request,
tokenizer,
request.prompt,
add_special_tokens=request.add_special_tokens,
)
(request_prompts,
engine_prompts) = await self._preprocess_completion(
request,
tokenizer,
request.prompt,
add_special_tokens=request.add_special_tokens,
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
@@ -134,7 +135,7 @@ class OpenAIServingTokenization(OpenAIServing):
# Silently ignore prompt adapter since it does not affect tokenization
# (Unlike in Embeddings API where an error is raised)
prompt_input = self._tokenize_prompt_input(
prompt_input = await self._tokenize_prompt_input_async(
request,
tokenizer,
request.tokens,