[Frontend] Added support for HF's new continue_final_message parameter (#8942)
This commit is contained in:
@@ -12,7 +12,7 @@ assert chatml_jinja_path.exists()
|
||||
|
||||
# Define models, templates, and their corresponding expected outputs
|
||||
MODEL_TEMPLATE_GENERATON_OUTPUT = [
|
||||
("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user
|
||||
("facebook/opt-125m", chatml_jinja_path, True, False, """<|im_start|>user
|
||||
Hello<|im_end|>
|
||||
<|im_start|>assistant
|
||||
Hi there!<|im_end|>
|
||||
@@ -20,12 +20,20 @@ Hi there!<|im_end|>
|
||||
What is the capital of<|im_end|>
|
||||
<|im_start|>assistant
|
||||
"""),
|
||||
("facebook/opt-125m", chatml_jinja_path, False, """<|im_start|>user
|
||||
("facebook/opt-125m", chatml_jinja_path, False, False, """<|im_start|>user
|
||||
Hello<|im_end|>
|
||||
<|im_start|>assistant
|
||||
Hi there!<|im_end|>
|
||||
<|im_start|>user
|
||||
What is the capital of""")
|
||||
What is the capital of"""),
|
||||
("facebook/opt-125m", chatml_jinja_path, False, True, """<|im_start|>user
|
||||
Hello<|im_end|>
|
||||
<|im_start|>assistant
|
||||
Hi there!<|im_end|>
|
||||
<|im_start|>user
|
||||
What is the capital of<|im_end|>
|
||||
<|im_start|>assistant
|
||||
The capital of"""),
|
||||
]
|
||||
|
||||
TEST_MESSAGES = [
|
||||
@@ -42,6 +50,10 @@ TEST_MESSAGES = [
|
||||
'content': 'What is the capital of'
|
||||
},
|
||||
]
|
||||
ASSISTANT_MESSAGE_TO_CONTINUE = {
|
||||
'role': 'assistant',
|
||||
'content': 'The capital of'
|
||||
}
|
||||
|
||||
|
||||
def test_load_chat_template():
|
||||
@@ -73,10 +85,10 @@ def test_no_load_chat_template_literallike():
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,template,add_generation_prompt,expected_output",
|
||||
"model,template,add_generation_prompt,continue_final_message,expected_output",
|
||||
MODEL_TEMPLATE_GENERATON_OUTPUT)
|
||||
def test_get_gen_prompt(model, template, add_generation_prompt,
|
||||
expected_output):
|
||||
continue_final_message, expected_output):
|
||||
# Initialize the tokenizer
|
||||
tokenizer = get_tokenizer(tokenizer_name=model)
|
||||
template_content = load_chat_template(chat_template=template)
|
||||
@@ -84,8 +96,11 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
|
||||
# Create a mock request object using keyword arguments
|
||||
mock_request = ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=TEST_MESSAGES,
|
||||
add_generation_prompt=add_generation_prompt)
|
||||
messages=TEST_MESSAGES + [ASSISTANT_MESSAGE_TO_CONTINUE]
|
||||
if continue_final_message else TEST_MESSAGES,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
continue_final_message=continue_final_message,
|
||||
)
|
||||
|
||||
# Call the function and get the result
|
||||
result = apply_hf_chat_template(
|
||||
@@ -93,6 +108,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
|
||||
conversation=mock_request.messages,
|
||||
chat_template=mock_request.chat_template or template_content,
|
||||
add_generation_prompt=mock_request.add_generation_prompt,
|
||||
continue_final_message=mock_request.continue_final_message,
|
||||
)
|
||||
|
||||
# Test assertion
|
||||
|
||||
@@ -104,28 +104,40 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str,
|
||||
"role": "user",
|
||||
"content": "Can I ask a question? vllm1"
|
||||
}]
|
||||
for continue_final in [False, True]:
|
||||
if add_generation and continue_final:
|
||||
continue
|
||||
if continue_final:
|
||||
conversation.append({
|
||||
"role": "assistant",
|
||||
"content": "Sure,"
|
||||
})
|
||||
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
add_generation_prompt=add_generation,
|
||||
conversation=conversation,
|
||||
tokenize=False)
|
||||
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
add_generation_prompt=add_generation,
|
||||
continue_final_message=continue_final,
|
||||
conversation=conversation,
|
||||
tokenize=False)
|
||||
tokens = tokenizer.encode(prompt,
|
||||
add_special_tokens=add_special)
|
||||
|
||||
response = requests.post(base_url + "/tokenize",
|
||||
json={
|
||||
"add_generation_prompt":
|
||||
add_generation,
|
||||
"add_special_tokens": add_special,
|
||||
"messages": conversation,
|
||||
"model": model_name
|
||||
})
|
||||
response.raise_for_status()
|
||||
response = requests.post(base_url + "/tokenize",
|
||||
json={
|
||||
"add_generation_prompt":
|
||||
add_generation,
|
||||
"continue_final_message":
|
||||
continue_final,
|
||||
"add_special_tokens": add_special,
|
||||
"messages": conversation,
|
||||
"model": model_name
|
||||
})
|
||||
response.raise_for_status()
|
||||
|
||||
assert response.json() == {
|
||||
"tokens": tokens,
|
||||
"count": len(tokens),
|
||||
"max_model_len": 8192
|
||||
}
|
||||
assert response.json() == {
|
||||
"tokens": tokens,
|
||||
"count": len(tokens),
|
||||
"max_model_len": 8192
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user