[Frontend] Gracefully handle missing chat template and fix CI failure (#7238)

Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Cyrus Leung
2024-08-07 17:12:05 +08:00
committed by GitHub
parent 7b261092de
commit 66d617e343
9 changed files with 125 additions and 69 deletions

View File

@@ -1,22 +1,16 @@
import os
import pathlib
import pytest
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.chat_utils import apply_chat_template, load_chat_template
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.transformers_utils.tokenizer import get_tokenizer
chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath(
__file__))).parent.parent / "examples/template_chatml.jinja"
from ..utils import VLLM_PATH
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()
# Define models, templates, and their corresponding expected outputs
MODEL_TEMPLATE_GENERATON_OUTPUT = [
("facebook/opt-125m", None, True,
"Hello</s>Hi there!</s>What is the capital of</s>"),
("facebook/opt-125m", None, False,
"Hello</s>Hi there!</s>What is the capital of</s>"),
("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
@@ -93,11 +87,12 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
add_generation_prompt=add_generation_prompt)
# Call the function and get the result
result = tokenizer.apply_chat_template(
result = apply_chat_template(
tokenizer,
conversation=mock_request.messages,
tokenize=False,
chat_template=mock_request.chat_template or template_content,
add_generation_prompt=mock_request.add_generation_prompt,
chat_template=mock_request.chat_template or template_content)
)
# Test assertion
assert result == expected_output, (

View File

@@ -1,10 +1,12 @@
import openai # use the official client for correctness check
import pytest
from ..utils import RemoteOpenAIServer
from ..utils import VLLM_PATH, RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "facebook/opt-125m"
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()
@pytest.fixture(scope="module")
@@ -16,7 +18,9 @@ def server():
"--max-model-len",
"2048",
"--enforce-eager",
"--engine-use-ray"
"--engine-use-ray",
"--chat-template",
str(chatml_jinja_path),
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
@@ -83,7 +87,7 @@ async def test_single_chat_session(client: openai.AsyncOpenAI):
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=13, total_tokens=23)
completion_tokens=10, prompt_tokens=55, total_tokens=65)
message = choice.message
assert message.content is not None and len(message.content) >= 10