Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -4,8 +4,7 @@
|
||||
import pytest
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
|
||||
load_chat_template)
|
||||
from vllm.entrypoints.chat_utils import apply_hf_chat_template, load_chat_template
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
@@ -17,48 +16,54 @@ assert chatml_jinja_path.exists()
|
||||
|
||||
# Define models, templates, and their corresponding expected outputs
|
||||
MODEL_TEMPLATE_GENERATION_OUTPUT = [
|
||||
("facebook/opt-125m", chatml_jinja_path, True, False, """<|im_start|>user
|
||||
(
|
||||
"facebook/opt-125m",
|
||||
chatml_jinja_path,
|
||||
True,
|
||||
False,
|
||||
"""<|im_start|>user
|
||||
Hello<|im_end|>
|
||||
<|im_start|>assistant
|
||||
Hi there!<|im_end|>
|
||||
<|im_start|>user
|
||||
What is the capital of<|im_end|>
|
||||
<|im_start|>assistant
|
||||
"""),
|
||||
("facebook/opt-125m", chatml_jinja_path, False, False, """<|im_start|>user
|
||||
""",
|
||||
),
|
||||
(
|
||||
"facebook/opt-125m",
|
||||
chatml_jinja_path,
|
||||
False,
|
||||
False,
|
||||
"""<|im_start|>user
|
||||
Hello<|im_end|>
|
||||
<|im_start|>assistant
|
||||
Hi there!<|im_end|>
|
||||
<|im_start|>user
|
||||
What is the capital of"""),
|
||||
("facebook/opt-125m", chatml_jinja_path, False, True, """<|im_start|>user
|
||||
What is the capital of""",
|
||||
),
|
||||
(
|
||||
"facebook/opt-125m",
|
||||
chatml_jinja_path,
|
||||
False,
|
||||
True,
|
||||
"""<|im_start|>user
|
||||
Hello<|im_end|>
|
||||
<|im_start|>assistant
|
||||
Hi there!<|im_end|>
|
||||
<|im_start|>user
|
||||
What is the capital of<|im_end|>
|
||||
<|im_start|>assistant
|
||||
The capital of"""),
|
||||
The capital of""",
|
||||
),
|
||||
]
|
||||
|
||||
TEST_MESSAGES = [
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'Hello'
|
||||
},
|
||||
{
|
||||
'role': 'assistant',
|
||||
'content': 'Hi there!'
|
||||
},
|
||||
{
|
||||
'role': 'user',
|
||||
'content': 'What is the capital of'
|
||||
},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
{"role": "user", "content": "What is the capital of"},
|
||||
]
|
||||
ASSISTANT_MESSAGE_TO_CONTINUE = {
|
||||
'role': 'assistant',
|
||||
'content': 'The capital of'
|
||||
}
|
||||
ASSISTANT_MESSAGE_TO_CONTINUE = {"role": "assistant", "content": "The capital of"}
|
||||
|
||||
|
||||
def test_load_chat_template():
|
||||
@@ -68,8 +73,11 @@ def test_load_chat_template():
|
||||
# Test assertions
|
||||
assert template_content is not None
|
||||
# Hard coded value for template_chatml.jinja
|
||||
assert template_content == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
|
||||
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501
|
||||
assert (
|
||||
template_content
|
||||
== """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
|
||||
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}"""
|
||||
) # noqa: E501
|
||||
|
||||
|
||||
def test_no_load_chat_template_filelike():
|
||||
@@ -91,9 +99,11 @@ def test_no_load_chat_template_literallike():
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,template,add_generation_prompt,continue_final_message,expected_output",
|
||||
MODEL_TEMPLATE_GENERATION_OUTPUT)
|
||||
def test_get_gen_prompt(model, template, add_generation_prompt,
|
||||
continue_final_message, expected_output):
|
||||
MODEL_TEMPLATE_GENERATION_OUTPUT,
|
||||
)
|
||||
def test_get_gen_prompt(
|
||||
model, template, add_generation_prompt, continue_final_message, expected_output
|
||||
):
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
|
||||
@@ -106,7 +116,8 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
dtype=model_info.dtype)
|
||||
dtype=model_info.dtype,
|
||||
)
|
||||
|
||||
# Initialize the tokenizer
|
||||
tokenizer = get_tokenizer(
|
||||
@@ -119,7 +130,8 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
|
||||
mock_request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=TEST_MESSAGES + [ASSISTANT_MESSAGE_TO_CONTINUE]
|
||||
if continue_final_message else TEST_MESSAGES,
|
||||
if continue_final_message
|
||||
else TEST_MESSAGES,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=continue_final_message,
|
||||
)
|
||||
@@ -138,4 +150,5 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
|
||||
# Test assertion
|
||||
assert result == expected_output, (
|
||||
f"The generated prompt does not match the expected output for "
|
||||
f"model {model} and template {template}")
|
||||
f"model {model} and template {template}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user