[Frontend] Automatic detection of chat content format from AST (#9919)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -29,6 +29,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from vllm.engine.multiprocessing.engine import run_mp_engine
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import load_chat_template
|
||||
from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||
@@ -529,6 +530,9 @@ def init_app_state(
|
||||
state.engine_client = engine_client
|
||||
state.log_stats = not args.disable_log_stats
|
||||
|
||||
resolved_chat_template = load_chat_template(args.chat_template)
|
||||
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
|
||||
|
||||
state.openai_serving_chat = OpenAIServingChat(
|
||||
engine_client,
|
||||
model_config,
|
||||
@@ -537,7 +541,8 @@ def init_app_state(
|
||||
lora_modules=args.lora_modules,
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
chat_template=args.chat_template,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
tool_parser=args.tool_call_parser,
|
||||
@@ -557,7 +562,8 @@ def init_app_state(
|
||||
model_config,
|
||||
base_model_paths,
|
||||
request_logger=request_logger,
|
||||
chat_template=args.chat_template,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
) if model_config.task == "embedding" else None
|
||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine_client,
|
||||
@@ -565,7 +571,8 @@ def init_app_state(
|
||||
base_model_paths,
|
||||
lora_modules=args.lora_modules,
|
||||
request_logger=request_logger,
|
||||
chat_template=args.chat_template,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,10 +7,11 @@ purposes.
|
||||
import argparse
|
||||
import json
|
||||
import ssl
|
||||
from typing import List, Optional, Sequence, Union
|
||||
from typing import List, Optional, Sequence, Union, get_args
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||
from vllm.entrypoints.chat_utils import validate_chat_template
|
||||
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
||||
validate_chat_template)
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
PromptAdapterPath)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
@@ -132,6 +133,18 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
help="The file path to the chat template, "
|
||||
"or the template in single-line form "
|
||||
"for the specified model")
|
||||
parser.add_argument(
|
||||
'--chat-template-content-format',
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=get_args(ChatTemplateContentFormatOption),
|
||||
help='The format to render message content within a chat template.'
|
||||
'\n\n'
|
||||
'* "string" will render the content as a string. '
|
||||
'Example: "Hello World"\n'
|
||||
'* "openai" will render the content as a list of dictionaries, '
|
||||
'similar to OpenAI schema. '
|
||||
'Example: [{"type": "text", "text": "Hello world!"}]')
|
||||
parser.add_argument("--response-role",
|
||||
type=nullable_str,
|
||||
default="assistant",
|
||||
|
||||
@@ -5,9 +5,8 @@ from argparse import Namespace
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from openai.types.chat import ChatCompletionContentPartParam
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
from typing_extensions import Annotated, Required, TypedDict
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@@ -35,26 +34,6 @@ assert _LONG_INFO.min == _MOCK_LONG_INFO.min
|
||||
assert _LONG_INFO.max == _MOCK_LONG_INFO.max
|
||||
|
||||
|
||||
class CustomChatCompletionMessageParam(TypedDict, total=False):
|
||||
"""Enables custom roles in the Chat Completion API."""
|
||||
role: Required[str]
|
||||
"""The role of the message's author."""
|
||||
|
||||
content: Union[str, List[ChatCompletionContentPartParam]]
|
||||
"""The contents of the message."""
|
||||
|
||||
name: str
|
||||
"""An optional name for the participant.
|
||||
|
||||
Provides the model information to differentiate between participants of the
|
||||
same role.
|
||||
"""
|
||||
|
||||
tool_call_id: Optional[str]
|
||||
|
||||
tool_calls: Optional[List[dict]]
|
||||
|
||||
|
||||
class OpenAIBaseModel(BaseModel):
|
||||
# OpenAI API does not allow extra fields
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
@@ -1054,16 +1033,56 @@ class TokenizeCompletionRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
|
||||
add_special_tokens: bool = Field(default=True)
|
||||
add_special_tokens: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"If true (the default), special tokens (e.g. BOS) will be added to "
|
||||
"the prompt."),
|
||||
)
|
||||
|
||||
|
||||
class TokenizeChatRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
messages: List[ChatCompletionMessageParam]
|
||||
|
||||
add_generation_prompt: bool = Field(default=True)
|
||||
continue_final_message: bool = Field(default=False)
|
||||
add_special_tokens: bool = Field(default=False)
|
||||
add_generation_prompt: bool = Field(
|
||||
default=True,
|
||||
description=
|
||||
("If true, the generation prompt will be added to the chat template. "
|
||||
"This is a parameter used by chat template in tokenizer config of the "
|
||||
"model."),
|
||||
)
|
||||
continue_final_message: bool = Field(
|
||||
default=False,
|
||||
description=
|
||||
("If this is set, the chat will be formatted so that the final "
|
||||
"message in the chat is open-ended, without any EOS tokens. The "
|
||||
"model will continue this message rather than starting a new one. "
|
||||
"This allows you to \"prefill\" part of the model's response for it. "
|
||||
"Cannot be used at the same time as `add_generation_prompt`."),
|
||||
)
|
||||
add_special_tokens: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
||||
"on top of what is added by the chat template. "
|
||||
"For most models, the chat template takes care of adding the "
|
||||
"special tokens so this should be set to false (as is the "
|
||||
"default)."),
|
||||
)
|
||||
chat_template: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"A Jinja template to use for this conversion. "
|
||||
"As of transformers v4.44, default chat template is no longer "
|
||||
"allowed, so you must provide a chat template if the tokenizer "
|
||||
"does not define one."),
|
||||
)
|
||||
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description=("Additional kwargs to pass to the template renderer. "
|
||||
"Will be accessible by the chat template."),
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
|
||||
@@ -222,6 +222,7 @@ async def main(args):
|
||||
prompt_adapters=None,
|
||||
request_logger=request_logger,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
) if model_config.task == "generate" else None
|
||||
openai_serving_embedding = OpenAIServingEmbedding(
|
||||
@@ -230,6 +231,7 @@ async def main(args):
|
||||
base_model_paths,
|
||||
request_logger=request_logger,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
) if model_config.task == "embedding" else None
|
||||
|
||||
tracker = BatchProgressTracker()
|
||||
|
||||
@@ -10,7 +10,8 @@ from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ConversationMessage, load_chat_template
|
||||
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
||||
ConversationMessage)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionLogProb, ChatCompletionLogProbs,
|
||||
@@ -38,20 +39,23 @@ logger = init_logger(__name__)
|
||||
|
||||
class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
def __init__(self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: List[BaseModelPath],
|
||||
response_role: str,
|
||||
*,
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
enable_auto_tools: bool = False,
|
||||
tool_parser: Optional[str] = None,
|
||||
enable_prompt_tokens_details: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: List[BaseModelPath],
|
||||
response_role: str,
|
||||
*,
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
return_tokens_as_token_ids: bool = False,
|
||||
enable_auto_tools: bool = False,
|
||||
tool_parser: Optional[str] = None,
|
||||
enable_prompt_tokens_details: bool = False,
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
@@ -61,8 +65,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
return_tokens_as_token_ids=return_tokens_as_token_ids)
|
||||
|
||||
self.response_role = response_role
|
||||
self.use_tool_use_model_template = False
|
||||
self.chat_template = load_chat_template(chat_template)
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
|
||||
# set up tool use
|
||||
self.enable_auto_tools: bool = enable_auto_tools
|
||||
@@ -120,6 +124,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
tool_parser = self.tool_parser
|
||||
|
||||
# validation for OpenAI tools
|
||||
@@ -157,6 +162,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
tokenizer,
|
||||
request.messages,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
chat_template_content_format=self.chat_template_content_format,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
tool_dicts=tool_dicts,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import time
|
||||
from typing import AsyncGenerator, List, Literal, Optional, Union, cast
|
||||
from typing import AsyncGenerator, Final, List, Literal, Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
@@ -9,7 +9,7 @@ from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import load_chat_template
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
|
||||
EmbeddingRequest,
|
||||
@@ -77,7 +77,8 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
):
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
@@ -85,7 +86,8 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
prompt_adapters=None,
|
||||
request_logger=request_logger)
|
||||
|
||||
self.chat_template = load_chat_template(chat_template)
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
|
||||
async def create_embedding(
|
||||
self,
|
||||
@@ -144,6 +146,8 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
tokenizer,
|
||||
request.messages,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
chat_template_content_format=self.
|
||||
chat_template_content_format,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
|
||||
@@ -11,14 +11,16 @@ from typing_extensions import Annotated
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
ChatTemplateContentFormatOption,
|
||||
ConversationMessage,
|
||||
apply_hf_chat_template,
|
||||
apply_mistral_chat_template,
|
||||
parse_chat_messages_futures)
|
||||
parse_chat_messages_futures,
|
||||
resolve_chat_template_content_format)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
DetokenizeRequest,
|
||||
@@ -426,7 +428,8 @@ class OpenAIServing:
|
||||
request: ChatLikeRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
messages: List[ChatCompletionMessageParam],
|
||||
chat_template: Optional[str] = None,
|
||||
chat_template: Optional[str],
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
add_generation_prompt: bool = True,
|
||||
continue_final_message: bool = False,
|
||||
tool_dicts: Optional[List[Dict[str, Any]]] = None,
|
||||
@@ -437,10 +440,16 @@ class OpenAIServing:
|
||||
add_special_tokens: bool = False,
|
||||
) -> Tuple[List[ConversationMessage], Sequence[RequestPrompt],
|
||||
List[TokensPrompt]]:
|
||||
resolved_content_format = resolve_chat_template_content_format(
|
||||
chat_template,
|
||||
chat_template_content_format,
|
||||
tokenizer,
|
||||
)
|
||||
conversation, mm_data_future = parse_chat_messages_futures(
|
||||
messages,
|
||||
self.model_config,
|
||||
tokenizer,
|
||||
content_format=resolved_content_format,
|
||||
)
|
||||
|
||||
_chat_template_kwargs: Dict[str, Any] = dict(
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import List, Optional, Union
|
||||
from typing import Final, List, Optional, Union
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import load_chat_template
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@@ -33,7 +33,8 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
):
|
||||
chat_template_content_format: ChatTemplateContentFormatOption,
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
@@ -41,12 +42,8 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
prompt_adapters=None,
|
||||
request_logger=request_logger)
|
||||
|
||||
# If this is None we use the tokenizer's default chat template
|
||||
# the list of commonly-used chat template names for HF named templates
|
||||
hf_chat_templates: List[str] = ['default', 'tool_use']
|
||||
self.chat_template = chat_template \
|
||||
if chat_template in hf_chat_templates \
|
||||
else load_chat_template(chat_template)
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
|
||||
async def create_tokenize(
|
||||
self,
|
||||
@@ -75,9 +72,12 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
request,
|
||||
tokenizer,
|
||||
request.messages,
|
||||
chat_template=self.chat_template,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
chat_template_content_format=self.
|
||||
chat_template_content_format,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
chat_template_kwargs=request.chat_template_kwargs,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user