[mypy] Enable following imports for entrypoints (#7248)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Fei <dfdfcai4@gmail.com>
This commit is contained in:
@@ -15,6 +15,7 @@ from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from starlette.routing import Mount
|
||||
from typing_extensions import assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig
|
||||
@@ -29,14 +30,16 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
EmbeddingRequest, ErrorResponse,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse, ErrorResponse,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
|
||||
from vllm.entrypoints.openai.rpc.server import run_rpc_server
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
@@ -90,7 +93,8 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
|
||||
async def build_async_engine_client(
|
||||
args: Namespace) -> AsyncIterator[AsyncEngineClient]:
|
||||
# Context manager to handle async_engine_client lifecycle
|
||||
# Ensures everything is shutdown and cleaned up on error/exit
|
||||
global engine_args
|
||||
@@ -142,12 +146,15 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
|
||||
logger.info("Started engine process with PID %d",
|
||||
rpc_server_process.pid)
|
||||
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
|
||||
async_engine_client = AsyncEngineRPCClient(rpc_path)
|
||||
# NOTE: Actually, this is not true yet. We still need to support
|
||||
# embedding models via RPC (see TODO above)
|
||||
rpc_client = AsyncEngineRPCClient(rpc_path)
|
||||
async_engine_client = rpc_client # type: ignore
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
await async_engine_client.setup()
|
||||
await rpc_client.setup()
|
||||
break
|
||||
except TimeoutError as e:
|
||||
if not rpc_server_process.is_alive():
|
||||
@@ -161,7 +168,7 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
|
||||
rpc_server_process.terminate()
|
||||
|
||||
# Close all open connections to the backend
|
||||
async_engine_client.close()
|
||||
rpc_client.close()
|
||||
|
||||
# Wait for server process to join
|
||||
rpc_server_process.join()
|
||||
@@ -216,10 +223,11 @@ async def tokenize(request: TokenizeRequest):
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
else:
|
||||
assert isinstance(generator, TokenizeResponse)
|
||||
elif isinstance(generator, TokenizeResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.post("/detokenize")
|
||||
async def detokenize(request: DetokenizeRequest):
|
||||
@@ -227,10 +235,11 @@ async def detokenize(request: DetokenizeRequest):
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
else:
|
||||
assert isinstance(generator, DetokenizeResponse)
|
||||
elif isinstance(generator, DetokenizeResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
@router.get("/v1/models")
|
||||
async def show_available_models():
|
||||
@@ -252,13 +261,11 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
if request.stream:
|
||||
return StreamingResponse(content=generator,
|
||||
media_type="text/event-stream")
|
||||
else:
|
||||
assert isinstance(generator, ChatCompletionResponse)
|
||||
elif isinstance(generator, ChatCompletionResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/v1/completions")
|
||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
@@ -267,12 +274,11 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
if request.stream:
|
||||
return StreamingResponse(content=generator,
|
||||
media_type="text/event-stream")
|
||||
else:
|
||||
elif isinstance(generator, CompletionResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/v1/embeddings")
|
||||
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||
@@ -281,9 +287,11 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
else:
|
||||
elif isinstance(generator, EmbeddingResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
|
||||
|
||||
def build_app(args: Namespace) -> FastAPI:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
@@ -7,6 +7,7 @@ purposes.
|
||||
import argparse
|
||||
import json
|
||||
import ssl
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
@@ -16,8 +17,19 @@ from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
class LoRAParserAction(argparse.Action):
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
lora_list = []
|
||||
def __call__(
|
||||
self,
|
||||
parser: argparse.ArgumentParser,
|
||||
namespace: argparse.Namespace,
|
||||
values: Optional[Union[str, Sequence[str]]],
|
||||
option_string: Optional[str] = None,
|
||||
):
|
||||
if values is None:
|
||||
values = []
|
||||
if isinstance(values, str):
|
||||
raise TypeError("Expected values to be a list")
|
||||
|
||||
lora_list: List[LoRAModulePath] = []
|
||||
for item in values:
|
||||
name, path = item.split('=')
|
||||
lora_list.append(LoRAModulePath(name, path))
|
||||
@@ -26,8 +38,19 @@ class LoRAParserAction(argparse.Action):
|
||||
|
||||
class PromptAdapterParserAction(argparse.Action):
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
adapter_list = []
|
||||
def __call__(
|
||||
self,
|
||||
parser: argparse.ArgumentParser,
|
||||
namespace: argparse.Namespace,
|
||||
values: Optional[Union[str, Sequence[str]]],
|
||||
option_string: Optional[str] = None,
|
||||
):
|
||||
if values is None:
|
||||
values = []
|
||||
if isinstance(values, str):
|
||||
raise TypeError("Expected values to be a list")
|
||||
|
||||
adapter_list: List[PromptAdapterPath] = []
|
||||
for item in values:
|
||||
name, path = item.split('=')
|
||||
adapter_list.append(PromptAdapterPath(name, path))
|
||||
|
||||
@@ -2,9 +2,9 @@ from functools import lru_cache, partial
|
||||
from typing import Dict, FrozenSet, Iterable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.sampling_params import LogitsProcessor
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
|
||||
class AllowedTokenIdsLogitsProcessor:
|
||||
@@ -51,10 +51,11 @@ def logit_bias_logits_processor(
|
||||
|
||||
|
||||
def get_logits_processors(
|
||||
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
|
||||
allowed_token_ids: Optional[List[int]],
|
||||
tokenizer: PreTrainedTokenizer) -> List[LogitsProcessor]:
|
||||
logits_processors = []
|
||||
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]],
|
||||
allowed_token_ids: Optional[List[int]],
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> List[LogitsProcessor]:
|
||||
logits_processors: List[LogitsProcessor] = []
|
||||
if logit_bias:
|
||||
try:
|
||||
# Convert token_id to integer
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from transformers import PreTrainedTokenizer
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
@@ -14,11 +13,13 @@ from vllm.entrypoints.openai.logits_processors import get_logits_processors
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import LogitsProcessor, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
# torch is mocked during docs generation,
|
||||
# so we have to provide the values as literals
|
||||
_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807)
|
||||
_LONG_INFO: Union["torch.iinfo", Namespace]
|
||||
|
||||
try:
|
||||
from sphinx.ext.autodoc.mock import _MockModule
|
||||
@@ -235,13 +236,17 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
# doc: end-chat-completion-extra-params
|
||||
|
||||
def to_sampling_params(
|
||||
self, tokenizer: PreTrainedTokenizer,
|
||||
self, tokenizer: AnyTokenizer,
|
||||
guided_decode_logits_processor: Optional[LogitsProcessor],
|
||||
default_max_tokens: int) -> SamplingParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
prompt_logprobs = self.prompt_logprobs
|
||||
if prompt_logprobs is None and self.echo:
|
||||
prompt_logprobs = self.top_logprobs
|
||||
|
||||
# We now allow logprobs being true without top_logrobs.
|
||||
logits_processors = get_logits_processors(
|
||||
logit_bias=self.logit_bias,
|
||||
@@ -251,7 +256,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
if guided_decode_logits_processor:
|
||||
logits_processors.append(guided_decode_logits_processor)
|
||||
|
||||
return SamplingParams(
|
||||
return SamplingParams.from_optional(
|
||||
n=self.n,
|
||||
best_of=self.best_of,
|
||||
presence_penalty=self.presence_penalty,
|
||||
@@ -265,8 +270,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
stop=self.stop,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
logprobs=self.top_logprobs if self.logprobs else None,
|
||||
prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else
|
||||
(self.top_logprobs if self.echo else None),
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
ignore_eos=self.ignore_eos,
|
||||
max_tokens=max_tokens,
|
||||
min_tokens=self.min_tokens,
|
||||
@@ -280,14 +284,36 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_stream_options(cls, values):
|
||||
if (values.get('stream_options') is not None
|
||||
and not values.get('stream')):
|
||||
def validate_stream_options(cls, data):
|
||||
if data.get("stream_options") and not data.get("stream"):
|
||||
raise ValueError(
|
||||
"stream_options can only be set if stream is true")
|
||||
return values
|
||||
"Stream options can only be defined when `stream=True`.")
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_logprobs(cls, data):
|
||||
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
|
||||
if data.get("stream") and prompt_logprobs > 0:
|
||||
raise ValueError(
|
||||
"`prompt_logprobs` are not available when `stream=True`.")
|
||||
|
||||
if prompt_logprobs < 0:
|
||||
raise ValueError("`prompt_logprobs` must be a positive value.")
|
||||
|
||||
if (top_logprobs := data.get("top_logprobs")) is not None:
|
||||
if top_logprobs < 0:
|
||||
raise ValueError("`top_logprobs` must be a positive value.")
|
||||
|
||||
if not data.get("logprobs"):
|
||||
raise ValueError(
|
||||
"when using `top_logprobs`, `logprobs` must be set to true."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
@@ -320,19 +346,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
"When using `tool_choice`, `tools` must be set.")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_logprobs(cls, data):
|
||||
if "top_logprobs" in data and data["top_logprobs"] is not None:
|
||||
if "logprobs" not in data or data["logprobs"] is False:
|
||||
raise ValueError(
|
||||
"when using `top_logprobs`, `logprobs` must be set to true."
|
||||
)
|
||||
elif data["top_logprobs"] < 0:
|
||||
raise ValueError(
|
||||
"`top_logprobs` must be a value a positive value.")
|
||||
return data
|
||||
|
||||
|
||||
class CompletionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
@@ -422,13 +435,17 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
# doc: end-completion-extra-params
|
||||
|
||||
def to_sampling_params(
|
||||
self, tokenizer: PreTrainedTokenizer,
|
||||
self, tokenizer: AnyTokenizer,
|
||||
guided_decode_logits_processor: Optional[LogitsProcessor],
|
||||
default_max_tokens: int) -> SamplingParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
prompt_logprobs = self.prompt_logprobs
|
||||
if prompt_logprobs is None and self.echo:
|
||||
prompt_logprobs = self.logprobs
|
||||
|
||||
echo_without_generation = self.echo and self.max_tokens == 0
|
||||
|
||||
logits_processors = get_logits_processors(
|
||||
@@ -439,7 +456,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
if guided_decode_logits_processor:
|
||||
logits_processors.append(guided_decode_logits_processor)
|
||||
|
||||
return SamplingParams(
|
||||
return SamplingParams.from_optional(
|
||||
n=self.n,
|
||||
best_of=self.best_of,
|
||||
presence_penalty=self.presence_penalty,
|
||||
@@ -458,8 +475,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
min_tokens=self.min_tokens,
|
||||
use_beam_search=self.use_beam_search,
|
||||
early_stopping=self.early_stopping,
|
||||
prompt_logprobs=self.prompt_logprobs
|
||||
if self.prompt_logprobs else self.logprobs if self.echo else None,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
@@ -485,9 +501,17 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_logprobs(cls, data):
|
||||
if "logprobs" in data and data[
|
||||
"logprobs"] is not None and not data["logprobs"] >= 0:
|
||||
raise ValueError("if passed, `logprobs` must be a positive value.")
|
||||
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
|
||||
if data.get("stream") and prompt_logprobs > 0:
|
||||
raise ValueError(
|
||||
"`prompt_logprobs` are not available when `stream=True`.")
|
||||
|
||||
if prompt_logprobs < 0:
|
||||
raise ValueError("`prompt_logprobs` must be a positive value.")
|
||||
|
||||
if (logprobs := data.get("logprobs")) is not None and logprobs < 0:
|
||||
raise ValueError("`logprobs` must be a positive value.")
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@@ -495,7 +519,8 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
def validate_stream_options(cls, data):
|
||||
if data.get("stream_options") and not data.get("stream"):
|
||||
raise ValueError(
|
||||
"Stream options can only be defined when stream is true.")
|
||||
"Stream options can only be defined when `stream=True`.")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@@ -504,7 +529,7 @@ class EmbeddingRequest(OpenAIBaseModel):
|
||||
# https://platform.openai.com/docs/api-reference/embeddings
|
||||
model: str
|
||||
input: Union[List[int], List[List[int]], str, List[str]]
|
||||
encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
|
||||
encoding_format: Literal["float", "base64"] = "float"
|
||||
dimensions: Optional[int] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
@@ -23,8 +23,8 @@ class AsyncEngineRPCServer:
|
||||
def __init__(self, async_engine_args: AsyncEngineArgs,
|
||||
usage_context: UsageContext, rpc_path: str):
|
||||
# Initialize engine first.
|
||||
self.engine = AsyncLLMEngine.from_engine_args(async_engine_args,
|
||||
usage_context)
|
||||
self.engine = AsyncLLMEngine.from_engine_args(
|
||||
async_engine_args, usage_context=usage_context)
|
||||
|
||||
# Initialize context.
|
||||
self.context = zmq.asyncio.Context()
|
||||
@@ -39,7 +39,7 @@ class AsyncEngineRPCServer:
|
||||
self.context.destroy()
|
||||
self.engine.shutdown_background_loop()
|
||||
# Clear the engine reference so that it can be GC'ed.
|
||||
self.engine = None
|
||||
del self.engine
|
||||
|
||||
async def get_model_config(self, identity):
|
||||
"""Send the ModelConfig"""
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
|
||||
from typing import AsyncGenerator, AsyncIterator, Dict, Final, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Union
|
||||
|
||||
from fastapi import Request
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
@@ -24,13 +23,14 @@ from vllm.entrypoints.openai.protocol import (
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
OpenAIServing,
|
||||
PromptAdapterPath)
|
||||
from vllm.inputs import PromptInputs
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
log_tracing_disabled_warning)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import iterate_with_cancellation, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -67,9 +67,9 @@ class OpenAIServingChat(OpenAIServing):
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
raw_request: Optional[Request] = None
|
||||
) -> Union[ErrorResponse, AsyncGenerator[str, None],
|
||||
ChatCompletionResponse]:
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
|
||||
ErrorResponse]:
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/chat/create
|
||||
@@ -83,16 +83,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
if request.prompt_logprobs is not None:
|
||||
if request.stream and request.prompt_logprobs > 0:
|
||||
return self.create_error_response(
|
||||
"Prompt_logprobs are not available when stream is enabled")
|
||||
|
||||
if request.prompt_logprobs < 0:
|
||||
return self.create_error_response(
|
||||
f"Prompt_logprobs set to invalid "
|
||||
f"negative value: {request.prompt_logprobs}")
|
||||
|
||||
try:
|
||||
(
|
||||
lora_request,
|
||||
@@ -160,9 +150,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
engine_inputs: PromptInputs = {
|
||||
"prompt_token_ids": prompt_inputs["prompt_token_ids"],
|
||||
}
|
||||
engine_inputs = TokensPrompt(
|
||||
prompt_token_ids=prompt_inputs["prompt_token_ids"])
|
||||
if mm_data is not None:
|
||||
engine_inputs["multi_modal_data"] = mm_data
|
||||
|
||||
@@ -214,11 +203,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
result_generator: AsyncIterator[RequestOutput],
|
||||
request_id: str,
|
||||
conversation: List[ConversationMessage],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
model_name = self.served_model_names[0]
|
||||
created_time = int(time.time())
|
||||
chunk_object_type = "chat.completion.chunk"
|
||||
chunk_object_type: Final = "chat.completion.chunk"
|
||||
first_iteration = True
|
||||
|
||||
# Send response for each token for each request.n (index)
|
||||
@@ -438,7 +427,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
result_generator: AsyncIterator[RequestOutput],
|
||||
request_id: str,
|
||||
conversation: List[ConversationMessage],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
||||
|
||||
model_name = self.served_model_names[0]
|
||||
@@ -523,7 +512,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
def _get_top_logprobs(
|
||||
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
|
||||
tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]:
|
||||
tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]:
|
||||
return [
|
||||
ChatCompletionLogProb(token=(token := self._get_decoded_token(
|
||||
p[1],
|
||||
@@ -541,12 +530,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
self,
|
||||
token_ids: GenericSequence[int],
|
||||
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
tokenizer: AnyTokenizer,
|
||||
num_output_top_logprobs: Optional[int] = None,
|
||||
) -> ChatCompletionLogProbs:
|
||||
"""Create OpenAI-style logprobs."""
|
||||
|
||||
logprobs_content = []
|
||||
logprobs_content: List[ChatCompletionLogProbsContent] = []
|
||||
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
@@ -554,23 +542,32 @@ class OpenAIServingChat(OpenAIServing):
|
||||
token = tokenizer.decode(token_id)
|
||||
if self.return_tokens_as_token_ids:
|
||||
token = f"token_id:{token_id}"
|
||||
|
||||
logprobs_content.append(
|
||||
ChatCompletionLogProbsContent(
|
||||
token=token,
|
||||
bytes=list(token.encode("utf-8", errors="replace"))))
|
||||
bytes=list(token.encode("utf-8", errors="replace")),
|
||||
))
|
||||
else:
|
||||
step_token = step_top_logprobs[token_id]
|
||||
step_decoded = step_token.decoded_token
|
||||
|
||||
logprobs_content.append(
|
||||
ChatCompletionLogProbsContent(
|
||||
token=self._get_decoded_token(
|
||||
step_top_logprobs[token_id], token_id, tokenizer,
|
||||
self.return_tokens_as_token_ids),
|
||||
logprob=max(step_top_logprobs[token_id].logprob,
|
||||
-9999.0),
|
||||
bytes=list(
|
||||
step_top_logprobs[token_id].decoded_token.encode(
|
||||
"utf-8", errors="replace")),
|
||||
step_token,
|
||||
token_id,
|
||||
tokenizer,
|
||||
self.return_tokens_as_token_ids,
|
||||
),
|
||||
logprob=max(step_token.logprob, -9999.0),
|
||||
bytes=None if step_decoded is None else list(
|
||||
step_decoded.encode("utf-8", errors="replace")),
|
||||
top_logprobs=self._get_top_logprobs(
|
||||
step_top_logprobs, num_output_top_logprobs,
|
||||
tokenizer)))
|
||||
step_top_logprobs,
|
||||
num_output_top_logprobs,
|
||||
tokenizer,
|
||||
),
|
||||
))
|
||||
|
||||
return ChatCompletionLogProbs(content=logprobs_content)
|
||||
|
||||
@@ -3,10 +3,9 @@ import time
|
||||
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
|
||||
Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Tuple, cast
|
||||
from typing import Tuple, Union, cast
|
||||
|
||||
from fastapi import Request
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
@@ -19,7 +18,7 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
||||
CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
UsageInfo)
|
||||
ErrorResponse, UsageInfo)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
OpenAIServing,
|
||||
@@ -29,6 +28,7 @@ from vllm.outputs import RequestOutput
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
log_tracing_disabled_warning)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import merge_async_iterators, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -60,8 +60,11 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
||||
|
||||
async def create_completion(self, request: CompletionRequest,
|
||||
raw_request: Request):
|
||||
async def create_completion(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
raw_request: Request,
|
||||
) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]:
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/completions/create
|
||||
@@ -84,15 +87,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
created_time = int(time.time())
|
||||
|
||||
if request.prompt_logprobs is not None:
|
||||
if request.stream and request.prompt_logprobs > 0:
|
||||
return self.create_error_response(
|
||||
"Prompt_logprobs are not available when stream is enabled")
|
||||
elif request.prompt_logprobs < 0:
|
||||
return self.create_error_response(
|
||||
f"Prompt_logprobs set to invalid negative "
|
||||
f"value: {request.prompt_logprobs}")
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncGenerator[RequestOutput, None]] = []
|
||||
try:
|
||||
@@ -153,9 +147,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator: AsyncIterator[Tuple[
|
||||
int, RequestOutput]] = merge_async_iterators(
|
||||
*generators, is_cancelled=raw_request.is_disconnected)
|
||||
result_generator = merge_async_iterators(
|
||||
*generators, is_cancelled=raw_request.is_disconnected)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
# results. In addition, we do not stream the results when use
|
||||
@@ -227,7 +220,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
num_prompts: int,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
num_choices = 1 if request.n is None else request.n
|
||||
previous_texts = [""] * num_choices * num_prompts
|
||||
@@ -236,6 +229,13 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
try:
|
||||
async for prompt_idx, res in result_generator:
|
||||
prompt_token_ids = res.prompt_token_ids
|
||||
prompt_logprobs = res.prompt_logprobs
|
||||
prompt_text = res.prompt
|
||||
|
||||
delta_token_ids: GenericSequence[int]
|
||||
out_logprobs: Optional[GenericSequence[Optional[Dict[
|
||||
int, Logprob]]]]
|
||||
|
||||
for output in res.outputs:
|
||||
i = output.index + prompt_idx * num_choices
|
||||
@@ -244,19 +244,25 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
assert request.max_tokens is not None
|
||||
if request.echo and request.max_tokens == 0:
|
||||
assert prompt_text is not None
|
||||
# only return the prompt
|
||||
delta_text = res.prompt
|
||||
delta_token_ids = res.prompt_token_ids
|
||||
out_logprobs = res.prompt_logprobs
|
||||
delta_text = prompt_text
|
||||
delta_token_ids = prompt_token_ids
|
||||
out_logprobs = prompt_logprobs
|
||||
has_echoed[i] = True
|
||||
elif (request.echo and request.max_tokens > 0
|
||||
and not has_echoed[i]):
|
||||
assert prompt_text is not None
|
||||
assert prompt_logprobs is not None
|
||||
# echo the prompt and first token
|
||||
delta_text = res.prompt + output.text
|
||||
delta_token_ids = (res.prompt_token_ids +
|
||||
output.token_ids)
|
||||
out_logprobs = res.prompt_logprobs + (output.logprobs
|
||||
or [])
|
||||
delta_text = prompt_text + output.text
|
||||
delta_token_ids = [
|
||||
*prompt_token_ids, *output.token_ids
|
||||
]
|
||||
out_logprobs = [
|
||||
*prompt_logprobs,
|
||||
*(output.logprobs or []),
|
||||
]
|
||||
has_echoed[i] = True
|
||||
else:
|
||||
# return just the delta
|
||||
@@ -301,7 +307,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
and request.stream_options.include_usage):
|
||||
if (request.stream_options.continuous_usage_stats
|
||||
or output.finish_reason is not None):
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
prompt_tokens = len(prompt_token_ids)
|
||||
completion_tokens = len(output.token_ids)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
@@ -342,7 +348,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> CompletionResponse:
|
||||
choices: List[CompletionResponseChoice] = []
|
||||
num_prompt_tokens = 0
|
||||
@@ -353,16 +359,31 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
prompt_logprobs = final_res.prompt_logprobs
|
||||
prompt_text = final_res.prompt
|
||||
|
||||
token_ids: GenericSequence[int]
|
||||
out_logprobs: Optional[GenericSequence[Optional[Dict[int,
|
||||
Logprob]]]]
|
||||
|
||||
for output in final_res.outputs:
|
||||
assert request.max_tokens is not None
|
||||
if request.echo and request.max_tokens == 0:
|
||||
assert prompt_text is not None
|
||||
token_ids = prompt_token_ids
|
||||
out_logprobs = prompt_logprobs
|
||||
output_text = prompt_text
|
||||
elif request.echo and request.max_tokens > 0:
|
||||
token_ids = prompt_token_ids + list(output.token_ids)
|
||||
out_logprobs = (prompt_logprobs + output.logprobs
|
||||
if request.logprobs is not None else None)
|
||||
assert prompt_text is not None
|
||||
token_ids = [*prompt_token_ids, *output.token_ids]
|
||||
|
||||
if request.logprobs is None:
|
||||
out_logprobs = None
|
||||
else:
|
||||
assert prompt_logprobs is not None
|
||||
assert output.logprobs is not None
|
||||
out_logprobs = [
|
||||
*prompt_logprobs,
|
||||
*output.logprobs,
|
||||
]
|
||||
|
||||
output_text = prompt_text + output.text
|
||||
else:
|
||||
token_ids = output.token_ids
|
||||
@@ -413,7 +434,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
token_ids: GenericSequence[int],
|
||||
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
||||
num_output_top_logprobs: int,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
tokenizer: AnyTokenizer,
|
||||
initial_text_offset: int = 0,
|
||||
) -> CompletionLogProbs:
|
||||
"""Create logprobs for OpenAI Completion API."""
|
||||
@@ -430,17 +451,21 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
token = tokenizer.decode(token_id)
|
||||
if self.return_tokens_as_token_ids:
|
||||
token = f"token_id:{token_id}"
|
||||
|
||||
out_tokens.append(token)
|
||||
out_token_logprobs.append(None)
|
||||
out_top_logprobs.append(None)
|
||||
else:
|
||||
step_token = step_top_logprobs[token_id]
|
||||
|
||||
token = self._get_decoded_token(
|
||||
step_top_logprobs[token_id],
|
||||
step_token,
|
||||
token_id,
|
||||
tokenizer,
|
||||
return_as_token_id=self.return_tokens_as_token_ids)
|
||||
token_logprob = max(step_top_logprobs[token_id].logprob,
|
||||
-9999.0)
|
||||
return_as_token_id=self.return_tokens_as_token_ids,
|
||||
)
|
||||
token_logprob = max(step_token.logprob, -9999.0)
|
||||
|
||||
out_tokens.append(token)
|
||||
out_token_logprobs.append(token_logprob)
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import time
|
||||
from typing import (AsyncGenerator, AsyncIterator, List, Optional, Tuple,
|
||||
Union, cast)
|
||||
from typing import AsyncGenerator, List, Literal, Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
@@ -16,7 +16,7 @@ from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
|
||||
ErrorResponse, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import EmbeddingRequestOutput
|
||||
from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput
|
||||
from vllm.utils import merge_async_iterators, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -24,18 +24,28 @@ logger = init_logger(__name__)
|
||||
TypeTokenIDs = List[int]
|
||||
|
||||
|
||||
def _get_embedding(
|
||||
output: EmbeddingOutput,
|
||||
encoding_format: Literal["float", "base64"],
|
||||
) -> Union[List[float], str]:
|
||||
if encoding_format == "float":
|
||||
return output.embedding
|
||||
elif encoding_format == "base64":
|
||||
embedding_bytes = np.array(output.embedding).tobytes()
|
||||
return base64.b64encode(embedding_bytes).decode("utf-8")
|
||||
|
||||
assert_never(encoding_format)
|
||||
|
||||
|
||||
def request_output_to_embedding_response(
|
||||
final_res_batch: List[EmbeddingRequestOutput], request_id: str,
|
||||
created_time: int, model_name: str,
|
||||
encoding_format: str) -> EmbeddingResponse:
|
||||
encoding_format: Literal["float", "base64"]) -> EmbeddingResponse:
|
||||
data: List[EmbeddingResponseData] = []
|
||||
num_prompt_tokens = 0
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
embedding = final_res.outputs.embedding
|
||||
if encoding_format == "base64":
|
||||
embedding_bytes = np.array(embedding).tobytes()
|
||||
embedding = base64.b64encode(embedding_bytes).decode("utf-8")
|
||||
embedding = _get_embedding(final_res.outputs, encoding_format)
|
||||
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
|
||||
data.append(embedding_data)
|
||||
|
||||
@@ -76,8 +86,8 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
async def create_embedding(
|
||||
self,
|
||||
request: EmbeddingRequest,
|
||||
raw_request: Optional[Request] = None
|
||||
) -> Union[ErrorResponse, EmbeddingResponse]:
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[EmbeddingResponse, ErrorResponse]:
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
@@ -89,8 +99,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
encoding_format = (request.encoding_format
|
||||
if request.encoding_format else "float")
|
||||
encoding_format = request.encoding_format
|
||||
if request.dimensions is not None:
|
||||
return self.create_error_response(
|
||||
"dimensions is currently not supported")
|
||||
@@ -145,11 +154,10 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator: AsyncIterator[Tuple[
|
||||
int, EmbeddingRequestOutput]] = merge_async_iterators(
|
||||
*generators,
|
||||
is_cancelled=raw_request.is_disconnected
|
||||
if raw_request else None)
|
||||
result_generator = merge_async_iterators(
|
||||
*generators,
|
||||
is_cancelled=raw_request.is_disconnected if raw_request else None,
|
||||
)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: List[Optional[EmbeddingRequestOutput]]
|
||||
@@ -175,7 +183,7 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
|
||||
return response
|
||||
|
||||
def _check_embedding_mode(self, embedding_mode: bool):
|
||||
def _check_embedding_mode(self, embedding_mode: bool) -> bool:
|
||||
if not embedding_mode:
|
||||
logger.warning(
|
||||
"embedding_mode is False. Embedding API will not work.")
|
||||
|
||||
@@ -31,7 +31,7 @@ from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import LogitsProcessor, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.tokenizer_group import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user