[Frontend] Chat template fallbacks for multimodal models (#17805)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-05-08 14:05:54 +08:00
committed by GitHub
parent 843b222723
commit 96722aa81d
18 changed files with 219 additions and 52 deletions

View File

@@ -4,8 +4,6 @@ import warnings
from typing import Optional
import pytest
from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.assets.image import ImageAsset
from vllm.config import ModelConfig
@@ -19,6 +17,7 @@ from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import encode_image_base64
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import VLLM_PATH
EXAMPLES_DIR = VLLM_PATH / "examples"
@@ -772,6 +771,7 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
trust_remote_code=model_config.trust_remote_code,
)
tokenizer = tokenizer_group.tokenizer
@@ -793,8 +793,8 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
)
vllm_result = apply_hf_chat_template(
model_config,
tokenizer,
trust_remote_code=model_config.trust_remote_code,
conversation=conversation,
chat_template=None,
tools=None,
@@ -813,6 +813,16 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
@pytest.mark.parametrize("use_tools", [True, False])
def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
"""checks that chat_template is a dict type for HF models."""
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
)
# Build the tokenizer group and grab the underlying tokenizer
tokenizer_group = TokenizerGroup(
@@ -820,6 +830,7 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
trust_remote_code=model_config.trust_remote_code,
)
tokenizer = tokenizer_group.tokenizer
@@ -834,10 +845,10 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
# Test detecting the tokenizer's chat_template
chat_template = resolve_hf_chat_template(
model_config,
tokenizer,
chat_template=None,
tools=tools,
trust_remote_code=True,
)
assert isinstance(chat_template, str)
@@ -857,24 +868,32 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
)
# yapf: enable
def test_resolve_content_format_hf_defined(model, expected_format):
if model == QWEN25VL_MODEL_ID and Version(TRANSFORMERS_VERSION) < Version(
"4.49.0"):
pytest.skip("Qwen2.5-VL requires transformers>=4.49.0")
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
)
tokenizer_group = TokenizerGroup(
model,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
trust_remote_code=model_config.trust_remote_code,
)
tokenizer = tokenizer_group.tokenizer
# Test detecting the tokenizer's chat_template
chat_template = resolve_hf_chat_template(
model_config,
tokenizer,
chat_template=None,
tools=None,
trust_remote_code=True,
)
assert isinstance(chat_template, str)
@@ -884,11 +903,70 @@ def test_resolve_content_format_hf_defined(model, expected_format):
print(_try_extract_ast(chat_template))
resolved_format = resolve_chat_template_content_format(
model_config,
None, # Test detecting the tokenizer's chat_template
None,
"auto",
tokenizer,
)
assert resolved_format == expected_format
# yapf: disable
@pytest.mark.parametrize(
("model", "expected_format"),
[("Salesforce/blip2-opt-2.7b", "string"),
("facebook/chameleon-7b", "string"),
("deepseek-ai/deepseek-vl2-tiny", "string"),
("microsoft/Florence-2-base", "string"),
("adept/fuyu-8b", "string"),
("google/paligemma-3b-mix-224", "string"),
("Qwen/Qwen-VL", "string"),
("Qwen/Qwen-VL-Chat", "string")],
)
# yapf: enable
def test_resolve_content_format_fallbacks(model, expected_format):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
)
tokenizer_group = TokenizerGroup(
model_config.tokenizer,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
trust_remote_code=model_config.trust_remote_code,
)
tokenizer = tokenizer_group.tokenizer
# Test detecting the tokenizer's chat_template
chat_template = resolve_hf_chat_template(
model_config,
tokenizer,
chat_template=None,
tools=None,
)
assert isinstance(chat_template, str)
print("[TEXT]")
print(chat_template)
print("[AST]")
print(_try_extract_ast(chat_template))
resolved_format = resolve_chat_template_content_format(
model_config,
None, # Test detecting the tokenizer's chat_template
None,
"auto",
tokenizer,
trust_remote_code=True,
)
assert resolved_format == expected_format
@@ -899,22 +977,14 @@ def test_resolve_content_format_hf_defined(model, expected_format):
("template_path", "expected_format"),
[("template_alpaca.jinja", "string"),
("template_baichuan.jinja", "string"),
("template_blip2.jinja", "string"),
("template_chameleon.jinja", "string"),
("template_chatglm.jinja", "string"),
("template_chatglm2.jinja", "string"),
("template_chatml.jinja", "string"),
("template_deepseek_vl2.jinja", "string"),
("template_dse_qwen2_vl.jinja", "openai"),
("template_falcon_180b.jinja", "string"),
("template_falcon.jinja", "string"),
("template_florence2.jinja", "string"),
("template_fuyu.jinja", "string"),
("template_inkbot.jinja", "string"),
("template_paligemma.jinja", "string"),
("template_teleflm.jinja", "string"),
("template_qwen_vl.jinja", "string"),
("template_qwen_vl_chat.jinja", "string"),
("template_vlm2vec.jinja", "openai"),
("tool_chat_template_granite_20b_fc.jinja", "string"),
("tool_chat_template_hermes.jinja", "string"),
@@ -926,11 +996,18 @@ def test_resolve_content_format_hf_defined(model, expected_format):
)
# yapf: enable
def test_resolve_content_format_examples(template_path, expected_format):
model_config = ModelConfig(
PHI3V_MODEL_ID, # Dummy
tokenizer=PHI3V_MODEL_ID, # Dummy
trust_remote_code=True,
)
tokenizer_group = TokenizerGroup(
PHI3V_MODEL_ID,
PHI3V_MODEL_ID, # Dummy
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
trust_remote_code=model_config.trust_remote_code,
)
dummy_tokenizer = tokenizer_group.tokenizer
dummy_tokenizer.chat_template = None
@@ -944,11 +1021,11 @@ def test_resolve_content_format_examples(template_path, expected_format):
print(_try_extract_ast(chat_template))
resolved_format = resolve_chat_template_content_format(
model_config,
chat_template,
None,
"auto",
dummy_tokenizer,
trust_remote_code=True,
)
assert resolved_format == expected_format