[Frontend] Automatic detection of chat content format from AST (#9919)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-11-16 13:35:40 +08:00
committed by GitHub
parent 4f168f69a3
commit 32e46e000f
16 changed files with 788 additions and 350 deletions

View File

@@ -29,6 +29,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.multiprocessing.engine import run_mp_engine
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
@@ -529,6 +530,9 @@ def init_app_state(
state.engine_client = engine_client
state.log_stats = not args.disable_log_stats
resolved_chat_template = load_chat_template(args.chat_template)
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
state.openai_serving_chat = OpenAIServingChat(
engine_client,
model_config,
@@ -537,7 +541,8 @@ def init_app_state(
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
chat_template=args.chat_template,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser,
@@ -557,7 +562,8 @@ def init_app_state(
model_config,
base_model_paths,
request_logger=request_logger,
chat_template=args.chat_template,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
) if model_config.task == "embedding" else None
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
@@ -565,7 +571,8 @@ def init_app_state(
base_model_paths,
lora_modules=args.lora_modules,
request_logger=request_logger,
chat_template=args.chat_template,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
)

View File

@@ -7,10 +7,11 @@ purposes.
import argparse
import json
import ssl
from typing import List, Optional, Sequence, Union
from typing import List, Optional, Sequence, Union, get_args
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.chat_utils import validate_chat_template
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
validate_chat_template)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
PromptAdapterPath)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
@@ -132,6 +133,18 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help="The file path to the chat template, "
"or the template in single-line form "
"for the specified model")
parser.add_argument(
'--chat-template-content-format',
type=str,
default="auto",
choices=get_args(ChatTemplateContentFormatOption),
help='The format to render message content within a chat template.'
'\n\n'
'* "string" will render the content as a string. '
'Example: "Hello World"\n'
'* "openai" will render the content as a list of dictionaries, '
'similar to OpenAI schema. '
'Example: [{"type": "text", "text": "Hello world!"}]')
parser.add_argument("--response-role",
type=nullable_str,
default="assistant",

View File

@@ -5,9 +5,8 @@ from argparse import Namespace
from typing import Any, Dict, List, Literal, Optional, Union
import torch
from openai.types.chat import ChatCompletionContentPartParam
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated, Required, TypedDict
from typing_extensions import Annotated
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.pooling_params import PoolingParams
@@ -35,26 +34,6 @@ assert _LONG_INFO.min == _MOCK_LONG_INFO.min
assert _LONG_INFO.max == _MOCK_LONG_INFO.max
class CustomChatCompletionMessageParam(TypedDict, total=False):
"""Enables custom roles in the Chat Completion API."""
role: Required[str]
"""The role of the message's author."""
content: Union[str, List[ChatCompletionContentPartParam]]
"""The contents of the message."""
name: str
"""An optional name for the participant.
Provides the model information to differentiate between participants of the
same role.
"""
tool_call_id: Optional[str]
tool_calls: Optional[List[dict]]
class OpenAIBaseModel(BaseModel):
# OpenAI API does not allow extra fields
model_config = ConfigDict(extra="forbid")
@@ -1054,16 +1033,56 @@ class TokenizeCompletionRequest(OpenAIBaseModel):
model: str
prompt: str
add_special_tokens: bool = Field(default=True)
add_special_tokens: bool = Field(
default=True,
description=(
"If true (the default), special tokens (e.g. BOS) will be added to "
"the prompt."),
)
class TokenizeChatRequest(OpenAIBaseModel):
model: str
messages: List[ChatCompletionMessageParam]
add_generation_prompt: bool = Field(default=True)
continue_final_message: bool = Field(default=False)
add_special_tokens: bool = Field(default=False)
add_generation_prompt: bool = Field(
default=True,
description=
("If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."),
)
continue_final_message: bool = Field(
default=False,
description=
("If this is set, the chat will be formatted so that the final "
"message in the chat is open-ended, without any EOS tokens. The "
"model will continue this message rather than starting a new one. "
"This allows you to \"prefill\" part of the model's response for it. "
"Cannot be used at the same time as `add_generation_prompt`."),
)
add_special_tokens: bool = Field(
default=False,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to false (as is the "
"default)."),
)
chat_template: Optional[str] = Field(
default=None,
description=(
"A Jinja template to use for this conversion. "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."),
)
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
default=None,
description=("Additional kwargs to pass to the template renderer. "
"Will be accessible by the chat template."),
)
@model_validator(mode="before")
@classmethod

View File

@@ -222,6 +222,7 @@ async def main(args):
prompt_adapters=None,
request_logger=request_logger,
chat_template=None,
chat_template_content_format="auto",
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if model_config.task == "generate" else None
openai_serving_embedding = OpenAIServingEmbedding(
@@ -230,6 +231,7 @@ async def main(args):
base_model_paths,
request_logger=request_logger,
chat_template=None,
chat_template_content_format="auto",
) if model_config.task == "embedding" else None
tracker = BatchProgressTracker()

View File

@@ -10,7 +10,8 @@ from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ConversationMessage, load_chat_template
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
ConversationMessage)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProb, ChatCompletionLogProbs,
@@ -38,20 +39,23 @@ logger = init_logger(__name__)
class OpenAIServingChat(OpenAIServing):
def __init__(self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
response_role: str,
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
return_tokens_as_token_ids: bool = False,
enable_auto_tools: bool = False,
tool_parser: Optional[str] = None,
enable_prompt_tokens_details: bool = False):
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
response_role: str,
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption,
return_tokens_as_token_ids: bool = False,
enable_auto_tools: bool = False,
tool_parser: Optional[str] = None,
enable_prompt_tokens_details: bool = False,
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
@@ -61,8 +65,8 @@ class OpenAIServingChat(OpenAIServing):
return_tokens_as_token_ids=return_tokens_as_token_ids)
self.response_role = response_role
self.use_tool_use_model_template = False
self.chat_template = load_chat_template(chat_template)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
# set up tool use
self.enable_auto_tools: bool = enable_auto_tools
@@ -120,6 +124,7 @@ class OpenAIServingChat(OpenAIServing):
) = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
tool_parser = self.tool_parser
# validation for OpenAI tools
@@ -157,6 +162,7 @@ class OpenAIServingChat(OpenAIServing):
tokenizer,
request.messages,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tool_dicts=tool_dicts,

View File

@@ -1,7 +1,7 @@
import asyncio
import base64
import time
from typing import AsyncGenerator, List, Literal, Optional, Union, cast
from typing import AsyncGenerator, Final, List, Literal, Optional, Union, cast
import numpy as np
from fastapi import Request
@@ -9,7 +9,7 @@ from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
EmbeddingRequest,
@@ -77,7 +77,8 @@ class OpenAIServingEmbedding(OpenAIServing):
*,
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
):
chat_template_content_format: ChatTemplateContentFormatOption,
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
@@ -85,7 +86,8 @@ class OpenAIServingEmbedding(OpenAIServing):
prompt_adapters=None,
request_logger=request_logger)
self.chat_template = load_chat_template(chat_template)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
async def create_embedding(
self,
@@ -144,6 +146,8 @@ class OpenAIServingEmbedding(OpenAIServing):
tokenizer,
request.messages,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.
chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
truncate_prompt_tokens=truncate_prompt_tokens,

View File

@@ -11,14 +11,16 @@ from typing_extensions import Annotated
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
ConversationMessage,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages_futures)
parse_chat_messages_futures,
resolve_chat_template_content_format)
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest,
DetokenizeRequest,
@@ -426,7 +428,8 @@ class OpenAIServing:
request: ChatLikeRequest,
tokenizer: AnyTokenizer,
messages: List[ChatCompletionMessageParam],
chat_template: Optional[str] = None,
chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption,
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tool_dicts: Optional[List[Dict[str, Any]]] = None,
@@ -437,10 +440,16 @@ class OpenAIServing:
add_special_tokens: bool = False,
) -> Tuple[List[ConversationMessage], Sequence[RequestPrompt],
List[TokensPrompt]]:
resolved_content_format = resolve_chat_template_content_format(
chat_template,
chat_template_content_format,
tokenizer,
)
conversation, mm_data_future = parse_chat_messages_futures(
messages,
self.model_config,
tokenizer,
content_format=resolved_content_format,
)
_chat_template_kwargs: Dict[str, Any] = dict(

View File

@@ -1,8 +1,8 @@
from typing import List, Optional, Union
from typing import Final, List, Optional, Union
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
@@ -33,7 +33,8 @@ class OpenAIServingTokenization(OpenAIServing):
lora_modules: Optional[List[LoRAModulePath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
):
chat_template_content_format: ChatTemplateContentFormatOption,
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
@@ -41,12 +42,8 @@ class OpenAIServingTokenization(OpenAIServing):
prompt_adapters=None,
request_logger=request_logger)
# If this is None we use the tokenizer's default chat template
# the list of commonly-used chat template names for HF named templates
hf_chat_templates: List[str] = ['default', 'tool_use']
self.chat_template = chat_template \
if chat_template in hf_chat_templates \
else load_chat_template(chat_template)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
async def create_tokenize(
self,
@@ -75,9 +72,12 @@ class OpenAIServingTokenization(OpenAIServing):
request,
tokenizer,
request.messages,
chat_template=self.chat_template,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.
chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
chat_template_kwargs=request.chat_template_kwargs,
add_special_tokens=request.add_special_tokens,
)
else: