[Quality] Add CI for formatting (#343)
This commit is contained in:
@@ -2,16 +2,19 @@
|
||||
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
from http import HTTPStatus
|
||||
import json
|
||||
import time
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Union, Any
|
||||
from typing import AsyncGenerator, Dict, List, Optional
|
||||
|
||||
import fastapi
|
||||
from fastapi import BackgroundTasks, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastchat.conversation import (Conversation, SeparatorStyle,
|
||||
get_conv_template)
|
||||
import uvicorn
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
@@ -19,11 +22,10 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
CompletionRequest, CompletionResponse, CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice, CompletionStreamResponse,
|
||||
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,
|
||||
ChatMessage, DeltaMessage, ErrorResponse, LogProbs,
|
||||
ModelCard, ModelList, ModelPermission, UsageInfo)
|
||||
from fastchat.conversation import Conversation, SeparatorStyle, get_conv_template
|
||||
ChatCompletionRequest, ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
||||
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@@ -95,15 +97,15 @@ async def get_gen_prompt(request) -> str:
|
||||
return prompt
|
||||
|
||||
|
||||
async def check_length(request, prompt, engine):
|
||||
if hasattr(engine.engine.model_config.hf_config, "max_sequence_length"):
|
||||
context_len = engine.engine.model_config.hf_config.max_sequence_length
|
||||
elif hasattr(engine.engine.model_config.hf_config, "seq_length"):
|
||||
context_len = engine.engine.model_config.hf_config.seq_length
|
||||
elif hasattr(engine.engine.model_config.hf_config, "max_position_embeddings"):
|
||||
context_len = engine.engine.model_config.hf_config.max_position_embeddings
|
||||
elif hasattr(engine.engine.model_config.hf_config, "seq_length"):
|
||||
context_len = engine.engine.model_config.hf_config.seq_length
|
||||
async def check_length(request, prompt, model_config):
|
||||
if hasattr(model_config.hf_config, "max_sequence_length"):
|
||||
context_len = model_config.hf_config.max_sequence_length
|
||||
elif hasattr(model_config.hf_config, "seq_length"):
|
||||
context_len = model_config.hf_config.seq_length
|
||||
elif hasattr(model_config.hf_config, "max_position_embeddings"):
|
||||
context_len = model_config.hf_config.max_position_embeddings
|
||||
elif hasattr(model_config.hf_config, "seq_length"):
|
||||
context_len = model_config.hf_config.seq_length
|
||||
else:
|
||||
context_len = 2048
|
||||
|
||||
@@ -182,7 +184,7 @@ async def create_chat_completion(raw_request: Request):
|
||||
"logit_bias is not currently supported")
|
||||
|
||||
prompt = await get_gen_prompt(request)
|
||||
error_check_ret = await check_length(request, prompt, engine)
|
||||
error_check_ret = await check_length(request, prompt, engine_model_config)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
@@ -206,15 +208,16 @@ async def create_chat_completion(raw_request: Request):
|
||||
except ValueError as e:
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||
|
||||
result_generator = engine.generate(prompt, sampling_params,
|
||||
request_id)
|
||||
result_generator = engine.generate(prompt, sampling_params, request_id)
|
||||
|
||||
async def abort_request() -> None:
|
||||
await engine.abort(request_id)
|
||||
|
||||
def create_stream_response_json(index: int,
|
||||
text: str,
|
||||
finish_reason: Optional[str] = None) -> str:
|
||||
def create_stream_response_json(
|
||||
index: int,
|
||||
text: str,
|
||||
finish_reason: Optional[str] = None,
|
||||
) -> str:
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(content=text),
|
||||
@@ -238,10 +241,11 @@ async def create_chat_completion(raw_request: Request):
|
||||
delta=DeltaMessage(role="assistant"),
|
||||
finish_reason=None,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id, choices=[choice_data], model=model_name
|
||||
)
|
||||
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
|
||||
chunk = ChatCompletionStreamResponse(id=request_id,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
data = chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
previous_texts = [""] * request.n
|
||||
previous_num_tokens = [0] * request.n
|
||||
@@ -295,8 +299,8 @@ async def create_chat_completion(raw_request: Request):
|
||||
choices.append(choice_data)
|
||||
|
||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||
num_generated_tokens = sum(len(output.token_ids)
|
||||
for output in final_res.outputs)
|
||||
num_generated_tokens = sum(
|
||||
len(output.token_ids) for output in final_res.outputs)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
@@ -314,9 +318,11 @@ async def create_chat_completion(raw_request: Request):
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
response_json = response.json(ensure_ascii=False)
|
||||
|
||||
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
||||
yield f"data: {response_json}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(fake_stream_generator(),
|
||||
media_type="text/event-stream")
|
||||
|
||||
@@ -367,9 +373,9 @@ async def create_completion(raw_request: Request):
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"please provide at least one prompt")
|
||||
if len(request.prompt) > 1:
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"multiple prompts in a batch is not "
|
||||
"currently supported")
|
||||
return create_error_response(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"multiple prompts in a batch is not currently supported")
|
||||
prompt = request.prompt[0]
|
||||
else:
|
||||
prompt = request.prompt
|
||||
@@ -571,6 +577,7 @@ if __name__ == "__main__":
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
engine_model_config = asyncio.run(engine.get_model_config())
|
||||
|
||||
# A separate tokenizer to map token IDs to strings.
|
||||
tokenizer = get_tokenizer(engine_args.tokenizer,
|
||||
|
||||
Reference in New Issue
Block a user