|
|
|
|
@@ -2,16 +2,19 @@
|
|
|
|
|
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
|
import asyncio
|
|
|
|
|
from http import HTTPStatus
|
|
|
|
|
import json
|
|
|
|
|
import time
|
|
|
|
|
from typing import AsyncGenerator, Dict, List, Optional, Union, Any
|
|
|
|
|
from typing import AsyncGenerator, Dict, List, Optional
|
|
|
|
|
|
|
|
|
|
import fastapi
|
|
|
|
|
from fastapi import BackgroundTasks, Request
|
|
|
|
|
from fastapi.exceptions import RequestValidationError
|
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
|
|
|
from fastchat.conversation import (Conversation, SeparatorStyle,
|
|
|
|
|
get_conv_template)
|
|
|
|
|
import uvicorn
|
|
|
|
|
|
|
|
|
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
|
|
|
|
@@ -19,11 +22,10 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|
|
|
|
from vllm.entrypoints.openai.protocol import (
|
|
|
|
|
CompletionRequest, CompletionResponse, CompletionResponseChoice,
|
|
|
|
|
CompletionResponseStreamChoice, CompletionStreamResponse,
|
|
|
|
|
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice,
|
|
|
|
|
ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,
|
|
|
|
|
ChatMessage, DeltaMessage, ErrorResponse, LogProbs,
|
|
|
|
|
ModelCard, ModelList, ModelPermission, UsageInfo)
|
|
|
|
|
from fastchat.conversation import Conversation, SeparatorStyle, get_conv_template
|
|
|
|
|
ChatCompletionRequest, ChatCompletionResponse,
|
|
|
|
|
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
|
|
|
|
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
|
|
|
|
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
|
|
|
|
|
from vllm.logger import init_logger
|
|
|
|
|
from vllm.outputs import RequestOutput
|
|
|
|
|
from vllm.sampling_params import SamplingParams
|
|
|
|
|
@@ -95,15 +97,15 @@ async def get_gen_prompt(request) -> str:
|
|
|
|
|
return prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def check_length(request, prompt, engine):
|
|
|
|
|
if hasattr(engine.engine.model_config.hf_config, "max_sequence_length"):
|
|
|
|
|
context_len = engine.engine.model_config.hf_config.max_sequence_length
|
|
|
|
|
elif hasattr(engine.engine.model_config.hf_config, "seq_length"):
|
|
|
|
|
context_len = engine.engine.model_config.hf_config.seq_length
|
|
|
|
|
elif hasattr(engine.engine.model_config.hf_config, "max_position_embeddings"):
|
|
|
|
|
context_len = engine.engine.model_config.hf_config.max_position_embeddings
|
|
|
|
|
elif hasattr(engine.engine.model_config.hf_config, "seq_length"):
|
|
|
|
|
context_len = engine.engine.model_config.hf_config.seq_length
|
|
|
|
|
async def check_length(request, prompt, model_config):
|
|
|
|
|
if hasattr(model_config.hf_config, "max_sequence_length"):
|
|
|
|
|
context_len = model_config.hf_config.max_sequence_length
|
|
|
|
|
elif hasattr(model_config.hf_config, "seq_length"):
|
|
|
|
|
context_len = model_config.hf_config.seq_length
|
|
|
|
|
elif hasattr(model_config.hf_config, "max_position_embeddings"):
|
|
|
|
|
context_len = model_config.hf_config.max_position_embeddings
|
|
|
|
|
elif hasattr(model_config.hf_config, "seq_length"):
|
|
|
|
|
context_len = model_config.hf_config.seq_length
|
|
|
|
|
else:
|
|
|
|
|
context_len = 2048
|
|
|
|
|
|
|
|
|
|
@@ -182,7 +184,7 @@ async def create_chat_completion(raw_request: Request):
|
|
|
|
|
"logit_bias is not currently supported")
|
|
|
|
|
|
|
|
|
|
prompt = await get_gen_prompt(request)
|
|
|
|
|
error_check_ret = await check_length(request, prompt, engine)
|
|
|
|
|
error_check_ret = await check_length(request, prompt, engine_model_config)
|
|
|
|
|
if error_check_ret is not None:
|
|
|
|
|
return error_check_ret
|
|
|
|
|
|
|
|
|
|
@@ -206,15 +208,16 @@ async def create_chat_completion(raw_request: Request):
|
|
|
|
|
except ValueError as e:
|
|
|
|
|
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
|
|
|
|
|
|
|
|
|
result_generator = engine.generate(prompt, sampling_params,
|
|
|
|
|
request_id)
|
|
|
|
|
result_generator = engine.generate(prompt, sampling_params, request_id)
|
|
|
|
|
|
|
|
|
|
async def abort_request() -> None:
|
|
|
|
|
await engine.abort(request_id)
|
|
|
|
|
|
|
|
|
|
def create_stream_response_json(index: int,
|
|
|
|
|
text: str,
|
|
|
|
|
finish_reason: Optional[str] = None) -> str:
|
|
|
|
|
def create_stream_response_json(
|
|
|
|
|
index: int,
|
|
|
|
|
text: str,
|
|
|
|
|
finish_reason: Optional[str] = None,
|
|
|
|
|
) -> str:
|
|
|
|
|
choice_data = ChatCompletionResponseStreamChoice(
|
|
|
|
|
index=index,
|
|
|
|
|
delta=DeltaMessage(content=text),
|
|
|
|
|
@@ -238,10 +241,11 @@ async def create_chat_completion(raw_request: Request):
|
|
|
|
|
delta=DeltaMessage(role="assistant"),
|
|
|
|
|
finish_reason=None,
|
|
|
|
|
)
|
|
|
|
|
chunk = ChatCompletionStreamResponse(
|
|
|
|
|
id=request_id, choices=[choice_data], model=model_name
|
|
|
|
|
)
|
|
|
|
|
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
|
|
|
|
chunk = ChatCompletionStreamResponse(id=request_id,
|
|
|
|
|
choices=[choice_data],
|
|
|
|
|
model=model_name)
|
|
|
|
|
data = chunk.json(exclude_unset=True, ensure_ascii=False)
|
|
|
|
|
yield f"data: {data}\n\n"
|
|
|
|
|
|
|
|
|
|
previous_texts = [""] * request.n
|
|
|
|
|
previous_num_tokens = [0] * request.n
|
|
|
|
|
@@ -295,8 +299,8 @@ async def create_chat_completion(raw_request: Request):
|
|
|
|
|
choices.append(choice_data)
|
|
|
|
|
|
|
|
|
|
num_prompt_tokens = len(final_res.prompt_token_ids)
|
|
|
|
|
num_generated_tokens = sum(len(output.token_ids)
|
|
|
|
|
for output in final_res.outputs)
|
|
|
|
|
num_generated_tokens = sum(
|
|
|
|
|
len(output.token_ids) for output in final_res.outputs)
|
|
|
|
|
usage = UsageInfo(
|
|
|
|
|
prompt_tokens=num_prompt_tokens,
|
|
|
|
|
completion_tokens=num_generated_tokens,
|
|
|
|
|
@@ -314,9 +318,11 @@ async def create_chat_completion(raw_request: Request):
|
|
|
|
|
# When user requests streaming but we don't stream, we still need to
|
|
|
|
|
# return a streaming response with a single event.
|
|
|
|
|
response_json = response.json(ensure_ascii=False)
|
|
|
|
|
|
|
|
|
|
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
|
|
|
|
yield f"data: {response_json}\n\n"
|
|
|
|
|
yield "data: [DONE]\n\n"
|
|
|
|
|
|
|
|
|
|
return StreamingResponse(fake_stream_generator(),
|
|
|
|
|
media_type="text/event-stream")
|
|
|
|
|
|
|
|
|
|
@@ -367,9 +373,9 @@ async def create_completion(raw_request: Request):
|
|
|
|
|
return create_error_response(HTTPStatus.BAD_REQUEST,
|
|
|
|
|
"please provide at least one prompt")
|
|
|
|
|
if len(request.prompt) > 1:
|
|
|
|
|
return create_error_response(HTTPStatus.BAD_REQUEST,
|
|
|
|
|
"multiple prompts in a batch is not "
|
|
|
|
|
"currently supported")
|
|
|
|
|
return create_error_response(
|
|
|
|
|
HTTPStatus.BAD_REQUEST,
|
|
|
|
|
"multiple prompts in a batch is not currently supported")
|
|
|
|
|
prompt = request.prompt[0]
|
|
|
|
|
else:
|
|
|
|
|
prompt = request.prompt
|
|
|
|
|
@@ -571,6 +577,7 @@ if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
|
|
|
|
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
|
|
|
|
engine_model_config = asyncio.run(engine.get_model_config())
|
|
|
|
|
|
|
|
|
|
# A separate tokenizer to map token IDs to strings.
|
|
|
|
|
tokenizer = get_tokenizer(engine_args.tokenizer,
|
|
|
|
|
|