[Tool parsing] Improve / correct mistral tool parsing (#10333)
This commit is contained in:
committed by
GitHub
parent
554af9228d
commit
11cd1ae6ad
@@ -2,9 +2,13 @@
|
||||
|
||||
Run `pytest tests/models/test_mistral.py`.
|
||||
"""
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa
|
||||
MistralToolParser)
|
||||
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
@@ -58,17 +62,69 @@ TOOLS = [{
|
||||
},
|
||||
"required": ["city", "state", "unit"]
|
||||
}
|
||||
},
|
||||
}, {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rewrite",
|
||||
"description": "Rewrites text",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"required": [],
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The input text to rewrite."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}]
|
||||
MSGS = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": ("Can you tell me what the temperate"
|
||||
" will be in Dallas, in fahrenheit?")
|
||||
}]
|
||||
EXPECTED_FUNC_CALL = (
|
||||
'[{"name": "get_current_weather", "arguments": '
|
||||
'{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]')
|
||||
MSGS = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are an assistant."
|
||||
},
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Could you please rewrite the below article? \n\n My English needs improvving, maybe I make errors." # noqa
|
||||
},
|
||||
{
|
||||
"role":
|
||||
"assistant",
|
||||
"content":
|
||||
"",
|
||||
"tool_calls": [{
|
||||
"id": "bbc5b7ede",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name":
|
||||
"rewrite",
|
||||
"arguments":
|
||||
'{\"text\":\"My English needs improvving, maybe I make errors.\"}' # noqa
|
||||
}
|
||||
}]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content":
|
||||
"{\"action\":\"rewrite\",\"outcome\":\"My English needs improving, maybe I make errors.\"}", # noqa
|
||||
"tool_call_id": "bbc5b7ede",
|
||||
"name": "rewrite"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "---\n\nMy English needs improving, maybe I make errors"
|
||||
},
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"content": ("Can you tell me what the temperate"
|
||||
" will be in Dallas, in fahrenheit?")
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@@ -175,8 +231,23 @@ def test_mistral_function_calling(
|
||||
tokenizer_mode="mistral",
|
||||
config_format="mistral",
|
||||
load_format="mistral") as vllm_model:
|
||||
outputs = vllm_model.model.chat(MSGS,
|
||||
|
||||
msgs = copy.deepcopy(MSGS)
|
||||
outputs = vllm_model.model.chat(msgs,
|
||||
tools=TOOLS,
|
||||
sampling_params=SAMPLING_PARAMS)
|
||||
|
||||
assert outputs[0].outputs[0].text.strip() == EXPECTED_FUNC_CALL
|
||||
tokenizer = vllm_model.model.get_tokenizer()
|
||||
tool_parser = MistralToolParser(tokenizer)
|
||||
|
||||
model_output = outputs[0].outputs[0].text.strip()
|
||||
assert model_output.startswith(tool_parser.bot_token), model_output
|
||||
parsed_message = tool_parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert parsed_message.tools_called
|
||||
assert parsed_message.tool_calls[0].id == "0UAqFzWsD"
|
||||
assert parsed_message.tool_calls[
|
||||
0].function.name == "get_current_weather"
|
||||
assert parsed_message.tool_calls[
|
||||
0].function.arguments == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' # noqa
|
||||
assert parsed_message.content is None
|
||||
|
||||
Reference in New Issue
Block a user