Add simple granite4 tool parser (#36827)
Signed-off-by: Max de Bayser <maxdebayser@gmail.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
@@ -3,29 +3,22 @@
|
||||
|
||||
import json
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from huggingface_hub import snapshot_download
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import ToolParser
|
||||
from vllm.tool_parsers.granite4_tool_parser import Granite4ToolParser
|
||||
from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
|
||||
|
||||
from ....utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
LORA_MODEL = "minpeter/LoRA-Llama-3.2-1B-tool-vllm-ci"
|
||||
|
||||
SERVER_ARGS = [
|
||||
"--enforce-eager",
|
||||
"--enable-auto-tool-choice",
|
||||
"--tool-call-parser",
|
||||
"hermes",
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
f"{LORA_MODEL}={LORA_MODEL}",
|
||||
"--tokenizer",
|
||||
f"{LORA_MODEL}",
|
||||
]
|
||||
|
||||
TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
@@ -50,6 +43,75 @@ TOOLS = [
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
class ServerConfig(TypedDict, total=False):
|
||||
model: str
|
||||
arguments: list[str]
|
||||
model_arg: str
|
||||
tool_parser: ToolParser
|
||||
|
||||
|
||||
CONFIGS: dict[str, ServerConfig] = {
|
||||
"llama": {
|
||||
"model": "meta-llama/Llama-3.2-1B-Instruct",
|
||||
"arguments": [
|
||||
"--enforce-eager",
|
||||
"--enable-auto-tool-choice",
|
||||
"--tool-call-parser",
|
||||
"hermes",
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
f"{LORA_MODEL}={LORA_MODEL}",
|
||||
"--tokenizer",
|
||||
f"{LORA_MODEL}",
|
||||
],
|
||||
"model_arg": LORA_MODEL,
|
||||
"tool_parser": Hermes2ProToolParser,
|
||||
},
|
||||
"granite4": {
|
||||
"model": "ibm-granite/granite-4.0-h-tiny",
|
||||
"arguments": [
|
||||
"--enforce-eager",
|
||||
"--enable-auto-tool-choice",
|
||||
"--tool-call-parser",
|
||||
"granite4",
|
||||
"--tokenizer",
|
||||
"ibm-granite/granite-4.0-h-tiny",
|
||||
"--max-model-len",
|
||||
"4096",
|
||||
"--max-num-seqs",
|
||||
"2",
|
||||
],
|
||||
"model_arg": "ibm-granite/granite-4.0-h-tiny",
|
||||
"tool_parser": Granite4ToolParser,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# for each server config, download the model and return the config
|
||||
@pytest.fixture(scope="session", params=CONFIGS.keys())
|
||||
def server_config(request):
|
||||
config = CONFIGS[request.param]
|
||||
|
||||
# download model and tokenizer using transformers
|
||||
snapshot_download(config["model"])
|
||||
yield CONFIGS[request.param]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(request, server_config: ServerConfig):
|
||||
model = server_config["model"]
|
||||
args_for_model = server_config["arguments"]
|
||||
with RemoteOpenAIServer(model, args_for_model, max_wait_seconds=480) as server:
|
||||
yield server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server: RemoteOpenAIServer):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
PRODUCT_TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
@@ -87,186 +149,182 @@ PRODUCT_MESSAGES = [
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_tool_call():
|
||||
async def test_non_streaming_tool_call(
|
||||
client: openai.AsyncOpenAI, server_config: ServerConfig
|
||||
):
|
||||
"""Test tool call in non-streaming mode."""
|
||||
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
|
||||
client = server.get_async_client()
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=LORA_MODEL,
|
||||
messages=MESSAGES,
|
||||
tools=TOOLS,
|
||||
tool_choice="auto",
|
||||
temperature=0.0,
|
||||
)
|
||||
response = await client.chat.completions.create(
|
||||
model=server_config["model_arg"],
|
||||
messages=MESSAGES,
|
||||
tools=TOOLS,
|
||||
tool_choice="auto",
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
assert response.choices
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
assert response.choices
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
assert choice.finish_reason == "tool_calls"
|
||||
assert message.tool_calls is not None
|
||||
assert choice.finish_reason == "tool_calls"
|
||||
assert message.tool_calls is not None
|
||||
|
||||
tool_call = message.tool_calls[0]
|
||||
assert tool_call.type == "function"
|
||||
assert tool_call.function.name == "get_current_weather"
|
||||
tool_call = message.tool_calls[0]
|
||||
assert tool_call.type == "function"
|
||||
assert tool_call.function.name == "get_current_weather"
|
||||
|
||||
arguments = json.loads(tool_call.function.arguments)
|
||||
assert "location" in arguments
|
||||
assert "Boston" in arguments["location"]
|
||||
print("\n[Non-Streaming Test Passed]")
|
||||
print(f"Tool Call: {tool_call.function.name}")
|
||||
print(f"Arguments: {arguments}")
|
||||
arguments = json.loads(tool_call.function.arguments)
|
||||
assert "location" in arguments
|
||||
assert "Boston" in arguments["location"]
|
||||
print("\n[Non-Streaming Test Passed]")
|
||||
print(f"Tool Call: {tool_call.function.name}")
|
||||
print(f"Arguments: {arguments}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_tool_call():
|
||||
async def test_streaming_tool_call(
|
||||
client: openai.AsyncOpenAI, server_config: ServerConfig
|
||||
):
|
||||
"""Test tool call in streaming mode."""
|
||||
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
|
||||
client = server.get_async_client()
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
model=LORA_MODEL,
|
||||
messages=MESSAGES,
|
||||
tools=TOOLS,
|
||||
tool_choice="auto",
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
)
|
||||
stream = await client.chat.completions.create(
|
||||
model=server_config["model_arg"],
|
||||
messages=MESSAGES,
|
||||
tools=TOOLS,
|
||||
tool_choice="auto",
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
tool_call_chunks = {}
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
tool_call_chunks = {}
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0].delta
|
||||
if not delta or not delta.tool_calls:
|
||||
continue
|
||||
delta = chunk.choices[0].delta
|
||||
if not delta or not delta.tool_calls:
|
||||
continue
|
||||
|
||||
for tool_chunk in delta.tool_calls:
|
||||
index = tool_chunk.index
|
||||
if index not in tool_call_chunks:
|
||||
tool_call_chunks[index] = {"name": "", "arguments": ""}
|
||||
for tool_chunk in delta.tool_calls:
|
||||
index = tool_chunk.index
|
||||
if index not in tool_call_chunks:
|
||||
tool_call_chunks[index] = {"name": "", "arguments": ""}
|
||||
|
||||
if tool_chunk.function.name:
|
||||
tool_call_chunks[index]["name"] += tool_chunk.function.name
|
||||
if tool_chunk.function.arguments:
|
||||
tool_call_chunks[index]["arguments"] += (
|
||||
tool_chunk.function.arguments
|
||||
)
|
||||
if tool_chunk.function.name:
|
||||
tool_call_chunks[index]["name"] += tool_chunk.function.name
|
||||
if tool_chunk.function.arguments:
|
||||
tool_call_chunks[index]["arguments"] += tool_chunk.function.arguments
|
||||
|
||||
assert len(tool_call_chunks) == 1
|
||||
reconstructed_tool_call = tool_call_chunks[0]
|
||||
assert len(tool_call_chunks) == 1
|
||||
reconstructed_tool_call = tool_call_chunks[0]
|
||||
|
||||
assert reconstructed_tool_call["name"] == "get_current_weather"
|
||||
assert reconstructed_tool_call["name"] == "get_current_weather"
|
||||
|
||||
arguments = json.loads(reconstructed_tool_call["arguments"])
|
||||
assert "location" in arguments
|
||||
assert "Boston" in arguments["location"]
|
||||
print("\n[Streaming Test Passed]")
|
||||
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
|
||||
print(f"Reconstructed Arguments: {arguments}")
|
||||
arguments = json.loads(reconstructed_tool_call["arguments"])
|
||||
assert "location" in arguments
|
||||
assert "Boston" in arguments["location"]
|
||||
print("\n[Streaming Test Passed]")
|
||||
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
|
||||
print(f"Reconstructed Arguments: {arguments}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_product_tool_call():
|
||||
async def test_non_streaming_product_tool_call(
|
||||
client: openai.AsyncOpenAI, server_config: ServerConfig
|
||||
):
|
||||
"""Test tool call integer and boolean parameters in non-streaming mode."""
|
||||
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
|
||||
client = server.get_async_client()
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=LORA_MODEL,
|
||||
messages=PRODUCT_MESSAGES,
|
||||
tools=PRODUCT_TOOLS,
|
||||
tool_choice="auto",
|
||||
temperature=0.66,
|
||||
)
|
||||
response = await client.chat.completions.create(
|
||||
model=server_config["model_arg"],
|
||||
messages=PRODUCT_MESSAGES,
|
||||
tools=PRODUCT_TOOLS,
|
||||
tool_choice="auto",
|
||||
temperature=0.66,
|
||||
)
|
||||
|
||||
assert response.choices
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
assert response.choices
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
assert choice.finish_reason == "tool_calls"
|
||||
assert message.tool_calls is not None
|
||||
assert choice.finish_reason == "tool_calls"
|
||||
assert message.tool_calls is not None
|
||||
|
||||
tool_call = message.tool_calls[0]
|
||||
assert tool_call.type == "function"
|
||||
assert tool_call.function.name == "get_product_info"
|
||||
tool_call = message.tool_calls[0]
|
||||
assert tool_call.type == "function"
|
||||
assert tool_call.function.name == "get_product_info"
|
||||
|
||||
arguments = json.loads(tool_call.function.arguments)
|
||||
assert "product_id" in arguments
|
||||
assert "inserted" in arguments
|
||||
arguments = json.loads(tool_call.function.arguments)
|
||||
assert "product_id" in arguments
|
||||
assert "inserted" in arguments
|
||||
|
||||
product_id = arguments.get("product_id")
|
||||
inserted = arguments.get("inserted")
|
||||
product_id = arguments.get("product_id")
|
||||
inserted = arguments.get("inserted")
|
||||
|
||||
assert isinstance(product_id, int)
|
||||
assert product_id == 7355608
|
||||
assert isinstance(inserted, bool)
|
||||
assert inserted is True
|
||||
assert isinstance(product_id, int)
|
||||
assert product_id == 7355608
|
||||
assert isinstance(inserted, bool)
|
||||
assert inserted is True
|
||||
|
||||
print("\n[Non-Streaming Product Test Passed]")
|
||||
print(f"Tool Call: {tool_call.function.name}")
|
||||
print(f"Arguments: {arguments}")
|
||||
print("\n[Non-Streaming Product Test Passed]")
|
||||
print(f"Tool Call: {tool_call.function.name}")
|
||||
print(f"Arguments: {arguments}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_product_tool_call():
|
||||
async def test_streaming_product_tool_call(
|
||||
client: openai.AsyncOpenAI, server_config: ServerConfig
|
||||
):
|
||||
"""Test tool call integer and boolean parameters in streaming mode."""
|
||||
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
|
||||
client = server.get_async_client()
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
model=LORA_MODEL,
|
||||
messages=PRODUCT_MESSAGES,
|
||||
tools=PRODUCT_TOOLS,
|
||||
tool_choice="auto",
|
||||
temperature=0.66,
|
||||
stream=True,
|
||||
)
|
||||
stream = await client.chat.completions.create(
|
||||
model=server_config["model_arg"],
|
||||
messages=PRODUCT_MESSAGES,
|
||||
tools=PRODUCT_TOOLS,
|
||||
tool_choice="auto",
|
||||
temperature=0.66,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
tool_call_chunks = {}
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
tool_call_chunks = {}
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0].delta
|
||||
if not delta or not delta.tool_calls:
|
||||
continue
|
||||
delta = chunk.choices[0].delta
|
||||
if not delta or not delta.tool_calls:
|
||||
continue
|
||||
|
||||
for tool_chunk in delta.tool_calls:
|
||||
index = tool_chunk.index
|
||||
if index not in tool_call_chunks:
|
||||
tool_call_chunks[index] = {"name": "", "arguments": ""}
|
||||
for tool_chunk in delta.tool_calls:
|
||||
index = tool_chunk.index
|
||||
if index not in tool_call_chunks:
|
||||
tool_call_chunks[index] = {"name": "", "arguments": ""}
|
||||
|
||||
if tool_chunk.function.name:
|
||||
tool_call_chunks[index]["name"] += tool_chunk.function.name
|
||||
if tool_chunk.function.arguments:
|
||||
tool_call_chunks[index]["arguments"] += (
|
||||
tool_chunk.function.arguments
|
||||
)
|
||||
if tool_chunk.function.name:
|
||||
tool_call_chunks[index]["name"] += tool_chunk.function.name
|
||||
if tool_chunk.function.arguments:
|
||||
tool_call_chunks[index]["arguments"] += tool_chunk.function.arguments
|
||||
|
||||
assert len(tool_call_chunks) == 1
|
||||
reconstructed_tool_call = tool_call_chunks[0]
|
||||
assert len(tool_call_chunks) == 1
|
||||
reconstructed_tool_call = tool_call_chunks[0]
|
||||
|
||||
assert reconstructed_tool_call["name"] == "get_product_info"
|
||||
assert reconstructed_tool_call["name"] == "get_product_info"
|
||||
|
||||
arguments = json.loads(reconstructed_tool_call["arguments"])
|
||||
assert "product_id" in arguments
|
||||
assert "inserted" in arguments
|
||||
arguments = json.loads(reconstructed_tool_call["arguments"])
|
||||
assert "product_id" in arguments
|
||||
assert "inserted" in arguments
|
||||
|
||||
# Handle type coercion for streaming test as well
|
||||
product_id = arguments.get("product_id")
|
||||
inserted = arguments.get("inserted")
|
||||
# Handle type coercion for streaming test as well
|
||||
product_id = arguments.get("product_id")
|
||||
inserted = arguments.get("inserted")
|
||||
|
||||
assert isinstance(product_id, int)
|
||||
assert product_id == 7355608
|
||||
assert isinstance(inserted, bool)
|
||||
assert inserted is True
|
||||
assert isinstance(product_id, int)
|
||||
assert product_id == 7355608
|
||||
assert isinstance(inserted, bool)
|
||||
assert inserted is True
|
||||
|
||||
print("\n[Streaming Product Test Passed]")
|
||||
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
|
||||
print(f"Reconstructed Arguments: {arguments}")
|
||||
print("\n[Streaming Product Test Passed]")
|
||||
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
|
||||
print(f"Reconstructed Arguments: {arguments}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -276,9 +334,10 @@ def qwen_tokenizer() -> TokenizerLike:
|
||||
return get_tokenizer("Qwen/Qwen3-32B")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hermes_parser(qwen_tokenizer: TokenizerLike) -> Hermes2ProToolParser:
|
||||
return Hermes2ProToolParser(qwen_tokenizer)
|
||||
@pytest.fixture(params=CONFIGS.keys())
|
||||
def hermes_parser(request, qwen_tokenizer: TokenizerLike) -> ToolParser:
|
||||
config = CONFIGS[request.param]
|
||||
return config["tool_parser"](qwen_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -292,7 +351,7 @@ def any_chat_request() -> ChatCompletionRequest:
|
||||
|
||||
def test_hermes_parser_streaming_just_forward_text(
|
||||
qwen_tokenizer: TokenizerLike,
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """This is some prior text that has nothing to do with tool calling."""
|
||||
@@ -324,7 +383,7 @@ def test_hermes_parser_streaming_just_forward_text(
|
||||
|
||||
def test_hermes_parser_streaming_failure_case_bug_19056(
|
||||
qwen_tokenizer: TokenizerLike,
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """<tool_call>
|
||||
@@ -358,7 +417,7 @@ def test_hermes_parser_streaming_failure_case_bug_19056(
|
||||
|
||||
def test_hermes_parser_streaming(
|
||||
qwen_tokenizer: TokenizerLike,
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = '<tool_call>\
|
||||
@@ -387,16 +446,20 @@ def test_hermes_parser_streaming(
|
||||
delta_messages.append(delta)
|
||||
print(delta_messages)
|
||||
assert delta_messages[0].tool_calls[0].function.name == "get_current_temperature"
|
||||
tool_call_args = "".join(
|
||||
delta.tool_calls[0].function.arguments or "" for delta in delta_messages
|
||||
)
|
||||
assert tool_call_args == (
|
||||
'{"location":"San Francisco, California, United States", "unit": "celsius"}'
|
||||
# load to normalize whitespace
|
||||
tool_call_args = json.loads(
|
||||
"".join(
|
||||
delta.tool_calls[0].function.arguments or "" for delta in delta_messages
|
||||
)
|
||||
)
|
||||
assert tool_call_args == {
|
||||
"location": "San Francisco, California, United States",
|
||||
"unit": "celsius",
|
||||
}
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_no_tool_call(
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """This is not a tool call."""
|
||||
@@ -410,7 +473,7 @@ def test_hermes_parser_non_streaming_no_tool_call(
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_tool_call_between_tags(
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """<tool_call>
|
||||
@@ -428,9 +491,12 @@ def test_hermes_parser_non_streaming_tool_call_between_tags(
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_tool_call_until_eos(
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
if isinstance(hermes_parser, Granite4ToolParser):
|
||||
pytest.skip(reason="The Granite4 tool parser enforces a complete response")
|
||||
|
||||
text = """<tool_call>
|
||||
{"name": "final_answer", "arguments": {"trigger": true}}"""
|
||||
tool_call = hermes_parser.extract_tool_calls(
|
||||
@@ -445,7 +511,7 @@ def test_hermes_parser_non_streaming_tool_call_until_eos(
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_tool_call_invalid_json(
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
# Missing closing brace to trigger exception
|
||||
|
||||
Reference in New Issue
Block a user