[Test] Consolidate tool parser unit tests to tests/tool_parsers (#37834)
Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
@@ -1,12 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def default_tokenizer() -> TokenizerLike:
|
||||
return AutoTokenizer.from_pretrained("gpt2")
|
||||
@@ -1,343 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.tool_parsers.utils import (
|
||||
run_tool_extraction,
|
||||
run_tool_extraction_streaming,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import FunctionCall
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers import ToolParser, ToolParserManager
|
||||
|
||||
MSG_SEP_TOKEN = "<|message_sep|>\n\n"
|
||||
ROLE_SEP_TOKEN = "<|role_sep|>\n"
|
||||
EOS_TOKEN = "</s>"
|
||||
TOOL_HEADER_GIGACHAT3 = f"function call{ROLE_SEP_TOKEN}"
|
||||
TOOL_HEADER_GIGACHAT31 = "<|function_call|>"
|
||||
|
||||
|
||||
SIMPLE_ARGS_DICT = {
|
||||
"action": "create",
|
||||
"id": "preferences",
|
||||
}
|
||||
SIMPLE_FUNCTION_JSON = json.dumps(
|
||||
{
|
||||
"name": "manage_user_memory",
|
||||
"arguments": SIMPLE_ARGS_DICT,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
SIMPLE_FUNCTION_OUTPUT_GIGACHAT3 = (
|
||||
f"{MSG_SEP_TOKEN}{TOOL_HEADER_GIGACHAT3}{SIMPLE_FUNCTION_JSON}"
|
||||
)
|
||||
SIMPLE_FUNCTION_OUTPUT_GIGACHAT31 = f"{TOOL_HEADER_GIGACHAT31}{SIMPLE_FUNCTION_JSON}"
|
||||
SIMPLE_FUNCTION_CALL = FunctionCall(
|
||||
name="manage_user_memory",
|
||||
arguments=json.dumps(SIMPLE_ARGS_DICT, ensure_ascii=False),
|
||||
)
|
||||
|
||||
|
||||
PARAMETERLESS_FUNCTION_JSON = json.dumps(
|
||||
{
|
||||
"name": "manage_user_memory",
|
||||
"arguments": {},
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT3 = (
|
||||
f"{MSG_SEP_TOKEN}{TOOL_HEADER_GIGACHAT3}{PARAMETERLESS_FUNCTION_JSON}"
|
||||
)
|
||||
PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT31 = (
|
||||
f"{TOOL_HEADER_GIGACHAT31}{PARAMETERLESS_FUNCTION_JSON}"
|
||||
)
|
||||
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
|
||||
name="manage_user_memory",
|
||||
arguments=json.dumps({}, ensure_ascii=False),
|
||||
)
|
||||
|
||||
|
||||
COMPLEX_ARGS_DICT = {
|
||||
"action": "create",
|
||||
"id": "preferences",
|
||||
"content": {
|
||||
"short_answers": True,
|
||||
"hate_emojis": True,
|
||||
"english_ui": False,
|
||||
"russian_math_explanations": True,
|
||||
},
|
||||
}
|
||||
COMPLEX_FUNCTION_JSON = json.dumps(
|
||||
{
|
||||
"name": "manage_user_memory",
|
||||
"arguments": COMPLEX_ARGS_DICT,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
COMPLEX_FUNCTION_OUTPUT_GIGACHAT3 = (
|
||||
f"{MSG_SEP_TOKEN}{TOOL_HEADER_GIGACHAT3}{COMPLEX_FUNCTION_JSON}"
|
||||
)
|
||||
COMPLEX_FUNCTION_OUTPUT_GIGACHAT31 = f"{TOOL_HEADER_GIGACHAT31}{COMPLEX_FUNCTION_JSON}"
|
||||
COMPLEX_FUNCTION_CALL = FunctionCall(
|
||||
name="manage_user_memory",
|
||||
arguments=json.dumps(COMPLEX_ARGS_DICT, ensure_ascii=False),
|
||||
)
|
||||
|
||||
|
||||
CONTENT_TEXT = "I'll check that for you."
|
||||
MIXED_OUTPUT_GIGACHAT3 = f"{CONTENT_TEXT}{SIMPLE_FUNCTION_OUTPUT_GIGACHAT3}"
|
||||
MIXED_OUTPUT_GIGACHAT31 = f"{CONTENT_TEXT}{SIMPLE_FUNCTION_OUTPUT_GIGACHAT31}"
|
||||
|
||||
|
||||
@pytest.fixture(name="gigachat_tokenizer")
|
||||
def fixture_gigachat_tokenizer(default_tokenizer: TokenizerLike):
|
||||
default_tokenizer.add_tokens(
|
||||
[
|
||||
MSG_SEP_TOKEN,
|
||||
ROLE_SEP_TOKEN,
|
||||
TOOL_HEADER_GIGACHAT31,
|
||||
EOS_TOKEN,
|
||||
]
|
||||
)
|
||||
return default_tokenizer
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_no_tool_call(streaming: bool, gigachat_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
|
||||
gigachat_tokenizer
|
||||
)
|
||||
model_output = "How can I help you today?"
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
assert content == model_output
|
||||
assert len(tool_calls) == 0
|
||||
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(
|
||||
True,
|
||||
SIMPLE_FUNCTION_OUTPUT_GIGACHAT3,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
None,
|
||||
id="simple_streaming_gigachat3",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
SIMPLE_FUNCTION_OUTPUT_GIGACHAT3,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
None,
|
||||
id="simple_nonstreaming_gigachat3",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT3,
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
None,
|
||||
id="parameterless_streaming_gigachat3",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT3,
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
None,
|
||||
id="parameterless_nonstreaming_gigachat3",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
COMPLEX_FUNCTION_OUTPUT_GIGACHAT3,
|
||||
[COMPLEX_FUNCTION_CALL],
|
||||
None,
|
||||
id="complex_streaming_gigachat3",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
COMPLEX_FUNCTION_OUTPUT_GIGACHAT3,
|
||||
[COMPLEX_FUNCTION_CALL],
|
||||
None,
|
||||
id="complex_nonstreaming_gigachat3",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MIXED_OUTPUT_GIGACHAT3,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
CONTENT_TEXT,
|
||||
id="mixed_content_streaming_gigachat3",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MIXED_OUTPUT_GIGACHAT3,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
CONTENT_TEXT,
|
||||
id="mixed_content_nonstreaming_gigachat3",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MIXED_OUTPUT_GIGACHAT3 + EOS_TOKEN,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
CONTENT_TEXT,
|
||||
id="mixed_content_streaming_with_eos_gigachat3",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MIXED_OUTPUT_GIGACHAT3 + EOS_TOKEN,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
CONTENT_TEXT,
|
||||
id="mixed_content_nonstreaming_with_eos_gigachat3",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
SIMPLE_FUNCTION_OUTPUT_GIGACHAT31,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
None,
|
||||
id="simple_streaming_gigachat31",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
SIMPLE_FUNCTION_OUTPUT_GIGACHAT31,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
None,
|
||||
id="simple_nonstreaming_gigachat31",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT31,
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
None,
|
||||
id="parameterless_streaming_gigachat31",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT31,
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
None,
|
||||
id="parameterless_nonstreaming_gigachat31",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
COMPLEX_FUNCTION_OUTPUT_GIGACHAT31,
|
||||
[COMPLEX_FUNCTION_CALL],
|
||||
None,
|
||||
id="complex_streaming_gigachat31",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
COMPLEX_FUNCTION_OUTPUT_GIGACHAT31,
|
||||
[COMPLEX_FUNCTION_CALL],
|
||||
None,
|
||||
id="complex_nonstreaming_gigachat31",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MIXED_OUTPUT_GIGACHAT31,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
CONTENT_TEXT,
|
||||
id="mixed_content_streaming_gigachat31",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MIXED_OUTPUT_GIGACHAT31,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
CONTENT_TEXT,
|
||||
id="mixed_content_nonstreaming_gigachat31",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MIXED_OUTPUT_GIGACHAT31 + EOS_TOKEN,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
CONTENT_TEXT,
|
||||
id="mixed_content_streaming_with_eos_gigachat31",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MIXED_OUTPUT_GIGACHAT31 + EOS_TOKEN,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
CONTENT_TEXT,
|
||||
id="mixed_content_nonstreaming_with_eos_gigachat31",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"streaming, model_output, expected_tool_calls, expected_content", TEST_CASES
|
||||
)
|
||||
def test_tool_call(
|
||||
streaming: bool,
|
||||
model_output: str,
|
||||
expected_tool_calls: list[FunctionCall],
|
||||
expected_content: str | None,
|
||||
gigachat_tokenizer: TokenizerLike,
|
||||
):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
|
||||
gigachat_tokenizer
|
||||
)
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
if content == "":
|
||||
content = None
|
||||
assert content == expected_content
|
||||
assert len(tool_calls) == len(expected_tool_calls)
|
||||
for actual, expected in zip(tool_calls, expected_tool_calls):
|
||||
assert actual.type == "function"
|
||||
assert actual.function.name == expected.name
|
||||
actual_args = json.loads(actual.function.arguments)
|
||||
expected_args = json.loads(expected.arguments)
|
||||
assert actual_args == expected_args
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_output_deltas",
|
||||
[
|
||||
pytest.param(
|
||||
[
|
||||
CONTENT_TEXT[:3],
|
||||
CONTENT_TEXT[3:5],
|
||||
CONTENT_TEXT[5:],
|
||||
MSG_SEP_TOKEN,
|
||||
TOOL_HEADER_GIGACHAT3,
|
||||
COMPLEX_FUNCTION_JSON[:40],
|
||||
COMPLEX_FUNCTION_JSON[40:-1],
|
||||
COMPLEX_FUNCTION_JSON[-1],
|
||||
],
|
||||
id="gigachat3",
|
||||
),
|
||||
pytest.param(
|
||||
[
|
||||
CONTENT_TEXT[:3],
|
||||
CONTENT_TEXT[3:5],
|
||||
CONTENT_TEXT[5:],
|
||||
TOOL_HEADER_GIGACHAT31,
|
||||
COMPLEX_FUNCTION_JSON[:40],
|
||||
COMPLEX_FUNCTION_JSON[40:-1],
|
||||
COMPLEX_FUNCTION_JSON[-1],
|
||||
],
|
||||
id="gigachat31",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_streaming_tool_call_with_large_steps(
|
||||
model_output_deltas: list[str],
|
||||
gigachat_tokenizer: TokenizerLike,
|
||||
):
|
||||
"""
|
||||
Test that the closing braces are streamed correctly.
|
||||
"""
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
|
||||
gigachat_tokenizer
|
||||
)
|
||||
reconstructor = run_tool_extraction_streaming(
|
||||
tool_parser,
|
||||
model_output_deltas,
|
||||
assert_one_tool_per_delta=False,
|
||||
)
|
||||
assert len(reconstructor.tool_calls) == 1
|
||||
call = reconstructor.tool_calls[0]
|
||||
assert call.type == "function"
|
||||
assert call.function.name == "manage_user_memory"
|
||||
args_dict = json.loads(call.function.arguments)
|
||||
assert args_dict == COMPLEX_ARGS_DICT
|
||||
@@ -1,18 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
import random
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaMessage,
|
||||
)
|
||||
from vllm.tool_parsers.granite4_tool_parser import Granite4ToolParser
|
||||
|
||||
from ....utils import RemoteOpenAIServer
|
||||
|
||||
@@ -38,137 +29,6 @@ def server():
|
||||
yield server
|
||||
|
||||
|
||||
def create_complex_input(create_string_args: bool):
|
||||
coord_arg: dict | str = {
|
||||
"coordinates": [[23.54, 43.1], [-12.2, 54.3], [4, 5]],
|
||||
"coordinate_type": "latlong",
|
||||
}
|
||||
if create_string_args:
|
||||
# test granite behavior
|
||||
coord_arg = json.dumps(coord_arg)
|
||||
return [
|
||||
{"name": "find_bbox", "arguments": coord_arg},
|
||||
{
|
||||
"name": "get_stock_price",
|
||||
"arguments": {
|
||||
"symbol": "AAPL",
|
||||
"start_date": "2021-01-01",
|
||||
"end_date": "2021-12-31",
|
||||
},
|
||||
},
|
||||
{"name": "find_bbox", "arguments": coord_arg},
|
||||
]
|
||||
|
||||
|
||||
def random_chunks(s: str, min_len: int, max_len: int):
|
||||
chunks = []
|
||||
i = 0
|
||||
n = len(s)
|
||||
|
||||
while i < n:
|
||||
size = random.randint(min_len, max_len)
|
||||
chunks.append(s[i : i + size])
|
||||
i += size
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def tokenizer():
|
||||
return AutoTokenizer.from_pretrained(MODEL)
|
||||
|
||||
|
||||
# create a variety of input chunk sizes
|
||||
@pytest.mark.parametrize(
|
||||
"min_chunk, max_chunk",
|
||||
[
|
||||
(1, 1),
|
||||
(1, 2),
|
||||
(5, 7),
|
||||
(6, 20),
|
||||
],
|
||||
)
|
||||
def test_tool_call_parser_complex(min_chunk: int, max_chunk: int, tokenizer):
|
||||
input_dicts = create_complex_input(True)
|
||||
|
||||
formatted_tcs = [
|
||||
"<tool_call> " + json.dumps(call) + " </tool_call>" for call in input_dicts
|
||||
]
|
||||
|
||||
text_messages = [
|
||||
"Here goes the bbox call: \n",
|
||||
" Now the stock price call: \n ",
|
||||
" Now another bbox call: \n ",
|
||||
" See? I'm a helpful assistant.",
|
||||
]
|
||||
|
||||
test_input = (
|
||||
text_messages[0]
|
||||
+ formatted_tcs[0]
|
||||
+ text_messages[1]
|
||||
+ formatted_tcs[1]
|
||||
+ text_messages[2]
|
||||
+ formatted_tcs[2]
|
||||
+ text_messages[3]
|
||||
)
|
||||
|
||||
any_chat_request = ChatCompletionRequest(
|
||||
seed=42,
|
||||
model=MODEL,
|
||||
messages=[],
|
||||
)
|
||||
|
||||
parser = Granite4ToolParser(tokenizer=tokenizer)
|
||||
|
||||
delta_messages = list[DeltaMessage]()
|
||||
for text in random_chunks(test_input, min_chunk, max_chunk):
|
||||
delta = parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text="",
|
||||
delta_text=text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=any_chat_request,
|
||||
)
|
||||
if delta is not None:
|
||||
delta_messages.append(delta)
|
||||
|
||||
content = ""
|
||||
tool_calls = list[dict[str, Any]]()
|
||||
|
||||
current_name = "__start__"
|
||||
current_args = ""
|
||||
|
||||
for msg in delta_messages:
|
||||
if msg.content:
|
||||
content += msg.content
|
||||
for tool_call in msg.tool_calls:
|
||||
if delta_func := tool_call.function:
|
||||
if delta_func.name is not None:
|
||||
if current_name == "__start__":
|
||||
current_name = delta_func.name
|
||||
|
||||
if delta_func.name != current_name:
|
||||
tool_calls.append(
|
||||
{
|
||||
"name": current_name,
|
||||
"arguments": json.loads(current_args),
|
||||
}
|
||||
)
|
||||
current_name = delta_func.name
|
||||
current_args = ""
|
||||
|
||||
if delta_func.arguments:
|
||||
current_args += delta_func.arguments
|
||||
|
||||
if current_name != "__start__":
|
||||
tool_calls.append({"name": current_name, "arguments": json.loads(current_args)})
|
||||
|
||||
assert content == "".join(text_messages)
|
||||
assert tool_calls == create_complex_input(False)
|
||||
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
|
||||
@@ -9,8 +9,6 @@ import pytest_asyncio
|
||||
from huggingface_hub import snapshot_download
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import ToolParser
|
||||
from vllm.tool_parsers.granite4_tool_parser import Granite4ToolParser
|
||||
from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
|
||||
@@ -325,202 +323,3 @@ async def test_streaming_product_tool_call(
|
||||
print("\n[Streaming Product Test Passed]")
|
||||
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
|
||||
print(f"Reconstructed Arguments: {arguments}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qwen_tokenizer() -> TokenizerLike:
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
|
||||
return get_tokenizer("Qwen/Qwen3-32B")
|
||||
|
||||
|
||||
@pytest.fixture(params=CONFIGS.keys())
|
||||
def hermes_parser(request, qwen_tokenizer: TokenizerLike) -> ToolParser:
|
||||
config = CONFIGS[request.param]
|
||||
return config["tool_parser"](qwen_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def any_chat_request() -> ChatCompletionRequest:
|
||||
return ChatCompletionRequest(
|
||||
seed=42,
|
||||
model="Qwen/Qwen3-32B",
|
||||
messages=[],
|
||||
)
|
||||
|
||||
|
||||
def test_hermes_parser_streaming_just_forward_text(
|
||||
qwen_tokenizer: TokenizerLike,
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """This is some prior text that has nothing to do with tool calling."""
|
||||
tokens = qwen_tokenizer.encode(text)
|
||||
previous_text = ""
|
||||
delta_messages = []
|
||||
for token in tokens:
|
||||
delta_text = qwen_tokenizer.decode([token])
|
||||
current_text = previous_text + delta_text
|
||||
delta = hermes_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=any_chat_request,
|
||||
)
|
||||
previous_text = current_text
|
||||
delta_messages.append(delta)
|
||||
|
||||
for delta in delta_messages:
|
||||
assert delta is not None
|
||||
assert not delta.tool_calls
|
||||
|
||||
print(delta_messages)
|
||||
assert "".join([delta.content for delta in delta_messages]) == text
|
||||
|
||||
|
||||
def test_hermes_parser_streaming_failure_case_bug_19056(
|
||||
qwen_tokenizer: TokenizerLike,
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """<tool_call>
|
||||
{"name": "final_answer", "arguments": {"trigger": true}}
|
||||
</tool_call>"""
|
||||
tokens = qwen_tokenizer.encode(text)
|
||||
previous_text = ""
|
||||
delta_messages = []
|
||||
for token in tokens:
|
||||
text = qwen_tokenizer.decode([token])
|
||||
current_text = previous_text + text
|
||||
delta = hermes_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=any_chat_request,
|
||||
)
|
||||
previous_text = current_text
|
||||
if delta is not None:
|
||||
delta_messages.append(delta)
|
||||
|
||||
assert delta_messages[0].tool_calls[0].function.name == "final_answer"
|
||||
tool_call_args = "".join(
|
||||
delta.tool_calls[0].function.arguments or "" for delta in delta_messages
|
||||
)
|
||||
assert tool_call_args == '{"trigger": true}'
|
||||
|
||||
|
||||
def test_hermes_parser_streaming(
|
||||
qwen_tokenizer: TokenizerLike,
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = '<tool_call>\
|
||||
{"name": "get_current_temperature",\
|
||||
"arguments": {"location":\
|
||||
"San Francisco, California, United States", "unit": "celsius"}}\
|
||||
</tool_call>'
|
||||
|
||||
tokens = qwen_tokenizer.encode(text)
|
||||
previous_text = ""
|
||||
delta_messages = []
|
||||
for token in tokens:
|
||||
text = qwen_tokenizer.decode([token])
|
||||
current_text = previous_text + text
|
||||
delta = hermes_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=any_chat_request,
|
||||
)
|
||||
previous_text = current_text
|
||||
if delta is not None:
|
||||
delta_messages.append(delta)
|
||||
print(delta_messages)
|
||||
assert delta_messages[0].tool_calls[0].function.name == "get_current_temperature"
|
||||
# load to normalize whitespace
|
||||
tool_call_args = json.loads(
|
||||
"".join(
|
||||
delta.tool_calls[0].function.arguments or "" for delta in delta_messages
|
||||
)
|
||||
)
|
||||
assert tool_call_args == {
|
||||
"location": "San Francisco, California, United States",
|
||||
"unit": "celsius",
|
||||
}
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_no_tool_call(
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """This is not a tool call."""
|
||||
tool_call = hermes_parser.extract_tool_calls(
|
||||
model_output=text,
|
||||
request=any_chat_request,
|
||||
)
|
||||
|
||||
assert tool_call is not None
|
||||
assert not tool_call.tools_called
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_tool_call_between_tags(
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """<tool_call>
|
||||
{"name": "final_answer", "arguments": {"trigger": true}}
|
||||
</tool_call>"""
|
||||
tool_call = hermes_parser.extract_tool_calls(
|
||||
model_output=text,
|
||||
request=any_chat_request,
|
||||
)
|
||||
|
||||
assert tool_call is not None
|
||||
assert tool_call.tools_called
|
||||
assert tool_call.tool_calls[0].function.name == "final_answer"
|
||||
assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}'
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_tool_call_until_eos(
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
if isinstance(hermes_parser, Granite4ToolParser):
|
||||
pytest.skip(reason="The Granite4 tool parser enforces a complete response")
|
||||
|
||||
text = """<tool_call>
|
||||
{"name": "final_answer", "arguments": {"trigger": true}}"""
|
||||
tool_call = hermes_parser.extract_tool_calls(
|
||||
model_output=text,
|
||||
request=any_chat_request,
|
||||
)
|
||||
|
||||
assert tool_call is not None
|
||||
assert tool_call.tools_called
|
||||
assert tool_call.tool_calls[0].function.name == "final_answer"
|
||||
assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}'
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_tool_call_invalid_json(
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
# Missing closing brace to trigger exception
|
||||
text = """<tool_call>
|
||||
{"name": "final_answer", "arguments": {"trigger": true}"""
|
||||
tool_call = hermes_parser.extract_tool_calls(
|
||||
model_output=text,
|
||||
request=any_chat_request,
|
||||
)
|
||||
|
||||
assert tool_call is not None
|
||||
assert not tool_call.tools_called
|
||||
|
||||
@@ -1,179 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.tool_parsers.utils import (
|
||||
run_tool_extraction,
|
||||
run_tool_extraction_streaming,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import FunctionCall, ToolCall
|
||||
from vllm.tool_parsers import ToolParser, ToolParserManager
|
||||
|
||||
|
||||
def make_tool_call(name, arguments):
|
||||
return ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(name=name, arguments=json.dumps(arguments)),
|
||||
)
|
||||
|
||||
|
||||
# TODO: add reason prefix and suffix.
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_output,expected_tool_calls,expected_content",
|
||||
[
|
||||
# No tool call
|
||||
("How can I help you today?", [], "How can I help you today?"),
|
||||
# Single tool call, no content
|
||||
(
|
||||
'<tool_calls>[{"name": "get_weather", "arguments": {"city": "San Francisco", "metric": "celsius"}}]</tool_calls>', # noqa: E501
|
||||
[
|
||||
make_tool_call(
|
||||
"get_weather", {"city": "San Francisco", "metric": "celsius"}
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
# Multiple tool calls
|
||||
(
|
||||
'<tool_calls>[{"name": "get_weather", "arguments": {"city": "San Francisco", "metric": "celsius"}}, {"name": "register_user", "arguments": {"name": "John Doe", "age": 37, "address": {"city": "San Francisco", "state": "CA"}, "role": null, "passed_test": true, "aliases": ["John", "Johnny"]}}]</tool_calls>', # noqa: E501
|
||||
[
|
||||
make_tool_call(
|
||||
"get_weather", {"city": "San Francisco", "metric": "celsius"}
|
||||
),
|
||||
make_tool_call(
|
||||
"register_user",
|
||||
{
|
||||
"name": "John Doe",
|
||||
"age": 37,
|
||||
"address": {"city": "San Francisco", "state": "CA"},
|
||||
"role": None,
|
||||
"passed_test": True,
|
||||
"aliases": ["John", "Johnny"],
|
||||
},
|
||||
),
|
||||
],
|
||||
None,
|
||||
),
|
||||
# Content before tool call
|
||||
(
|
||||
'I will call the tool now. <tool_calls>[{"name": "get_weather", "arguments": {"city": "Boston"}}]</tool_calls>', # noqa: E501
|
||||
[make_tool_call("get_weather", {"city": "Boston"})],
|
||||
"I will call the tool now. ",
|
||||
),
|
||||
# Content after tool call (should be stripped)
|
||||
(
|
||||
'<tool_calls>[{"name": "get_weather", "arguments": {"city": "Seattle"}}]</tool_calls>\nThank you!', # noqa: E501
|
||||
[make_tool_call("get_weather", {"city": "Seattle"})],
|
||||
None,
|
||||
),
|
||||
(
|
||||
'<tool_calls>[{"name": "complex_tool", "arguments": {"level1": {"level2": {"level3": {"value": 123}}}}}]</tool_calls>',
|
||||
[
|
||||
make_tool_call(
|
||||
"complex_tool", {"level1": {"level2": {"level3": {"value": 123}}}}
|
||||
)
|
||||
],
|
||||
None,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_hunyuan_a13b_tool_parser_extract(
|
||||
model_output, expected_tool_calls, expected_content
|
||||
):
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("hunyuan_a13b")(
|
||||
mock_tokenizer
|
||||
)
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=False
|
||||
)
|
||||
|
||||
# align the random id.
|
||||
for idx in range(len(tool_calls)):
|
||||
tool_calls[idx].id = expected_tool_calls[idx].id
|
||||
assert tool_calls == expected_tool_calls
|
||||
assert content == expected_content
|
||||
|
||||
|
||||
# Streaming test: simulate incremental output
|
||||
@pytest.mark.parametrize(
|
||||
"model_deltas,expected_tool_calls",
|
||||
[
|
||||
(
|
||||
[
|
||||
'<tool_calls>[{"name": "get_weather", ',
|
||||
'"arguments": {"city": "San Francisco", ',
|
||||
'"metric": "celsius"}}]',
|
||||
"</tool_calls>",
|
||||
],
|
||||
[
|
||||
make_tool_call(
|
||||
"get_weather", {"city": "San Francisco", "metric": "celsius"}
|
||||
)
|
||||
],
|
||||
),
|
||||
(
|
||||
[
|
||||
'<tool_calls>[{"name":',
|
||||
' "get_weather",',
|
||||
' "arguments":',
|
||||
' {"city": "Boston"}',
|
||||
"}]",
|
||||
"</tool_calls>",
|
||||
],
|
||||
[make_tool_call("get_weather", {"city": "Boston"})],
|
||||
),
|
||||
(
|
||||
[
|
||||
"",
|
||||
'<tool_calls>[{"name":',
|
||||
' "get_weather",',
|
||||
' "arguments":',
|
||||
' {"city": "Boston"}',
|
||||
"}]",
|
||||
"</tool_calls>",
|
||||
"\n</answer>",
|
||||
],
|
||||
[make_tool_call("get_weather", {"city": "Boston"})],
|
||||
),
|
||||
pytest.param(
|
||||
[
|
||||
'<tool_calls>[{"name": "complex_tool",',
|
||||
' "arguments": ',
|
||||
' {"level1": {"level2": ',
|
||||
'{"level3": {"value": 123}}}}}',
|
||||
"]</tool_calls>",
|
||||
],
|
||||
[
|
||||
make_tool_call(
|
||||
"complex_tool", {"level1": {"level2": {"level3": {"value": 123}}}}
|
||||
)
|
||||
],
|
||||
marks=pytest.mark.xfail(
|
||||
reason="stream parsing not support nested json yet."
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_hunyuan_a13b_tool_parser_streaming(model_deltas, expected_tool_calls):
|
||||
mock_tokenizer = MagicMock()
|
||||
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("hunyuan_a13b")(
|
||||
mock_tokenizer
|
||||
)
|
||||
reconstructor = run_tool_extraction_streaming(
|
||||
tool_parser, model_deltas, assert_one_tool_per_delta=False
|
||||
)
|
||||
|
||||
# align the random id.
|
||||
for idx in range(len(reconstructor.tool_calls)):
|
||||
reconstructor.tool_calls[idx].id = expected_tool_calls[idx].id
|
||||
|
||||
assert reconstructor.tool_calls == expected_tool_calls
|
||||
@@ -1,262 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import ExtractedToolCallInformation
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.llama_tool_parser import Llama3JsonToolParser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parser(default_tokenizer: TokenizerLike):
|
||||
return Llama3JsonToolParser(default_tokenizer)
|
||||
|
||||
|
||||
def test_extract_tool_calls_simple(parser):
|
||||
# Test with a simple tool call
|
||||
model_output = (
|
||||
'Here is the result: {"name": "getOpenIncidentsTool", '
|
||||
'"parameters": {}} Would you like to know more?'
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert isinstance(result, ExtractedToolCallInformation)
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].type == "function"
|
||||
assert result.tool_calls[0].function.name == "getOpenIncidentsTool"
|
||||
assert result.tool_calls[0].function.arguments == "{}"
|
||||
assert result.content is None
|
||||
|
||||
|
||||
def test_extract_tool_calls_with_arguments(parser):
|
||||
# Test with a tool call that has arguments
|
||||
model_output = (
|
||||
'{"name": "searchTool", "parameters": {"query": "test query", "limit": 10}}'
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "searchTool"
|
||||
assert '"query": "test query"' in result.tool_calls[0].function.arguments
|
||||
assert '"limit": 10' in result.tool_calls[0].function.arguments
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_json(parser):
|
||||
# Test with text that doesn't contain a JSON object
|
||||
model_output = "This is just some text without any tool calls"
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is False
|
||||
assert len(result.tool_calls) == 0
|
||||
assert result.content == model_output
|
||||
|
||||
|
||||
def test_extract_tool_calls_invalid_json(parser):
|
||||
# Test with invalid JSON
|
||||
model_output = '{"name": "invalidTool", "parameters": {invalid json}'
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is False
|
||||
assert len(result.tool_calls) == 0
|
||||
assert result.content == model_output
|
||||
|
||||
|
||||
def test_extract_tool_calls_with_arguments_key(parser):
|
||||
# Test with a tool call that uses "arguments" instead of "parameters"
|
||||
model_output = '{"name": "searchTool", "arguments": {"query": "test"}}'
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "searchTool"
|
||||
assert '"query": "test"' in result.tool_calls[0].function.arguments
|
||||
|
||||
|
||||
def test_extract_tool_calls_multiple_json(parser):
|
||||
# Test with multiple JSONs separated by semicolons
|
||||
model_output = (
|
||||
'{"name": "searchTool", "parameters": {"query": "test1"}}; '
|
||||
'{"name": "getOpenIncidentsTool", "parameters": {}}; '
|
||||
'{"name": "searchTool", "parameters": {"query": "test2"}}'
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 3
|
||||
|
||||
# Check first tool call
|
||||
assert result.tool_calls[0].function.name == "searchTool"
|
||||
assert '"query": "test1"' in result.tool_calls[0].function.arguments
|
||||
|
||||
# Check second tool call
|
||||
assert result.tool_calls[1].function.name == "getOpenIncidentsTool"
|
||||
assert result.tool_calls[1].function.arguments == "{}"
|
||||
|
||||
# Check third tool call
|
||||
assert result.tool_calls[2].function.name == "searchTool"
|
||||
assert '"query": "test2"' in result.tool_calls[2].function.arguments
|
||||
|
||||
|
||||
def test_extract_tool_calls_multiple_json_with_whitespace(parser):
|
||||
# Test with multiple JSONs separated by semicolons and extra whitespace
|
||||
model_output = (
|
||||
'{"name": "searchTool", "parameters": {"query": "test1"}} ; '
|
||||
'{"name": "getOpenIncidentsTool", "parameters": {}} ; '
|
||||
'{"name": "searchTool", "parameters": {"query": "test2"}}'
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 3
|
||||
assert result.tool_calls[0].function.name == "searchTool"
|
||||
assert result.tool_calls[1].function.name == "getOpenIncidentsTool"
|
||||
assert result.tool_calls[2].function.name == "searchTool"
|
||||
|
||||
|
||||
def test_extract_tool_calls_multiple_json_with_surrounding_text(parser):
|
||||
# Test with multiple JSONs and surrounding text
|
||||
model_output = (
|
||||
"Here are the results: "
|
||||
'{"name": "searchTool", "parameters": {"query": "test1"}}; '
|
||||
'{"name": "getOpenIncidentsTool", "parameters": {}}; '
|
||||
'{"name": "searchTool", "parameters": {"query": "test2"}} '
|
||||
"Would you like to know more?"
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 3
|
||||
assert result.tool_calls[0].function.name == "searchTool"
|
||||
assert result.tool_calls[1].function.name == "getOpenIncidentsTool"
|
||||
assert result.tool_calls[2].function.name == "searchTool"
|
||||
|
||||
|
||||
def test_extract_tool_calls_deeply_nested_json(parser):
|
||||
# Test with deeply nested JSON parameters (5 levels)
|
||||
model_output = (
|
||||
'{"name": "complexTool", '
|
||||
'"parameters": {'
|
||||
'"level1": {'
|
||||
'"level2": {'
|
||||
'"level3": {'
|
||||
'"level4": {'
|
||||
'"value": "deep"'
|
||||
"}}}}}}"
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "complexTool"
|
||||
# Verify the nested structure is preserved in the arguments
|
||||
import json
|
||||
|
||||
args = json.loads(result.tool_calls[0].function.arguments)
|
||||
assert args["level1"]["level2"]["level3"]["level4"]["value"] == "deep"
|
||||
|
||||
|
||||
def test_extract_tool_calls_multiple_with_deep_nesting(parser):
|
||||
# Test with multiple tool calls where some have deeply nested parameters
|
||||
model_output = (
|
||||
'{"name": "simpleTool", "parameters": {"value": "test"}}; '
|
||||
'{"name": "complexTool", "parameters": '
|
||||
'{"config": {"database": {"connection": {"pool": {"size": 10}}}}}}'
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 2
|
||||
|
||||
# Check first tool call
|
||||
assert result.tool_calls[0].function.name == "simpleTool"
|
||||
import json
|
||||
|
||||
args0 = json.loads(result.tool_calls[0].function.arguments)
|
||||
assert args0["value"] == "test"
|
||||
|
||||
# Check second tool call with deep nesting
|
||||
assert result.tool_calls[1].function.name == "complexTool"
|
||||
args1 = json.loads(result.tool_calls[1].function.arguments)
|
||||
assert args1["config"]["database"]["connection"]["pool"]["size"] == 10
|
||||
|
||||
|
||||
def test_extract_tool_calls_with_quotes_and_brackets_in_string(parser):
|
||||
# Test with quotes and brackets inside quoted string values
|
||||
model_output = (
|
||||
'{"name": "searchTool", '
|
||||
'"parameters": {'
|
||||
'"query": "test {value} [complex]",'
|
||||
'"nested": {"inner": "more {brackets}"}'
|
||||
"}}"
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "searchTool"
|
||||
# Verify the string values are preserved including brackets and quotes
|
||||
import json
|
||||
|
||||
args = json.loads(result.tool_calls[0].function.arguments)
|
||||
assert args["query"] == "test {value} [complex]"
|
||||
assert args["nested"]["inner"] == "more {brackets}"
|
||||
|
||||
|
||||
def test_extract_tool_calls_with_escaped_quotes_in_nested_json(parser):
|
||||
# Test with escaped quotes in deeply nested JSON
|
||||
model_output = (
|
||||
'{"name": "parserTool", "parameters": {"text": "He said \\"Hello {world}\\""}}'
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "parserTool"
|
||||
# Verify escaped quotes are preserved
|
||||
import json
|
||||
|
||||
args = json.loads(result.tool_calls[0].function.arguments)
|
||||
assert args["text"] == 'He said "Hello {world}"'
|
||||
|
||||
|
||||
def test_extract_tool_calls_missing_name_key(parser):
|
||||
# Test that missing "name" key returns content
|
||||
model_output = '{"parameters": {}}'
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is False
|
||||
assert len(result.tool_calls) == 0
|
||||
assert result.content == model_output
|
||||
|
||||
|
||||
def test_extract_tool_calls_missing_parameters_and_arguments_key(parser):
|
||||
# Test that missing both "parameters" and "arguments" keys returns content
|
||||
model_output = '{"name": "toolWithoutParams"}'
|
||||
result = parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert result.tools_called is False
|
||||
assert len(result.tool_calls) == 0
|
||||
assert result.content == model_output
|
||||
|
||||
|
||||
def test_regex_timeout_handling(parser):
|
||||
"""Test regex timeout is handled gracefully"""
|
||||
fake_problematic_input = "{hello world[A(A=" + "\t)A(A=,\t" * 2
|
||||
|
||||
# create a mock regex that raises TimeoutError
|
||||
mock_regex = MagicMock()
|
||||
mock_regex.finditer.side_effect = TimeoutError("Regex timeout")
|
||||
|
||||
with patch.object(parser, "tool_call_start_regex", mock_regex):
|
||||
result = parser.extract_tool_calls(fake_problematic_input, None)
|
||||
|
||||
# should treat as regular text when regex times out
|
||||
assert result.content == fake_problematic_input
|
||||
assert result.tools_called is False
|
||||
assert len(result.tool_calls) == 0
|
||||
mock_regex.finditer.assert_called_once()
|
||||
@@ -1,269 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.tool_parsers.utils import (
|
||||
run_tool_extraction,
|
||||
run_tool_extraction_streaming,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import FunctionCall
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers import ToolParser, ToolParserManager
|
||||
|
||||
# Test cases similar to pythonic parser but with Llama4 specific format
|
||||
SIMPLE_FUNCTION_OUTPUT = "[get_weather(city='LA', metric='C')]"
|
||||
SIMPLE_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "LA", "metric": "C"}',
|
||||
)
|
||||
MORE_TYPES_FUNCTION_OUTPUT = (
|
||||
"[register_user(name='Doe', "
|
||||
"age=9, "
|
||||
"address={'city': 'LA', 'state': 'CA'}, "
|
||||
"role=None, "
|
||||
"passed_test=True, "
|
||||
"aliases=['John', 'Johnny'])]"
|
||||
)
|
||||
MORE_TYPES_FUNCTION_CALL = FunctionCall(
|
||||
name="register_user",
|
||||
arguments='{"name": "Doe", '
|
||||
'"age": 9, '
|
||||
'"address": {"city": "LA", "state": "CA"}, '
|
||||
'"role": null, '
|
||||
'"passed_test": true, '
|
||||
'"aliases": ["John", "Johnny"]}',
|
||||
)
|
||||
PARAMETERLESS_FUNCTION_OUTPUT = "[get_weather()]"
|
||||
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments="{}",
|
||||
)
|
||||
EMPTY_DICT_FUNCTION_OUTPUT = "[do_something_cool(additional_data={})]"
|
||||
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
|
||||
name="do_something_cool",
|
||||
arguments='{"additional_data": {}}',
|
||||
)
|
||||
EMPTY_LIST_FUNCTION_OUTPUT = "[do_something_cool(steps=[])]"
|
||||
EMPTY_LIST_FUNCTION_CALL = FunctionCall(
|
||||
name="do_something_cool",
|
||||
arguments='{"steps": []}',
|
||||
)
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT = (
|
||||
r"[get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')]"
|
||||
)
|
||||
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
|
||||
)
|
||||
PYTHON_TAG_FUNCTION_OUTPUT = (
|
||||
"<|python_start|>[get_weather(city='LA', metric='C')]<|python_end|>"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||
default_tokenizer
|
||||
)
|
||||
model_output = "How can I help you today?"
|
||||
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
|
||||
assert content == model_output
|
||||
assert len(tool_calls) == 0
|
||||
|
||||
|
||||
test_str = "<|python_start|>"
|
||||
test_str += "[get_weather(city='LA', metric='C'),"
|
||||
test_str += "register_user(name='Doe', age=9)]"
|
||||
TEST_CASES = [
|
||||
pytest.param(
|
||||
True,
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT,
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="simple_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False, SIMPLE_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], id="simple_nonstreaming"
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MORE_TYPES_FUNCTION_OUTPUT,
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MORE_TYPES_FUNCTION_OUTPUT,
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT,
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT,
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
EMPTY_DICT_FUNCTION_OUTPUT,
|
||||
[EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
EMPTY_DICT_FUNCTION_OUTPUT,
|
||||
[EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
EMPTY_LIST_FUNCTION_OUTPUT,
|
||||
[EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
EMPTY_LIST_FUNCTION_OUTPUT,
|
||||
[EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT,
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT,
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
"[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]",
|
||||
[
|
||||
SIMPLE_FUNCTION_CALL,
|
||||
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
|
||||
],
|
||||
id="parallel_calls_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
"[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]",
|
||||
[
|
||||
SIMPLE_FUNCTION_CALL,
|
||||
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
|
||||
],
|
||||
id="parallel_calls_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
PYTHON_TAG_FUNCTION_OUTPUT,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
id="python_tag_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
PYTHON_TAG_FUNCTION_OUTPUT,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
id="python_tag_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
test_str,
|
||||
[
|
||||
SIMPLE_FUNCTION_CALL,
|
||||
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
|
||||
],
|
||||
id="parallel_calls_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
"<|python_start|>[get_weather(city='LA', metric='C'), "
|
||||
"register_user(name='Doe', age=9)]",
|
||||
[
|
||||
SIMPLE_FUNCTION_CALL,
|
||||
FunctionCall(name="register_user", arguments='{"name": "Doe", "age": 9}'),
|
||||
],
|
||||
id="parallel_calls_nonstreaming",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
|
||||
def test_tool_call(
|
||||
streaming: bool,
|
||||
model_output: str,
|
||||
expected_tool_calls: list[FunctionCall],
|
||||
default_tokenizer: TokenizerLike,
|
||||
):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||
default_tokenizer
|
||||
)
|
||||
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
|
||||
assert len(tool_calls) == len(expected_tool_calls)
|
||||
for actual, expected in zip(tool_calls, expected_tool_calls):
|
||||
assert actual.type == "function"
|
||||
assert actual.function == expected
|
||||
|
||||
|
||||
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||
default_tokenizer
|
||||
)
|
||||
model_output_deltas = [
|
||||
"<|python_start|>[get_weather(city='LA', metric='C'), "
|
||||
"get_weather(), "
|
||||
"do_something_cool(steps=[])]<|python_end|>",
|
||||
]
|
||||
|
||||
reconstructor = run_tool_extraction_streaming(
|
||||
tool_parser, model_output_deltas, assert_one_tool_per_delta=False
|
||||
)
|
||||
|
||||
assert reconstructor.other_content == ""
|
||||
assert len(reconstructor.tool_calls) == 3
|
||||
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [False])
|
||||
def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
"""test regex timeout is handled gracefully"""
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("llama4_pythonic")(
|
||||
default_tokenizer
|
||||
)
|
||||
|
||||
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
|
||||
|
||||
# create a mock regex that raises TimeoutError
|
||||
mock_regex = MagicMock()
|
||||
mock_regex.match.side_effect = TimeoutError("Regex timeout")
|
||||
|
||||
with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex):
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, fake_problematic_input, streaming=streaming
|
||||
)
|
||||
|
||||
# should treat as regular text when regex times out
|
||||
assert content == fake_problematic_input
|
||||
assert len(tool_calls) == 0
|
||||
mock_regex.match.assert_called_once()
|
||||
@@ -1,251 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.tool_parsers.utils import (
|
||||
run_tool_extraction,
|
||||
run_tool_extraction_streaming,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import FunctionCall
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers import ToolParser, ToolParserManager
|
||||
|
||||
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
|
||||
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
|
||||
SIMPLE_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "San Francisco", "metric": "celsius"}',
|
||||
)
|
||||
MORE_TYPES_FUNCTION_OUTPUT = (
|
||||
"register_user(name='John Doe', "
|
||||
"age=37, "
|
||||
"address={'city': 'San Francisco', 'state': 'CA'}, "
|
||||
"role=None, "
|
||||
"passed_test=True, "
|
||||
"aliases=['John', 'Johnny'])"
|
||||
)
|
||||
MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS = (
|
||||
"register_user(name='John Doe', "
|
||||
"age=37, "
|
||||
"address={'city': 'San Francisco', 'state': 'CA'}, "
|
||||
"role=null, "
|
||||
"passed_test=true, "
|
||||
"aliases=['John', 'Johnny'])"
|
||||
)
|
||||
MORE_TYPES_FUNCTION_CALL = FunctionCall(
|
||||
name="register_user",
|
||||
arguments='{"name": "John Doe", '
|
||||
'"age": 37, '
|
||||
'"address": {"city": "San Francisco", "state": "CA"}, '
|
||||
'"role": null, '
|
||||
'"passed_test": true, '
|
||||
'"aliases": ["John", "Johnny"]}',
|
||||
)
|
||||
PARAMETERLESS_FUNCTION_OUTPUT = "get_weather()"
|
||||
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments="{}",
|
||||
)
|
||||
EMPTY_DICT_FUNCTION_OUTPUT = "do_something_cool(additional_data={})"
|
||||
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
|
||||
name="do_something_cool",
|
||||
arguments='{"additional_data": {}}',
|
||||
)
|
||||
EMPTY_LIST_FUNCTION_OUTPUT = "do_something_cool(steps=[])"
|
||||
EMPTY_LIST_FUNCTION_CALL = FunctionCall(
|
||||
name="do_something_cool",
|
||||
arguments='{"steps": []}',
|
||||
)
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT = (
|
||||
r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')"
|
||||
)
|
||||
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
|
||||
default_tokenizer
|
||||
)
|
||||
model_output = "How can I help you today?"
|
||||
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
|
||||
assert content == model_output
|
||||
assert len(tool_calls) == 0
|
||||
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}</function_calls>",
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
id="simple_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}</function_calls>",
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
id="simple_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS}</function_calls>",
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_streaming_json_literals",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{MORE_TYPES_FUNCTION_OUTPUT_JSON_LITERALS}</function_calls>",
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_nonstreaming_json_literals",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{PARAMETERLESS_FUNCTION_OUTPUT}</function_calls>",
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{PARAMETERLESS_FUNCTION_OUTPUT}</function_calls>",
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{EMPTY_DICT_FUNCTION_OUTPUT}</function_calls>",
|
||||
[EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{EMPTY_DICT_FUNCTION_OUTPUT}</function_calls>",
|
||||
[EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
|
||||
[EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
|
||||
[EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{ESCAPED_STRING_FUNCTION_OUTPUT}</function_calls>",
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{ESCAPED_STRING_FUNCTION_OUTPUT}</function_calls>",
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}\n{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
|
||||
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
|
||||
id="parallel_calls_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"<function_calls>{SIMPLE_FUNCTION_OUTPUT}\n{MORE_TYPES_FUNCTION_OUTPUT}</function_calls>",
|
||||
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
|
||||
id="parallel_calls_nonstreaming",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
|
||||
def test_tool_call(
|
||||
streaming: bool,
|
||||
model_output: str,
|
||||
expected_tool_calls: list[FunctionCall],
|
||||
default_tokenizer: TokenizerLike,
|
||||
):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
|
||||
default_tokenizer
|
||||
)
|
||||
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
|
||||
assert content is None
|
||||
assert len(tool_calls) == len(expected_tool_calls)
|
||||
for actual, expected in zip(tool_calls, expected_tool_calls):
|
||||
assert actual.type == "function"
|
||||
assert actual.function == expected
|
||||
|
||||
|
||||
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
|
||||
default_tokenizer
|
||||
)
|
||||
model_output_deltas = [
|
||||
"<function_calls>get_weather(city='San",
|
||||
" Francisco', metric='celsius')\n"
|
||||
f"{PARAMETERLESS_FUNCTION_OUTPUT}\n"
|
||||
f"{EMPTY_LIST_FUNCTION_OUTPUT}</function_calls>",
|
||||
]
|
||||
|
||||
reconstructor = run_tool_extraction_streaming(
|
||||
tool_parser, model_output_deltas, assert_one_tool_per_delta=False
|
||||
)
|
||||
|
||||
assert reconstructor.other_content == ""
|
||||
assert len(reconstructor.tool_calls) == 3
|
||||
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [False])
|
||||
def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
"""test regex timeout is handled gracefully"""
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("olmo3")(
|
||||
default_tokenizer
|
||||
)
|
||||
|
||||
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
|
||||
|
||||
# create a mock regex that raises TimeoutError
|
||||
mock_regex = MagicMock()
|
||||
mock_regex.match.side_effect = TimeoutError("Regex timeout")
|
||||
|
||||
with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex):
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, fake_problematic_input, streaming=streaming
|
||||
)
|
||||
|
||||
# should treat as regular text when regex times out
|
||||
assert content == fake_problematic_input
|
||||
assert len(tool_calls) == 0
|
||||
mock_regex.match.assert_called_once()
|
||||
@@ -1,231 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.tool_parsers.utils import (
|
||||
run_tool_extraction,
|
||||
run_tool_extraction_streaming,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import FunctionCall
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers import ToolParser, ToolParserManager
|
||||
|
||||
# https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#model-response-format-1
|
||||
SIMPLE_FUNCTION_OUTPUT = "get_weather(city='San Francisco', metric='celsius')"
|
||||
SIMPLE_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "San Francisco", "metric": "celsius"}',
|
||||
)
|
||||
MORE_TYPES_FUNCTION_OUTPUT = (
|
||||
"register_user(name='John Doe', "
|
||||
"age=37, "
|
||||
"address={'city': 'San Francisco', 'state': 'CA'}, "
|
||||
"role=None, "
|
||||
"passed_test=True, "
|
||||
"aliases=['John', 'Johnny'])"
|
||||
)
|
||||
MORE_TYPES_FUNCTION_CALL = FunctionCall(
|
||||
name="register_user",
|
||||
arguments='{"name": "John Doe", '
|
||||
'"age": 37, '
|
||||
'"address": {"city": "San Francisco", "state": "CA"}, '
|
||||
'"role": null, '
|
||||
'"passed_test": true, '
|
||||
'"aliases": ["John", "Johnny"]}',
|
||||
)
|
||||
PARAMETERLESS_FUNCTION_OUTPUT = "get_weather()"
|
||||
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments="{}",
|
||||
)
|
||||
EMPTY_DICT_FUNCTION_OUTPUT = "do_something_cool(additional_data={})"
|
||||
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
|
||||
name="do_something_cool",
|
||||
arguments='{"additional_data": {}}',
|
||||
)
|
||||
EMPTY_LIST_FUNCTION_OUTPUT = "do_something_cool(steps=[])"
|
||||
EMPTY_LIST_FUNCTION_CALL = FunctionCall(
|
||||
name="do_something_cool",
|
||||
arguments='{"steps": []}',
|
||||
)
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT = (
|
||||
r"get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')"
|
||||
)
|
||||
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||
default_tokenizer
|
||||
)
|
||||
model_output = "How can I help you today?"
|
||||
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
|
||||
assert content == model_output
|
||||
assert len(tool_calls) == 0
|
||||
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}]",
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
id="simple_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}]",
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
id="simple_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{MORE_TYPES_FUNCTION_OUTPUT}]",
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{MORE_TYPES_FUNCTION_OUTPUT}]",
|
||||
[MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]",
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{PARAMETERLESS_FUNCTION_OUTPUT}]",
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]",
|
||||
[EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{EMPTY_DICT_FUNCTION_OUTPUT}]",
|
||||
[EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]",
|
||||
[EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{EMPTY_LIST_FUNCTION_OUTPUT}]",
|
||||
[EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]",
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{ESCAPED_STRING_FUNCTION_OUTPUT}]",
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_nonstreaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]",
|
||||
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
|
||||
id="parallel_calls_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]",
|
||||
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL],
|
||||
id="parallel_calls_nonstreaming",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", TEST_CASES)
|
||||
def test_tool_call(
|
||||
streaming: bool,
|
||||
model_output: str,
|
||||
expected_tool_calls: list[FunctionCall],
|
||||
default_tokenizer: TokenizerLike,
|
||||
):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||
default_tokenizer
|
||||
)
|
||||
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
|
||||
assert content is None
|
||||
assert len(tool_calls) == len(expected_tool_calls)
|
||||
for actual, expected in zip(tool_calls, expected_tool_calls):
|
||||
assert actual.type == "function"
|
||||
assert actual.function == expected
|
||||
|
||||
|
||||
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||
default_tokenizer
|
||||
)
|
||||
model_output_deltas = [
|
||||
"[get_weather(city='San",
|
||||
" Francisco', metric='celsius'), "
|
||||
f"{PARAMETERLESS_FUNCTION_OUTPUT}, "
|
||||
f"{EMPTY_LIST_FUNCTION_OUTPUT}]",
|
||||
]
|
||||
|
||||
reconstructor = run_tool_extraction_streaming(
|
||||
tool_parser, model_output_deltas, assert_one_tool_per_delta=False
|
||||
)
|
||||
|
||||
assert reconstructor.other_content == ""
|
||||
assert len(reconstructor.tool_calls) == 3
|
||||
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [False])
|
||||
def test_regex_timeout_handling(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
"""test regex timeout is handled gracefully"""
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("pythonic")(
|
||||
default_tokenizer
|
||||
)
|
||||
|
||||
fake_problematic_input = "hello world[A(A=" + "\t)A(A=,\t" * 2
|
||||
|
||||
# create a mock regex that raises TimeoutError
|
||||
mock_regex = MagicMock()
|
||||
mock_regex.match.side_effect = TimeoutError("Regex timeout")
|
||||
|
||||
with patch.object(tool_parser, "TOOL_CALL_REGEX", mock_regex):
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, fake_problematic_input, streaming=streaming
|
||||
)
|
||||
|
||||
# should treat as regular text when regex times out
|
||||
assert content == fake_problematic_input
|
||||
assert len(tool_calls) == 0
|
||||
mock_regex.match.assert_called_once()
|
||||
Reference in New Issue
Block a user