[Frontend] Added support for HF's new continue_final_message parameter (#8942)

This commit is contained in:
danieljannai21
2024-09-29 20:59:47 +03:00
committed by GitHub
parent 1fb9c1b0bf
commit 6c9ba48fde
7 changed files with 102 additions and 28 deletions

View File

@@ -12,7 +12,7 @@ assert chatml_jinja_path.exists()
# Define models, templates, and their corresponding expected outputs
MODEL_TEMPLATE_GENERATON_OUTPUT = [
("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user
("facebook/opt-125m", chatml_jinja_path, True, False, """<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
@@ -20,12 +20,20 @@ Hi there!<|im_end|>
What is the capital of<|im_end|>
<|im_start|>assistant
"""),
("facebook/opt-125m", chatml_jinja_path, False, """<|im_start|>user
("facebook/opt-125m", chatml_jinja_path, False, False, """<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of""")
What is the capital of"""),
("facebook/opt-125m", chatml_jinja_path, False, True, """<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there!<|im_end|>
<|im_start|>user
What is the capital of<|im_end|>
<|im_start|>assistant
The capital of"""),
]
TEST_MESSAGES = [
@@ -42,6 +50,10 @@ TEST_MESSAGES = [
'content': 'What is the capital of'
},
]
ASSISTANT_MESSAGE_TO_CONTINUE = {
'role': 'assistant',
'content': 'The capital of'
}
def test_load_chat_template():
@@ -73,10 +85,10 @@ def test_no_load_chat_template_literallike():
@pytest.mark.parametrize(
"model,template,add_generation_prompt,expected_output",
"model,template,add_generation_prompt,continue_final_message,expected_output",
MODEL_TEMPLATE_GENERATON_OUTPUT)
def test_get_gen_prompt(model, template, add_generation_prompt,
expected_output):
continue_final_message, expected_output):
# Initialize the tokenizer
tokenizer = get_tokenizer(tokenizer_name=model)
template_content = load_chat_template(chat_template=template)
@@ -84,8 +96,11 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
# Create a mock request object using keyword arguments
mock_request = ChatCompletionRequest(
model=model,
messages=TEST_MESSAGES,
add_generation_prompt=add_generation_prompt)
messages=TEST_MESSAGES + [ASSISTANT_MESSAGE_TO_CONTINUE]
if continue_final_message else TEST_MESSAGES,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
)
# Call the function and get the result
result = apply_hf_chat_template(
@@ -93,6 +108,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
conversation=mock_request.messages,
chat_template=mock_request.chat_template or template_content,
add_generation_prompt=mock_request.add_generation_prompt,
continue_final_message=mock_request.continue_final_message,
)
# Test assertion