[Model] Add mistral function calling format to all models loaded with "mistral" format (#8515)

Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Patrick von Platen
2024-09-17 19:50:37 +02:00
committed by GitHub
parent 9855b99502
commit a54ed80249
5 changed files with 219 additions and 9 deletions

View File

@@ -4,13 +4,61 @@ Run `pytest tests/models/test_mistral.py`.
"""
import pytest
from vllm import SamplingParams
from ...utils import check_logprobs_close
MODELS = [
"mistralai/Mistral-7B-Instruct-v0.1",
"mistralai/Mistral-7B-Instruct-v0.3",
# Mistral-Nemo is to big for CI, but passes locally
# "mistralai/Mistral-Nemo-Instruct-2407"
]
SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)
# for function calling
TOOLS = [{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type":
"string",
"description":
"The city to find the weather for, e.g. 'San Francisco'"
},
"state": {
"type":
"string",
"description":
"the two-letter abbreviation for the state that the city is"
" in, e.g. 'CA' which would mean 'California'"
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["city", "state", "unit"]
}
}
}]
MSGS = [{
"role":
"user",
"content": ("Can you tell me what the temperate"
" will be in Dallas, in fahrenheit?")
}]
EXPECTED_FUNC_CALL = (
'[{"name": "get_current_weather", "arguments": '
'{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]')
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@@ -81,3 +129,22 @@ def test_mistral_format(
name_0="hf",
name_1="mistral",
)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("model", MODELS[1:]) # v1 can't do func calling
def test_mistral_function_calling(
vllm_runner,
model: str,
dtype: str,
) -> None:
with vllm_runner(model,
dtype=dtype,
tokenizer_mode="mistral",
config_format="mistral",
load_format="mistral") as vllm_model:
outputs = vllm_model.model.chat(MSGS,
tools=TOOLS,
sampling_params=SAMPLING_PARAMS)
assert outputs[0].outputs[0].text.strip() == EXPECTED_FUNC_CALL