[Test] Consolidate tool parser unit tests to tests/tool_parsers (#37834)
Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
@@ -1,12 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def default_tokenizer() -> TokenizerLike:
|
||||
return AutoTokenizer.from_pretrained("gpt2")
|
||||
@@ -1,18 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
import random
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaMessage,
|
||||
)
|
||||
from vllm.tool_parsers.granite4_tool_parser import Granite4ToolParser
|
||||
|
||||
from ....utils import RemoteOpenAIServer
|
||||
|
||||
@@ -38,137 +29,6 @@ def server():
|
||||
yield server
|
||||
|
||||
|
||||
def create_complex_input(create_string_args: bool):
|
||||
coord_arg: dict | str = {
|
||||
"coordinates": [[23.54, 43.1], [-12.2, 54.3], [4, 5]],
|
||||
"coordinate_type": "latlong",
|
||||
}
|
||||
if create_string_args:
|
||||
# test granite behavior
|
||||
coord_arg = json.dumps(coord_arg)
|
||||
return [
|
||||
{"name": "find_bbox", "arguments": coord_arg},
|
||||
{
|
||||
"name": "get_stock_price",
|
||||
"arguments": {
|
||||
"symbol": "AAPL",
|
||||
"start_date": "2021-01-01",
|
||||
"end_date": "2021-12-31",
|
||||
},
|
||||
},
|
||||
{"name": "find_bbox", "arguments": coord_arg},
|
||||
]
|
||||
|
||||
|
||||
def random_chunks(s: str, min_len: int, max_len: int):
|
||||
chunks = []
|
||||
i = 0
|
||||
n = len(s)
|
||||
|
||||
while i < n:
|
||||
size = random.randint(min_len, max_len)
|
||||
chunks.append(s[i : i + size])
|
||||
i += size
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def tokenizer():
|
||||
return AutoTokenizer.from_pretrained(MODEL)
|
||||
|
||||
|
||||
# create a variety of input chunk sizes
|
||||
@pytest.mark.parametrize(
|
||||
"min_chunk, max_chunk",
|
||||
[
|
||||
(1, 1),
|
||||
(1, 2),
|
||||
(5, 7),
|
||||
(6, 20),
|
||||
],
|
||||
)
|
||||
def test_tool_call_parser_complex(min_chunk: int, max_chunk: int, tokenizer):
|
||||
input_dicts = create_complex_input(True)
|
||||
|
||||
formatted_tcs = [
|
||||
"<tool_call> " + json.dumps(call) + " </tool_call>" for call in input_dicts
|
||||
]
|
||||
|
||||
text_messages = [
|
||||
"Here goes the bbox call: \n",
|
||||
" Now the stock price call: \n ",
|
||||
" Now another bbox call: \n ",
|
||||
" See? I'm a helpful assistant.",
|
||||
]
|
||||
|
||||
test_input = (
|
||||
text_messages[0]
|
||||
+ formatted_tcs[0]
|
||||
+ text_messages[1]
|
||||
+ formatted_tcs[1]
|
||||
+ text_messages[2]
|
||||
+ formatted_tcs[2]
|
||||
+ text_messages[3]
|
||||
)
|
||||
|
||||
any_chat_request = ChatCompletionRequest(
|
||||
seed=42,
|
||||
model=MODEL,
|
||||
messages=[],
|
||||
)
|
||||
|
||||
parser = Granite4ToolParser(tokenizer=tokenizer)
|
||||
|
||||
delta_messages = list[DeltaMessage]()
|
||||
for text in random_chunks(test_input, min_chunk, max_chunk):
|
||||
delta = parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text="",
|
||||
delta_text=text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=any_chat_request,
|
||||
)
|
||||
if delta is not None:
|
||||
delta_messages.append(delta)
|
||||
|
||||
content = ""
|
||||
tool_calls = list[dict[str, Any]]()
|
||||
|
||||
current_name = "__start__"
|
||||
current_args = ""
|
||||
|
||||
for msg in delta_messages:
|
||||
if msg.content:
|
||||
content += msg.content
|
||||
for tool_call in msg.tool_calls:
|
||||
if delta_func := tool_call.function:
|
||||
if delta_func.name is not None:
|
||||
if current_name == "__start__":
|
||||
current_name = delta_func.name
|
||||
|
||||
if delta_func.name != current_name:
|
||||
tool_calls.append(
|
||||
{
|
||||
"name": current_name,
|
||||
"arguments": json.loads(current_args),
|
||||
}
|
||||
)
|
||||
current_name = delta_func.name
|
||||
current_args = ""
|
||||
|
||||
if delta_func.arguments:
|
||||
current_args += delta_func.arguments
|
||||
|
||||
if current_name != "__start__":
|
||||
tool_calls.append({"name": current_name, "arguments": json.loads(current_args)})
|
||||
|
||||
assert content == "".join(text_messages)
|
||||
assert tool_calls == create_complex_input(False)
|
||||
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
|
||||
@@ -9,8 +9,6 @@ import pytest_asyncio
|
||||
from huggingface_hub import snapshot_download
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import ToolParser
|
||||
from vllm.tool_parsers.granite4_tool_parser import Granite4ToolParser
|
||||
from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
|
||||
@@ -325,202 +323,3 @@ async def test_streaming_product_tool_call(
|
||||
print("\n[Streaming Product Test Passed]")
|
||||
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
|
||||
print(f"Reconstructed Arguments: {arguments}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qwen_tokenizer() -> TokenizerLike:
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
|
||||
return get_tokenizer("Qwen/Qwen3-32B")
|
||||
|
||||
|
||||
@pytest.fixture(params=CONFIGS.keys())
|
||||
def hermes_parser(request, qwen_tokenizer: TokenizerLike) -> ToolParser:
|
||||
config = CONFIGS[request.param]
|
||||
return config["tool_parser"](qwen_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def any_chat_request() -> ChatCompletionRequest:
|
||||
return ChatCompletionRequest(
|
||||
seed=42,
|
||||
model="Qwen/Qwen3-32B",
|
||||
messages=[],
|
||||
)
|
||||
|
||||
|
||||
def test_hermes_parser_streaming_just_forward_text(
|
||||
qwen_tokenizer: TokenizerLike,
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """This is some prior text that has nothing to do with tool calling."""
|
||||
tokens = qwen_tokenizer.encode(text)
|
||||
previous_text = ""
|
||||
delta_messages = []
|
||||
for token in tokens:
|
||||
delta_text = qwen_tokenizer.decode([token])
|
||||
current_text = previous_text + delta_text
|
||||
delta = hermes_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=any_chat_request,
|
||||
)
|
||||
previous_text = current_text
|
||||
delta_messages.append(delta)
|
||||
|
||||
for delta in delta_messages:
|
||||
assert delta is not None
|
||||
assert not delta.tool_calls
|
||||
|
||||
print(delta_messages)
|
||||
assert "".join([delta.content for delta in delta_messages]) == text
|
||||
|
||||
|
||||
def test_hermes_parser_streaming_failure_case_bug_19056(
|
||||
qwen_tokenizer: TokenizerLike,
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """<tool_call>
|
||||
{"name": "final_answer", "arguments": {"trigger": true}}
|
||||
</tool_call>"""
|
||||
tokens = qwen_tokenizer.encode(text)
|
||||
previous_text = ""
|
||||
delta_messages = []
|
||||
for token in tokens:
|
||||
text = qwen_tokenizer.decode([token])
|
||||
current_text = previous_text + text
|
||||
delta = hermes_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=any_chat_request,
|
||||
)
|
||||
previous_text = current_text
|
||||
if delta is not None:
|
||||
delta_messages.append(delta)
|
||||
|
||||
assert delta_messages[0].tool_calls[0].function.name == "final_answer"
|
||||
tool_call_args = "".join(
|
||||
delta.tool_calls[0].function.arguments or "" for delta in delta_messages
|
||||
)
|
||||
assert tool_call_args == '{"trigger": true}'
|
||||
|
||||
|
||||
def test_hermes_parser_streaming(
|
||||
qwen_tokenizer: TokenizerLike,
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = '<tool_call>\
|
||||
{"name": "get_current_temperature",\
|
||||
"arguments": {"location":\
|
||||
"San Francisco, California, United States", "unit": "celsius"}}\
|
||||
</tool_call>'
|
||||
|
||||
tokens = qwen_tokenizer.encode(text)
|
||||
previous_text = ""
|
||||
delta_messages = []
|
||||
for token in tokens:
|
||||
text = qwen_tokenizer.decode([token])
|
||||
current_text = previous_text + text
|
||||
delta = hermes_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=any_chat_request,
|
||||
)
|
||||
previous_text = current_text
|
||||
if delta is not None:
|
||||
delta_messages.append(delta)
|
||||
print(delta_messages)
|
||||
assert delta_messages[0].tool_calls[0].function.name == "get_current_temperature"
|
||||
# load to normalize whitespace
|
||||
tool_call_args = json.loads(
|
||||
"".join(
|
||||
delta.tool_calls[0].function.arguments or "" for delta in delta_messages
|
||||
)
|
||||
)
|
||||
assert tool_call_args == {
|
||||
"location": "San Francisco, California, United States",
|
||||
"unit": "celsius",
|
||||
}
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_no_tool_call(
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """This is not a tool call."""
|
||||
tool_call = hermes_parser.extract_tool_calls(
|
||||
model_output=text,
|
||||
request=any_chat_request,
|
||||
)
|
||||
|
||||
assert tool_call is not None
|
||||
assert not tool_call.tools_called
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_tool_call_between_tags(
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """<tool_call>
|
||||
{"name": "final_answer", "arguments": {"trigger": true}}
|
||||
</tool_call>"""
|
||||
tool_call = hermes_parser.extract_tool_calls(
|
||||
model_output=text,
|
||||
request=any_chat_request,
|
||||
)
|
||||
|
||||
assert tool_call is not None
|
||||
assert tool_call.tools_called
|
||||
assert tool_call.tool_calls[0].function.name == "final_answer"
|
||||
assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}'
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_tool_call_until_eos(
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
if isinstance(hermes_parser, Granite4ToolParser):
|
||||
pytest.skip(reason="The Granite4 tool parser enforces a complete response")
|
||||
|
||||
text = """<tool_call>
|
||||
{"name": "final_answer", "arguments": {"trigger": true}}"""
|
||||
tool_call = hermes_parser.extract_tool_calls(
|
||||
model_output=text,
|
||||
request=any_chat_request,
|
||||
)
|
||||
|
||||
assert tool_call is not None
|
||||
assert tool_call.tools_called
|
||||
assert tool_call.tool_calls[0].function.name == "final_answer"
|
||||
assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}'
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_tool_call_invalid_json(
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
# Missing closing brace to trigger exception
|
||||
text = """<tool_call>
|
||||
{"name": "final_answer", "arguments": {"trigger": true}"""
|
||||
tool_call = hermes_parser.extract_tool_calls(
|
||||
model_output=text,
|
||||
request=any_chat_request,
|
||||
)
|
||||
|
||||
assert tool_call is not None
|
||||
assert not tool_call.tools_called
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.tool_parsers.utils import (
|
||||
run_tool_extraction,
|
||||
@@ -13,6 +14,14 @@ from vllm.entrypoints.openai.engine.protocol import FunctionCall
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers import ToolParser, ToolParserManager
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def default_tokenizer() -> TokenizerLike:
|
||||
"""Override module-scoped default_tokenizer because gigachat tests
|
||||
mutate the tokenizer via ``add_tokens``."""
|
||||
return AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
|
||||
MSG_SEP_TOKEN = "<|message_sep|>\n\n"
|
||||
ROLE_SEP_TOKEN = "<|role_sep|>\n"
|
||||
EOS_TOKEN = "</s>"
|
||||
147
tests/tool_parsers/test_granite4_tool_parser.py
Normal file
147
tests/tool_parsers/test_granite4_tool_parser.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
import random
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaMessage,
|
||||
)
|
||||
from vllm.tool_parsers.granite4_tool_parser import Granite4ToolParser
|
||||
|
||||
MODEL = "ibm-granite/granite-4.0-h-tiny"
|
||||
|
||||
|
||||
def create_complex_input(create_string_args: bool):
|
||||
coord_arg: dict | str = {
|
||||
"coordinates": [[23.54, 43.1], [-12.2, 54.3], [4, 5]],
|
||||
"coordinate_type": "latlong",
|
||||
}
|
||||
if create_string_args:
|
||||
# test granite behavior
|
||||
coord_arg = json.dumps(coord_arg)
|
||||
return [
|
||||
{"name": "find_bbox", "arguments": coord_arg},
|
||||
{
|
||||
"name": "get_stock_price",
|
||||
"arguments": {
|
||||
"symbol": "AAPL",
|
||||
"start_date": "2021-01-01",
|
||||
"end_date": "2021-12-31",
|
||||
},
|
||||
},
|
||||
{"name": "find_bbox", "arguments": coord_arg},
|
||||
]
|
||||
|
||||
|
||||
def random_chunks(s: str, min_len: int, max_len: int):
|
||||
chunks = []
|
||||
i = 0
|
||||
n = len(s)
|
||||
|
||||
while i < n:
|
||||
size = random.randint(min_len, max_len)
|
||||
chunks.append(s[i : i + size])
|
||||
i += size
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def tokenizer():
|
||||
return AutoTokenizer.from_pretrained(MODEL)
|
||||
|
||||
|
||||
# create a variety of input chunk sizes
|
||||
@pytest.mark.parametrize(
|
||||
"min_chunk, max_chunk",
|
||||
[
|
||||
(1, 1),
|
||||
(1, 2),
|
||||
(5, 7),
|
||||
(6, 20),
|
||||
],
|
||||
)
|
||||
def test_tool_call_parser_complex(min_chunk: int, max_chunk: int, tokenizer):
|
||||
input_dicts = create_complex_input(True)
|
||||
|
||||
formatted_tcs = [
|
||||
"<tool_call> " + json.dumps(call) + " </tool_call>" for call in input_dicts
|
||||
]
|
||||
|
||||
text_messages = [
|
||||
"Here goes the bbox call: \n",
|
||||
" Now the stock price call: \n ",
|
||||
" Now another bbox call: \n ",
|
||||
" See? I'm a helpful assistant.",
|
||||
]
|
||||
|
||||
test_input = (
|
||||
text_messages[0]
|
||||
+ formatted_tcs[0]
|
||||
+ text_messages[1]
|
||||
+ formatted_tcs[1]
|
||||
+ text_messages[2]
|
||||
+ formatted_tcs[2]
|
||||
+ text_messages[3]
|
||||
)
|
||||
|
||||
any_chat_request = ChatCompletionRequest(
|
||||
seed=42,
|
||||
model=MODEL,
|
||||
messages=[],
|
||||
)
|
||||
|
||||
parser = Granite4ToolParser(tokenizer=tokenizer)
|
||||
|
||||
delta_messages = list[DeltaMessage]()
|
||||
for text in random_chunks(test_input, min_chunk, max_chunk):
|
||||
delta = parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text="",
|
||||
delta_text=text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=any_chat_request,
|
||||
)
|
||||
if delta is not None:
|
||||
delta_messages.append(delta)
|
||||
|
||||
content = ""
|
||||
tool_calls = list[dict[str, Any]]()
|
||||
|
||||
current_name = "__start__"
|
||||
current_args = ""
|
||||
|
||||
for msg in delta_messages:
|
||||
if msg.content:
|
||||
content += msg.content
|
||||
for tool_call in msg.tool_calls:
|
||||
if delta_func := tool_call.function:
|
||||
if delta_func.name is not None:
|
||||
if current_name == "__start__":
|
||||
current_name = delta_func.name
|
||||
|
||||
if delta_func.name != current_name:
|
||||
tool_calls.append(
|
||||
{
|
||||
"name": current_name,
|
||||
"arguments": json.loads(current_args),
|
||||
}
|
||||
)
|
||||
current_name = delta_func.name
|
||||
current_args = ""
|
||||
|
||||
if delta_func.arguments:
|
||||
current_args += delta_func.arguments
|
||||
|
||||
if current_name != "__start__":
|
||||
tool_calls.append({"name": current_name, "arguments": json.loads(current_args)})
|
||||
|
||||
assert content == "".join(text_messages)
|
||||
assert tool_calls == create_complex_input(False)
|
||||
220
tests/tool_parsers/test_hermes_tool_parser.py
Normal file
220
tests/tool_parsers/test_hermes_tool_parser.py
Normal file
@@ -0,0 +1,220 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import ToolParser
|
||||
from vllm.tool_parsers.granite4_tool_parser import Granite4ToolParser
|
||||
from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
|
||||
|
||||
CONFIGS = {
|
||||
"llama": {
|
||||
"tool_parser": Hermes2ProToolParser,
|
||||
},
|
||||
"granite4": {
|
||||
"tool_parser": Granite4ToolParser,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qwen_tokenizer() -> TokenizerLike:
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
|
||||
return get_tokenizer("Qwen/Qwen3-32B")
|
||||
|
||||
|
||||
@pytest.fixture(params=CONFIGS.keys())
|
||||
def hermes_parser(request, qwen_tokenizer: TokenizerLike) -> ToolParser:
|
||||
config = CONFIGS[request.param]
|
||||
return config["tool_parser"](qwen_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def any_chat_request() -> ChatCompletionRequest:
|
||||
return ChatCompletionRequest(
|
||||
seed=42,
|
||||
model="Qwen/Qwen3-32B",
|
||||
messages=[],
|
||||
)
|
||||
|
||||
|
||||
def test_hermes_parser_streaming_just_forward_text(
|
||||
qwen_tokenizer: TokenizerLike,
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """This is some prior text that has nothing to do with tool calling."""
|
||||
tokens = qwen_tokenizer.encode(text)
|
||||
previous_text = ""
|
||||
delta_messages = []
|
||||
for token in tokens:
|
||||
delta_text = qwen_tokenizer.decode([token])
|
||||
current_text = previous_text + delta_text
|
||||
delta = hermes_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=any_chat_request,
|
||||
)
|
||||
previous_text = current_text
|
||||
delta_messages.append(delta)
|
||||
|
||||
for delta in delta_messages:
|
||||
assert delta is not None
|
||||
assert not delta.tool_calls
|
||||
|
||||
print(delta_messages)
|
||||
assert "".join([delta.content for delta in delta_messages]) == text
|
||||
|
||||
|
||||
def test_hermes_parser_streaming_failure_case_bug_19056(
|
||||
qwen_tokenizer: TokenizerLike,
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """<tool_call>
|
||||
{"name": "final_answer", "arguments": {"trigger": true}}
|
||||
</tool_call>"""
|
||||
tokens = qwen_tokenizer.encode(text)
|
||||
previous_text = ""
|
||||
delta_messages = []
|
||||
for token in tokens:
|
||||
text = qwen_tokenizer.decode([token])
|
||||
current_text = previous_text + text
|
||||
delta = hermes_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=any_chat_request,
|
||||
)
|
||||
previous_text = current_text
|
||||
if delta is not None:
|
||||
delta_messages.append(delta)
|
||||
|
||||
assert delta_messages[0].tool_calls[0].function.name == "final_answer"
|
||||
tool_call_args = "".join(
|
||||
delta.tool_calls[0].function.arguments or "" for delta in delta_messages
|
||||
)
|
||||
assert tool_call_args == '{"trigger": true}'
|
||||
|
||||
|
||||
def test_hermes_parser_streaming(
|
||||
qwen_tokenizer: TokenizerLike,
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = '<tool_call>\
|
||||
{"name": "get_current_temperature",\
|
||||
"arguments": {"location":\
|
||||
"San Francisco, California, United States", "unit": "celsius"}}\
|
||||
</tool_call>'
|
||||
|
||||
tokens = qwen_tokenizer.encode(text)
|
||||
previous_text = ""
|
||||
delta_messages = []
|
||||
for token in tokens:
|
||||
text = qwen_tokenizer.decode([token])
|
||||
current_text = previous_text + text
|
||||
delta = hermes_parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=any_chat_request,
|
||||
)
|
||||
previous_text = current_text
|
||||
if delta is not None:
|
||||
delta_messages.append(delta)
|
||||
print(delta_messages)
|
||||
assert delta_messages[0].tool_calls[0].function.name == "get_current_temperature"
|
||||
# load to normalize whitespace
|
||||
tool_call_args = json.loads(
|
||||
"".join(
|
||||
delta.tool_calls[0].function.arguments or "" for delta in delta_messages
|
||||
)
|
||||
)
|
||||
assert tool_call_args == {
|
||||
"location": "San Francisco, California, United States",
|
||||
"unit": "celsius",
|
||||
}
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_no_tool_call(
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """This is not a tool call."""
|
||||
tool_call = hermes_parser.extract_tool_calls(
|
||||
model_output=text,
|
||||
request=any_chat_request,
|
||||
)
|
||||
|
||||
assert tool_call is not None
|
||||
assert not tool_call.tools_called
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_tool_call_between_tags(
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """<tool_call>
|
||||
{"name": "final_answer", "arguments": {"trigger": true}}
|
||||
</tool_call>"""
|
||||
tool_call = hermes_parser.extract_tool_calls(
|
||||
model_output=text,
|
||||
request=any_chat_request,
|
||||
)
|
||||
|
||||
assert tool_call is not None
|
||||
assert tool_call.tools_called
|
||||
assert tool_call.tool_calls[0].function.name == "final_answer"
|
||||
assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}'
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_tool_call_until_eos(
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
if isinstance(hermes_parser, Granite4ToolParser):
|
||||
pytest.skip(reason="The Granite4 tool parser enforces a complete response")
|
||||
|
||||
text = """<tool_call>
|
||||
{"name": "final_answer", "arguments": {"trigger": true}}"""
|
||||
tool_call = hermes_parser.extract_tool_calls(
|
||||
model_output=text,
|
||||
request=any_chat_request,
|
||||
)
|
||||
|
||||
assert tool_call is not None
|
||||
assert tool_call.tools_called
|
||||
assert tool_call.tool_calls[0].function.name == "final_answer"
|
||||
assert tool_call.tool_calls[0].function.arguments == '{"trigger": true}'
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_tool_call_invalid_json(
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
# Missing closing brace to trigger exception
|
||||
text = """<tool_call>
|
||||
{"name": "final_answer", "arguments": {"trigger": true}"""
|
||||
tool_call = hermes_parser.extract_tool_calls(
|
||||
model_output=text,
|
||||
request=any_chat_request,
|
||||
)
|
||||
|
||||
assert tool_call is not None
|
||||
assert not tool_call.tools_called
|
||||
Reference in New Issue
Block a user