[Bugfix] Fix chat template loading (#15143)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Roger Wang <ywang@roblox.com> Co-authored-by: chaunceyjiang <chaunceyjiang@gmail.com> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
@@ -4,10 +4,13 @@ import warnings
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from packaging.version import Version
|
||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
|
||||
from vllm.entrypoints.chat_utils import (_resolve_hf_chat_template,
|
||||
_try_extract_ast, load_chat_template,
|
||||
parse_chat_messages,
|
||||
parse_chat_messages_futures,
|
||||
resolve_chat_template_content_format)
|
||||
@@ -23,8 +26,10 @@ EXAMPLES_DIR = VLLM_PATH / "examples"
|
||||
PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
|
||||
ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
||||
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
|
||||
QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
|
||||
HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
@@ -703,25 +708,27 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
|
||||
|
||||
vllm_result = apply_hf_chat_template(
|
||||
tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
conversation=conversation,
|
||||
chat_template=None,
|
||||
tools=None,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
assert hf_result == vllm_result
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("model", "expected_format"),
|
||||
[(PHI3V_MODEL_ID, "string"),
|
||||
(QWEN2VL_MODEL_ID, "openai"),
|
||||
(ULTRAVOX_MODEL_ID, "string"),
|
||||
(MLLAMA_MODEL_ID, "openai"),
|
||||
(LLAMA_GUARD_MODEL_ID, "openai")],
|
||||
)
|
||||
# yapf: enable
|
||||
def test_resolve_content_format_hf_defined(model, expected_format):
|
||||
"model",
|
||||
[
|
||||
QWEN2VL_MODEL_ID, # tokenizer.chat_template is of type str
|
||||
HERMES_MODEL_ID, # tokenizer.chat_template is of type dict
|
||||
])
|
||||
@pytest.mark.parametrize("use_tools", [True, False])
|
||||
def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
|
||||
"""checks that chat_template is a dict type for HF models."""
|
||||
|
||||
# Build the tokenizer group and grab the underlying tokenizer
|
||||
tokenizer_group = TokenizerGroup(
|
||||
model,
|
||||
enable_lora=False,
|
||||
@@ -730,7 +737,56 @@ def test_resolve_content_format_hf_defined(model, expected_format):
|
||||
)
|
||||
tokenizer = tokenizer_group.tokenizer
|
||||
|
||||
chat_template = tokenizer.chat_template
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "dummy_function_name",
|
||||
"description": "This is a dummy function",
|
||||
"parameters": sample_json_schema
|
||||
}
|
||||
}] if use_tools else None
|
||||
|
||||
# Test detecting the tokenizer's chat_template
|
||||
chat_template = _resolve_hf_chat_template(
|
||||
tokenizer,
|
||||
chat_template=None,
|
||||
tools=tools,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
assert isinstance(chat_template, str)
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("model", "expected_format"),
|
||||
[(PHI3V_MODEL_ID, "string"),
|
||||
(QWEN2VL_MODEL_ID, "openai"),
|
||||
(QWEN25VL_MODEL_ID, "openai"),
|
||||
(ULTRAVOX_MODEL_ID, "string"),
|
||||
(MLLAMA_MODEL_ID, "openai"),
|
||||
(LLAMA_GUARD_MODEL_ID, "openai")],
|
||||
)
|
||||
# yapf: enable
|
||||
def test_resolve_content_format_hf_defined(model, expected_format):
|
||||
if model == QWEN25VL_MODEL_ID and Version(TRANSFORMERS_VERSION) < Version(
|
||||
"4.49.0"):
|
||||
pytest.skip("Qwen2.5-VL requires transformers>=4.49.0")
|
||||
|
||||
tokenizer_group = TokenizerGroup(
|
||||
model,
|
||||
enable_lora=False,
|
||||
max_num_seqs=5,
|
||||
max_input_length=None,
|
||||
)
|
||||
tokenizer = tokenizer_group.tokenizer
|
||||
|
||||
# Test detecting the tokenizer's chat_template
|
||||
chat_template = _resolve_hf_chat_template(
|
||||
tokenizer,
|
||||
chat_template=None,
|
||||
tools=None,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
assert isinstance(chat_template, str)
|
||||
|
||||
print("[TEXT]")
|
||||
@@ -740,8 +796,10 @@ def test_resolve_content_format_hf_defined(model, expected_format):
|
||||
|
||||
resolved_format = resolve_chat_template_content_format(
|
||||
None, # Test detecting the tokenizer's chat_template
|
||||
None,
|
||||
"auto",
|
||||
tokenizer,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
assert resolved_format == expected_format
|
||||
@@ -791,8 +849,10 @@ def test_resolve_content_format_examples(template_path, expected_format):
|
||||
|
||||
resolved_format = resolve_chat_template_content_format(
|
||||
chat_template,
|
||||
None,
|
||||
"auto",
|
||||
dummy_tokenizer,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
assert resolved_format == expected_format
|
||||
|
||||
Reference in New Issue
Block a user