[Tool parsing] Improve / correct mistral tool parsing (#10333)

This commit is contained in:
Patrick von Platen
2024-11-15 01:42:49 +01:00
committed by GitHub
parent 554af9228d
commit 11cd1ae6ad
5 changed files with 172 additions and 59 deletions

View File

@@ -2,9 +2,13 @@
Run `pytest tests/models/test_mistral.py`.
"""
import copy
import pytest
from vllm import SamplingParams
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa
MistralToolParser)
from ...utils import check_logprobs_close
@@ -58,17 +62,69 @@ TOOLS = [{
},
"required": ["city", "state", "unit"]
}
},
}, {
"type": "function",
"function": {
"name": "rewrite",
"description": "Rewrites text",
"parameters": {
"type": "object",
"required": [],
"properties": {
"text": {
"type": "string",
"description": "The input text to rewrite."
}
}
}
}
}]
MSGS = [{
"role":
"user",
"content": ("Can you tell me what the temperate"
" will be in Dallas, in fahrenheit?")
}]
EXPECTED_FUNC_CALL = (
'[{"name": "get_current_weather", "arguments": '
'{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]')
MSGS = [
{
"role": "system",
"content": "You are an assistant."
},
{
"role":
"user",
"content":
"Could you please rewrite the below article? \n\n My English needs improvving, maybe I make errors." # noqa
},
{
"role":
"assistant",
"content":
"",
"tool_calls": [{
"id": "bbc5b7ede",
"type": "function",
"function": {
"name":
"rewrite",
"arguments":
'{\"text\":\"My English needs improvving, maybe I make errors.\"}' # noqa
}
}]
},
{
"role": "tool",
"content":
"{\"action\":\"rewrite\",\"outcome\":\"My English needs improving, maybe I make errors.\"}", # noqa
"tool_call_id": "bbc5b7ede",
"name": "rewrite"
},
{
"role": "assistant",
"content": "---\n\nMy English needs improving, maybe I make errors"
},
{
"role":
"user",
"content": ("Can you tell me what the temperate"
" will be in Dallas, in fahrenheit?")
}
]
@pytest.mark.parametrize("model", MODELS)
@@ -175,8 +231,23 @@ def test_mistral_function_calling(
tokenizer_mode="mistral",
config_format="mistral",
load_format="mistral") as vllm_model:
outputs = vllm_model.model.chat(MSGS,
msgs = copy.deepcopy(MSGS)
outputs = vllm_model.model.chat(msgs,
tools=TOOLS,
sampling_params=SAMPLING_PARAMS)
assert outputs[0].outputs[0].text.strip() == EXPECTED_FUNC_CALL
tokenizer = vllm_model.model.get_tokenizer()
tool_parser = MistralToolParser(tokenizer)
model_output = outputs[0].outputs[0].text.strip()
assert model_output.startswith(tool_parser.bot_token), model_output
parsed_message = tool_parser.extract_tool_calls(model_output, None)
assert parsed_message.tools_called
assert parsed_message.tool_calls[0].id == "0UAqFzWsD"
assert parsed_message.tool_calls[
0].function.name == "get_current_weather"
assert parsed_message.tool_calls[
0].function.arguments == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' # noqa
assert parsed_message.content is None