[Bugfix] Fix chat template loading (#15143)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: chaunceyjiang <chaunceyjiang@gmail.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Cyrus Leung
2025-03-24 21:50:09 +08:00
committed by GitHub
parent 038de04d7b
commit cbcdf2c609
7 changed files with 196 additions and 56 deletions

View File

@@ -4,10 +4,13 @@ import warnings
from typing import Optional
import pytest
from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.assets.image import ImageAsset
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
from vllm.entrypoints.chat_utils import (_resolve_hf_chat_template,
_try_extract_ast, load_chat_template,
parse_chat_messages,
parse_chat_messages_futures,
resolve_chat_template_content_format)
@@ -23,8 +26,10 @@ EXAMPLES_DIR = VLLM_PATH / "examples"
PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
@pytest.fixture(scope="function")
@@ -703,25 +708,27 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
vllm_result = apply_hf_chat_template(
tokenizer,
trust_remote_code=model_config.trust_remote_code,
conversation=conversation,
chat_template=None,
tools=None,
add_generation_prompt=True,
)
assert hf_result == vllm_result
# yapf: disable
@pytest.mark.parametrize(
("model", "expected_format"),
[(PHI3V_MODEL_ID, "string"),
(QWEN2VL_MODEL_ID, "openai"),
(ULTRAVOX_MODEL_ID, "string"),
(MLLAMA_MODEL_ID, "openai"),
(LLAMA_GUARD_MODEL_ID, "openai")],
)
# yapf: enable
def test_resolve_content_format_hf_defined(model, expected_format):
"model",
[
QWEN2VL_MODEL_ID, # tokenizer.chat_template is of type str
HERMES_MODEL_ID, # tokenizer.chat_template is of type dict
])
@pytest.mark.parametrize("use_tools", [True, False])
def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
"""checks that chat_template is a dict type for HF models."""
# Build the tokenizer group and grab the underlying tokenizer
tokenizer_group = TokenizerGroup(
model,
enable_lora=False,
@@ -730,7 +737,56 @@ def test_resolve_content_format_hf_defined(model, expected_format):
)
tokenizer = tokenizer_group.tokenizer
chat_template = tokenizer.chat_template
tools = [{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": sample_json_schema
}
}] if use_tools else None
# Test detecting the tokenizer's chat_template
chat_template = _resolve_hf_chat_template(
tokenizer,
chat_template=None,
tools=tools,
trust_remote_code=True,
)
assert isinstance(chat_template, str)
# yapf: disable
@pytest.mark.parametrize(
("model", "expected_format"),
[(PHI3V_MODEL_ID, "string"),
(QWEN2VL_MODEL_ID, "openai"),
(QWEN25VL_MODEL_ID, "openai"),
(ULTRAVOX_MODEL_ID, "string"),
(MLLAMA_MODEL_ID, "openai"),
(LLAMA_GUARD_MODEL_ID, "openai")],
)
# yapf: enable
def test_resolve_content_format_hf_defined(model, expected_format):
if model == QWEN25VL_MODEL_ID and Version(TRANSFORMERS_VERSION) < Version(
"4.49.0"):
pytest.skip("Qwen2.5-VL requires transformers>=4.49.0")
tokenizer_group = TokenizerGroup(
model,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
)
tokenizer = tokenizer_group.tokenizer
# Test detecting the tokenizer's chat_template
chat_template = _resolve_hf_chat_template(
tokenizer,
chat_template=None,
tools=None,
trust_remote_code=True,
)
assert isinstance(chat_template, str)
print("[TEXT]")
@@ -740,8 +796,10 @@ def test_resolve_content_format_hf_defined(model, expected_format):
resolved_format = resolve_chat_template_content_format(
None, # Test detecting the tokenizer's chat_template
None,
"auto",
tokenizer,
trust_remote_code=True,
)
assert resolved_format == expected_format
@@ -791,8 +849,10 @@ def test_resolve_content_format_examples(template_path, expected_format):
resolved_format = resolve_chat_template_content_format(
chat_template,
None,
"auto",
dummy_tokenizer,
trust_remote_code=True,
)
assert resolved_format == expected_format