[mypy] Enable following imports for entrypoints (#7248)

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Fei <dfdfcai4@gmail.com>
This commit is contained in:
Cyrus Leung
2024-08-21 14:28:21 +08:00
committed by GitHub
parent 4506641212
commit baaedfdb2d
26 changed files with 480 additions and 320 deletions

View File

@@ -3,10 +3,9 @@ import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional)
from typing import Sequence as GenericSequence
from typing import Tuple, cast
from typing import Tuple, Union, cast
from fastapi import Request
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
@@ -19,7 +18,7 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
UsageInfo)
ErrorResponse, UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing,
@@ -29,6 +28,7 @@ from vllm.outputs import RequestOutput
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__)
@@ -60,8 +60,11 @@ class OpenAIServingCompletion(OpenAIServing):
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)
async def create_completion(self, request: CompletionRequest,
raw_request: Request):
async def create_completion(
self,
request: CompletionRequest,
raw_request: Request,
) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create
@@ -84,15 +87,6 @@ class OpenAIServingCompletion(OpenAIServing):
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.time())
if request.prompt_logprobs is not None:
if request.stream and request.prompt_logprobs > 0:
return self.create_error_response(
"Prompt_logprobs are not available when stream is enabled")
elif request.prompt_logprobs < 0:
return self.create_error_response(
f"Prompt_logprobs set to invalid negative "
f"value: {request.prompt_logprobs}")
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[RequestOutput, None]] = []
try:
@@ -153,9 +147,8 @@ class OpenAIServingCompletion(OpenAIServing):
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
result_generator: AsyncIterator[Tuple[
int, RequestOutput]] = merge_async_iterators(
*generators, is_cancelled=raw_request.is_disconnected)
result_generator = merge_async_iterators(
*generators, is_cancelled=raw_request.is_disconnected)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use
@@ -227,7 +220,7 @@ class OpenAIServingCompletion(OpenAIServing):
created_time: int,
model_name: str,
num_prompts: int,
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n
previous_texts = [""] * num_choices * num_prompts
@@ -236,6 +229,13 @@ class OpenAIServingCompletion(OpenAIServing):
try:
async for prompt_idx, res in result_generator:
prompt_token_ids = res.prompt_token_ids
prompt_logprobs = res.prompt_logprobs
prompt_text = res.prompt
delta_token_ids: GenericSequence[int]
out_logprobs: Optional[GenericSequence[Optional[Dict[
int, Logprob]]]]
for output in res.outputs:
i = output.index + prompt_idx * num_choices
@@ -244,19 +244,25 @@ class OpenAIServingCompletion(OpenAIServing):
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
assert prompt_text is not None
# only return the prompt
delta_text = res.prompt
delta_token_ids = res.prompt_token_ids
out_logprobs = res.prompt_logprobs
delta_text = prompt_text
delta_token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
has_echoed[i] = True
elif (request.echo and request.max_tokens > 0
and not has_echoed[i]):
assert prompt_text is not None
assert prompt_logprobs is not None
# echo the prompt and first token
delta_text = res.prompt + output.text
delta_token_ids = (res.prompt_token_ids +
output.token_ids)
out_logprobs = res.prompt_logprobs + (output.logprobs
or [])
delta_text = prompt_text + output.text
delta_token_ids = [
*prompt_token_ids, *output.token_ids
]
out_logprobs = [
*prompt_logprobs,
*(output.logprobs or []),
]
has_echoed[i] = True
else:
# return just the delta
@@ -301,7 +307,7 @@ class OpenAIServingCompletion(OpenAIServing):
and request.stream_options.include_usage):
if (request.stream_options.continuous_usage_stats
or output.finish_reason is not None):
prompt_tokens = len(res.prompt_token_ids)
prompt_tokens = len(prompt_token_ids)
completion_tokens = len(output.token_ids)
usage = UsageInfo(
prompt_tokens=prompt_tokens,
@@ -342,7 +348,7 @@ class OpenAIServingCompletion(OpenAIServing):
request_id: str,
created_time: int,
model_name: str,
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
) -> CompletionResponse:
choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0
@@ -353,16 +359,31 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt
token_ids: GenericSequence[int]
out_logprobs: Optional[GenericSequence[Optional[Dict[int,
Logprob]]]]
for output in final_res.outputs:
assert request.max_tokens is not None
if request.echo and request.max_tokens == 0:
assert prompt_text is not None
token_ids = prompt_token_ids
out_logprobs = prompt_logprobs
output_text = prompt_text
elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + list(output.token_ids)
out_logprobs = (prompt_logprobs + output.logprobs
if request.logprobs is not None else None)
assert prompt_text is not None
token_ids = [*prompt_token_ids, *output.token_ids]
if request.logprobs is None:
out_logprobs = None
else:
assert prompt_logprobs is not None
assert output.logprobs is not None
out_logprobs = [
*prompt_logprobs,
*output.logprobs,
]
output_text = prompt_text + output.text
else:
token_ids = output.token_ids
@@ -413,7 +434,7 @@ class OpenAIServingCompletion(OpenAIServing):
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
num_output_top_logprobs: int,
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
initial_text_offset: int = 0,
) -> CompletionLogProbs:
"""Create logprobs for OpenAI Completion API."""
@@ -430,17 +451,21 @@ class OpenAIServingCompletion(OpenAIServing):
token = tokenizer.decode(token_id)
if self.return_tokens_as_token_ids:
token = f"token_id:{token_id}"
out_tokens.append(token)
out_token_logprobs.append(None)
out_top_logprobs.append(None)
else:
step_token = step_top_logprobs[token_id]
token = self._get_decoded_token(
step_top_logprobs[token_id],
step_token,
token_id,
tokenizer,
return_as_token_id=self.return_tokens_as_token_ids)
token_logprob = max(step_top_logprobs[token_id].logprob,
-9999.0)
return_as_token_id=self.return_tokens_as_token_ids,
)
token_logprob = max(step_token.logprob, -9999.0)
out_tokens.append(token)
out_token_logprobs.append(token_logprob)