[Frontend] Chat template fallbacks for multimodal models (#17805)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-05-08 14:05:54 +08:00
committed by GitHub
parent 843b222723
commit 96722aa81d
18 changed files with 219 additions and 52 deletions

View File

@@ -38,6 +38,10 @@ from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.utils import MediaConnector
# yapf: disable
from vllm.transformers_utils.chat_templates import (
get_chat_template_fallback_path)
# yapf: enable
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
@@ -325,11 +329,10 @@ def resolve_mistral_chat_template(
return None
def resolve_hf_chat_template(
model_config: ModelConfig,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
*,
trust_remote_code: bool,
) -> Optional[str]:
# 1st priority: The given chat template
if chat_template is not None:
@@ -342,7 +345,7 @@ def resolve_hf_chat_template(
tokenizer.name_or_path,
processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast,
ProcessorMixin),
trust_remote_code=trust_remote_code,
trust_remote_code=model_config.trust_remote_code,
)
if isinstance(processor, ProcessorMixin) and \
processor.chat_template is not None:
@@ -358,22 +361,34 @@ def resolve_hf_chat_template(
logger.debug("Failed to load AutoTokenizer chat template for %s",
tokenizer.name_or_path, exc_info=True)
return None
# 4th priority: Predefined fallbacks
path = get_chat_template_fallback_path(
model_type=model_config.hf_config.model_type,
tokenizer_name_or_path=model_config.tokenizer,
)
if path is not None:
logger.info("Loading chat template fallback for %s as there isn't one "
"defined on HF Hub.", tokenizer.name_or_path)
chat_template = load_chat_template(path)
else:
logger.debug("There is no chat template fallback for %s",
tokenizer.name_or_path)
return chat_template
def _resolve_chat_template_content_format(
model_config: ModelConfig,
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer,
*,
trust_remote_code: bool,
) -> _ChatTemplateContentFormat:
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
hf_chat_template = resolve_hf_chat_template(
model_config,
tokenizer,
chat_template=chat_template,
trust_remote_code=trust_remote_code,
tools=tools,
)
else:
@@ -413,19 +428,18 @@ def _log_chat_template_content_format(
def resolve_chat_template_content_format(
model_config: ModelConfig,
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer,
*,
trust_remote_code: bool = False,
) -> _ChatTemplateContentFormat:
detected_format = _resolve_chat_template_content_format(
model_config,
chat_template,
tools,
given_format,
tokenizer,
trust_remote_code=trust_remote_code,
)
_log_chat_template_content_format(
@@ -1177,20 +1191,20 @@ def parse_chat_messages_futures(
def apply_hf_chat_template(
model_config: ModelConfig,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
conversation: list[ConversationMessage],
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
*,
trust_remote_code: bool = False,
tokenize: bool = False, # Different from HF's default
**kwargs: Any,
) -> str:
hf_chat_template = resolve_hf_chat_template(
model_config,
tokenizer,
chat_template=chat_template,
tools=tools,
trust_remote_code=trust_remote_code,
)
if hf_chat_template is None: