[Frontend] Support for chat completions input in the tokenize endpoint (#5923)

This commit is contained in:
sasha0552
2024-07-16 12:18:09 +00:00
committed by GitHub
parent d97011512e
commit 7a3d2a5b95
9 changed files with 386 additions and 244 deletions

View File

@@ -33,6 +33,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
@@ -46,6 +48,7 @@ engine_args: AsyncEngineArgs
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
openai_serving_embedding: OpenAIServingEmbedding
openai_serving_tokenization: OpenAIServingTokenization
logger = init_logger('vllm.entrypoints.openai.api_server')
@@ -86,7 +89,7 @@ async def health() -> Response:
@router.post("/tokenize")
async def tokenize(request: TokenizeRequest):
generator = await openai_serving_completion.create_tokenize(request)
generator = await openai_serving_tokenization.create_tokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
@@ -97,7 +100,7 @@ async def tokenize(request: TokenizeRequest):
@router.post("/detokenize")
async def detokenize(request: DetokenizeRequest):
generator = await openai_serving_completion.create_detokenize(request)
generator = await openai_serving_tokenization.create_detokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
@@ -241,6 +244,7 @@ def run_server(args, llm_engine=None):
global openai_serving_chat
global openai_serving_completion
global openai_serving_embedding
global openai_serving_tokenization
openai_serving_chat = OpenAIServingChat(engine, model_config,
served_model_names,
@@ -252,6 +256,8 @@ def run_server(args, llm_engine=None):
args.prompt_adapters)
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
served_model_names)
openai_serving_tokenization = OpenAIServingTokenization(
engine, model_config, served_model_names, args.chat_template)
app.root_path = args.root_path
logger.info("Available routes are:")