[Frontend] Chat-based Embeddings API (#9759)

This commit is contained in:
Cyrus Leung
2024-11-01 16:13:35 +08:00
committed by GitHub
parent d3aa2a8b2f
commit 06386a64dd
21 changed files with 846 additions and 408 deletions

View File

@@ -1,7 +1,6 @@
import asyncio
import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional)
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import Tuple, Union, cast
@@ -30,18 +29,11 @@ from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__)
TypeTokenIDs = List[int]
TypeTopLogProbs = List[Optional[Dict[int, float]]]
TypeCreateLogProbsFn = Callable[
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
class OpenAIServingCompletion(OpenAIServing):
@@ -101,8 +93,6 @@ class OpenAIServingCompletion(OpenAIServing):
if raw_request:
raw_request.state.request_metadata = request_metadata
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[RequestOutput, None]] = []
try:
(
lora_request,
@@ -111,19 +101,24 @@ class OpenAIServingCompletion(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer(lora_request)
prompts = list(
self._tokenize_prompt_input_or_inputs(
request,
tokenizer,
request.prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
))
request_prompts, engine_prompts = self._preprocess_completion(
request,
tokenizer,
request.prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
for i, prompt_inputs in enumerate(prompts):
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens = self.max_model_len - len(
prompt_inputs["prompt_token_ids"])
engine_prompt["prompt_token_ids"])
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
default_max_tokens)
@@ -134,36 +129,24 @@ class OpenAIServingCompletion(OpenAIServing):
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,
prompt_inputs,
request_prompts[i],
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
is_tracing_enabled = (await
self.engine_client.is_tracing_enabled())
trace_headers = None
if is_tracing_enabled:
trace_headers = extract_trace_headers(raw_request.headers)
if not is_tracing_enabled and contains_trace_headers(
raw_request.headers):
log_tracing_disabled_warning()
trace_headers = (await
self._get_trace_headers(raw_request.headers))
if isinstance(sampling_params, BeamSearchParams):
generator = self.engine_client.beam_search(
prompt={
"prompt_token_ids":
prompt_inputs["prompt_token_ids"]
},
prompt=engine_prompt,
model_config=self.model_config,
request_id=request_id,
params=sampling_params,
)
else:
generator = self.engine_client.generate(
{
"prompt_token_ids":
prompt_inputs["prompt_token_ids"]
},
engine_prompt,
sampling_params,
request_id_item,
lora_request=lora_request,
@@ -180,6 +163,8 @@ class OpenAIServingCompletion(OpenAIServing):
result_generator = merge_async_iterators(
*generators, is_cancelled=raw_request.is_disconnected)
num_prompts = len(engine_prompts)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use
# beam search.
@@ -195,16 +180,22 @@ class OpenAIServingCompletion(OpenAIServing):
request_id,
created_time,
model_name,
num_prompts=len(prompts),
num_prompts=num_prompts,
tokenizer=tokenizer,
request_metadata=request_metadata)
# Non-streaming response
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
final_res_batch: List[Optional[RequestOutput]] = [None] * num_prompts
try:
async for i, res in result_generator:
final_res_batch[i] = res
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
try:
for i, final_res in enumerate(final_res_batch):
assert final_res is not None
@@ -212,7 +203,7 @@ class OpenAIServingCompletion(OpenAIServing):
# We did not pass it into vLLM engine to avoid being redundant
# with the inputs token IDs
if final_res.prompt is None:
final_res.prompt = prompts[i]["prompt"]
final_res.prompt = request_prompts[i]["prompt"]
final_res_batch_checked = cast(List[RequestOutput],
final_res_batch)
@@ -226,8 +217,6 @@ class OpenAIServingCompletion(OpenAIServing):
tokenizer,
request_metadata,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))