[Frontend] Gracefully handle missing chat template and fix CI failure (#7238)

Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Cyrus Leung
2024-08-07 17:12:05 +08:00
committed by GitHub
parent 7b261092de
commit 66d617e343
9 changed files with 125 additions and 69 deletions

View File

@@ -1,8 +1,9 @@
import codecs
from dataclasses import dataclass
from functools import lru_cache
from typing import (Awaitable, Iterable, List, Optional, Tuple, Union, cast,
final)
from pathlib import Path
from typing import (Any, Awaitable, Iterable, List, Optional, Tuple, Union,
cast, final)
# yapf conflicts with isort for this block
# yapf: disable
@@ -22,6 +23,7 @@ from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import async_get_and_parse_image
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
@@ -69,13 +71,17 @@ class ChatMessageParseResult:
mm_futures: List[Awaitable[MultiModalDataDict]]
def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
def load_chat_template(
chat_template: Optional[Union[Path, str]]) -> Optional[str]:
if chat_template is None:
return None
try:
with open(chat_template, "r") as f:
resolved_chat_template = f.read()
except OSError as e:
if isinstance(chat_template, Path):
raise
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
@@ -208,3 +214,28 @@ def parse_chat_messages(
mm_futures.extend(parse_result.mm_futures)
return conversation, mm_futures
def apply_chat_template(
tokenizer: AnyTokenizer,
conversation: List[ConversationMessage],
chat_template: Optional[str],
*,
tokenize: bool = False, # Different from HF's default
**kwargs: Any,
) -> str:
if chat_template is None and tokenizer.chat_template is None:
raise ValueError(
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one.")
prompt = tokenizer.apply_chat_template(
conversation=conversation,
chat_template=chat_template,
tokenize=tokenize,
**kwargs,
)
assert isinstance(prompt, str)
return prompt