Add simple granite4 tool parser (#36827)
Signed-off-by: Max de Bayser <maxdebayser@gmail.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
@@ -219,7 +219,7 @@ Supported models:
|
||||
|
||||
* `ibm-granite/granite-4.0-h-small` and other Granite 4.0 models
|
||||
|
||||
Recommended flags: `--tool-call-parser hermes`
|
||||
Recommended flags: `--tool-call-parser granite4`
|
||||
|
||||
* `ibm-granite/granite-3.0-8b-instruct`
|
||||
|
||||
|
||||
@@ -0,0 +1,360 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
import random
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaMessage,
|
||||
)
|
||||
from vllm.tool_parsers.granite4_tool_parser import Granite4ToolParser
|
||||
|
||||
from ....utils import RemoteOpenAIServer
|
||||
|
||||
MODEL = "ibm-granite/granite-4.0-h-tiny"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
model = MODEL
|
||||
args_for_model = [
|
||||
"--enforce-eager",
|
||||
"--enable-auto-tool-choice",
|
||||
"--tool-call-parser",
|
||||
"granite4",
|
||||
"--tokenizer",
|
||||
"ibm-granite/granite-4.0-h-tiny",
|
||||
"--max-model-len",
|
||||
"4096",
|
||||
"--max-num-seqs",
|
||||
"2",
|
||||
]
|
||||
with RemoteOpenAIServer(model, args_for_model, max_wait_seconds=480) as server:
|
||||
yield server
|
||||
|
||||
|
||||
def create_complex_input(create_string_args: bool):
|
||||
coord_arg: dict | str = {
|
||||
"coordinates": [[23.54, 43.1], [-12.2, 54.3], [4, 5]],
|
||||
"coordinate_type": "latlong",
|
||||
}
|
||||
if create_string_args:
|
||||
# test granite behavior
|
||||
coord_arg = json.dumps(coord_arg)
|
||||
return [
|
||||
{"name": "find_bbox", "arguments": coord_arg},
|
||||
{
|
||||
"name": "get_stock_price",
|
||||
"arguments": {
|
||||
"symbol": "AAPL",
|
||||
"start_date": "2021-01-01",
|
||||
"end_date": "2021-12-31",
|
||||
},
|
||||
},
|
||||
{"name": "find_bbox", "arguments": coord_arg},
|
||||
]
|
||||
|
||||
|
||||
def random_chunks(s: str, min_len: int, max_len: int):
|
||||
chunks = []
|
||||
i = 0
|
||||
n = len(s)
|
||||
|
||||
while i < n:
|
||||
size = random.randint(min_len, max_len)
|
||||
chunks.append(s[i : i + size])
|
||||
i += size
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def tokenizer():
|
||||
return AutoTokenizer.from_pretrained(MODEL)
|
||||
|
||||
|
||||
# create a variety of input chunk sizes
|
||||
@pytest.mark.parametrize(
|
||||
"min_chunk, max_chunk",
|
||||
[
|
||||
(1, 1),
|
||||
(1, 2),
|
||||
(5, 7),
|
||||
(6, 20),
|
||||
],
|
||||
)
|
||||
def test_tool_call_parser_complex(min_chunk: int, max_chunk: int, tokenizer):
|
||||
input_dicts = create_complex_input(True)
|
||||
|
||||
formatted_tcs = [
|
||||
"<tool_call> " + json.dumps(call) + " </tool_call>" for call in input_dicts
|
||||
]
|
||||
|
||||
text_messages = [
|
||||
"Here goes the bbox call: \n",
|
||||
" Now the stock price call: \n ",
|
||||
" Now another bbox call: \n ",
|
||||
" See? I'm a helpful assistant.",
|
||||
]
|
||||
|
||||
test_input = (
|
||||
text_messages[0]
|
||||
+ formatted_tcs[0]
|
||||
+ text_messages[1]
|
||||
+ formatted_tcs[1]
|
||||
+ text_messages[2]
|
||||
+ formatted_tcs[2]
|
||||
+ text_messages[3]
|
||||
)
|
||||
|
||||
any_chat_request = ChatCompletionRequest(
|
||||
seed=42,
|
||||
model=MODEL,
|
||||
messages=[],
|
||||
)
|
||||
|
||||
parser = Granite4ToolParser(tokenizer=tokenizer)
|
||||
|
||||
delta_messages = list[DeltaMessage]()
|
||||
for text in random_chunks(test_input, min_chunk, max_chunk):
|
||||
delta = parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text="",
|
||||
delta_text=text,
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=any_chat_request,
|
||||
)
|
||||
if delta is not None:
|
||||
delta_messages.append(delta)
|
||||
|
||||
content = ""
|
||||
tool_calls = list[dict[str, Any]]()
|
||||
|
||||
current_name = "__start__"
|
||||
current_args = ""
|
||||
|
||||
for msg in delta_messages:
|
||||
if msg.content:
|
||||
content += msg.content
|
||||
for tool_call in msg.tool_calls:
|
||||
if delta_func := tool_call.function:
|
||||
if delta_func.name is not None:
|
||||
if current_name == "__start__":
|
||||
current_name = delta_func.name
|
||||
|
||||
if delta_func.name != current_name:
|
||||
tool_calls.append(
|
||||
{
|
||||
"name": current_name,
|
||||
"arguments": json.loads(current_args),
|
||||
}
|
||||
)
|
||||
current_name = delta_func.name
|
||||
current_args = ""
|
||||
|
||||
if delta_func.arguments:
|
||||
current_args += delta_func.arguments
|
||||
|
||||
if current_name != "__start__":
|
||||
tool_calls.append({"name": current_name, "arguments": json.loads(current_args)})
|
||||
|
||||
assert content == "".join(text_messages)
|
||||
assert tool_calls == create_complex_input(False)
|
||||
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_acme_region_name_for_transaction_id",
|
||||
"description": "Returns ACME transaction/transaction ID information"
|
||||
" including ACME regions\n\nArgs:\n start_time "
|
||||
"(str): Start date and time in datetime format "
|
||||
'"%Y-%m-%dT%H:%M:%S.%f"\n end_time (str): End '
|
||||
"date and time in datetime format "
|
||||
'"%Y-%m-%dT%H:%M:%S.%f"\n size (int, optional): '
|
||||
"Number of ACME Transaction IDs to return\n "
|
||||
"order (str, optional): Sort by most run "
|
||||
"transaction IDs. The value can be 'asc' for "
|
||||
"ascending or 'desc' for descending\n "
|
||||
"transaction_id (str, optional): ACME Transaction "
|
||||
"ID to filter on\n acme_region (str, optional): "
|
||||
"ACME Region to filter on\nReturns:\n - A "
|
||||
"dictionary containing a list of ACME transaction "
|
||||
"ids and the ACME regions they run in:\n {\n"
|
||||
' "Number of transaction IDs" : int,\n'
|
||||
' "Total transaction IDs available": int'
|
||||
',\n "ACME Transaction IDs": [\n '
|
||||
' {\n "Transaction ID": '
|
||||
'str,\n "Number of runs": int,\n'
|
||||
' "ACME Regions": [str],\n '
|
||||
" },\n ...\n ],"
|
||||
'\n "Start time" : datetime,\n '
|
||||
' "End time" : datetime,\n '
|
||||
' "Order" : str\n }\n '
|
||||
" - If no ACME region found for transaction id, "
|
||||
'returns:\n {"Success": "No ACME region '
|
||||
'found for transaction id."}\n - If an error '
|
||||
'occurs, returns:\n {"Error": "{exception'
|
||||
' message}"}',
|
||||
"parameters": {
|
||||
"properties": {
|
||||
"start_time": {},
|
||||
"end_time": {},
|
||||
"size": {"default": 500},
|
||||
"order": {"default": "desc"},
|
||||
"transaction_id": {"default": None},
|
||||
"acme_region": {"default": None},
|
||||
},
|
||||
"required": ["start_time", "end_time"],
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
tools2 = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
"type": "string",
|
||||
}
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_stock_price",
|
||||
"description": "Retrieves the current stock price for a given "
|
||||
"ticker symbol. The ticker symbol must be a valid "
|
||||
"symbol for a publicly traded company on a major US"
|
||||
" stock exchange like NYSE or NASDAQ. The tool will"
|
||||
" return the latest trade price in USD. It should "
|
||||
"be used when the user asks about the current or "
|
||||
"most recent price of a specific stock. It will not"
|
||||
" provide any other information about the stock or"
|
||||
" company.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"ticker": {
|
||||
"description": "The stock ticker symbol, e.g."
|
||||
" AAPL for Apple Inc.",
|
||||
"type": "string",
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"content": "\n\nSystem: You are a helpful, precise, and methodical AI"
|
||||
" assistant that uses tool outputs provided inline.\nAlways"
|
||||
" assume the current datetime is 2026-01-29T13:59:09.238901"
|
||||
"+00:00.\n\nIf you receive a ToolMessage with `tool_call_id"
|
||||
'` equal to "get_time_range" (or "time_range_tool"), you '
|
||||
"MUST:\n 1. Parse that JSON and use the values `start` and"
|
||||
" `end` directly when calling other tools.\n 2. Do not "
|
||||
"re-call or re-compute the time range.\n 3. Pass resolved "
|
||||
"values (ISO strings) as arguments to any subsequent tool "
|
||||
"(do not pass function metadata or placeholders).\n 4. If "
|
||||
"a tool requires datetime objects rather than strings, "
|
||||
"convert the ISO strings into language-native datetime "
|
||||
"objects before invoking.\n\nAlways return fully resolved "
|
||||
"arguments in correct types (e.g., ISO datetime strings or"
|
||||
" datetime objects) and never include placeholders like "
|
||||
'"<start>".\n\n',
|
||||
"role": "system",
|
||||
},
|
||||
{
|
||||
"content": "What are the transaction IDs that ran in the"
|
||||
" ACME region A9345 over the last two months?",
|
||||
"role": "user",
|
||||
},
|
||||
{
|
||||
"content": '["2026-01-26T09: 51: 55.467722Z", "2026-01-27T09: 51: 55.467722Z"]',
|
||||
"role": "tool",
|
||||
"tool_call_id": "time_range_tool",
|
||||
},
|
||||
]
|
||||
messages2 = [{"role": "user", "content": "What's stock price for IBM?"}]
|
||||
|
||||
messages3 = [{"role": "user", "content": "What's the current weather in New York?"}]
|
||||
|
||||
|
||||
def get_args(client: openai.OpenAI, _tools, _messages, _stop):
|
||||
response = client.chat.completions.create(
|
||||
model=MODEL,
|
||||
messages=_messages,
|
||||
temperature=0,
|
||||
tools=_tools,
|
||||
max_tokens=200,
|
||||
stop=_stop,
|
||||
tool_choice="auto",
|
||||
)
|
||||
|
||||
return response.choices[0].message.tool_calls[0].function.arguments
|
||||
|
||||
|
||||
async def get_args_streaming(
|
||||
async_client: openai.AsyncOpenAI, _tools, _messages, _stop
|
||||
):
|
||||
stream = await async_client.chat.completions.create(
|
||||
model=MODEL,
|
||||
messages=_messages,
|
||||
temperature=0,
|
||||
tools=_tools,
|
||||
max_tokens=200,
|
||||
stop=_stop,
|
||||
tool_choice="auto",
|
||||
stream=True,
|
||||
)
|
||||
full_call = []
|
||||
async for chunk in stream:
|
||||
tc = chunk.choices[0].delta.tool_calls
|
||||
if tc and tc[0].function.arguments:
|
||||
full_call.append(tc[0].function.arguments)
|
||||
return "".join(full_call)
|
||||
|
||||
|
||||
async def run_scenario(server: RemoteOpenAIServer, _tools, _messages, _stop):
|
||||
non_streaming = get_args(server.get_client(), _tools, _messages, _stop)
|
||||
json.loads(non_streaming) # verify that it is json loadable
|
||||
streaming = await get_args_streaming(
|
||||
server.get_async_client(), _tools, _messages, _stop
|
||||
)
|
||||
json.loads(streaming)
|
||||
assert non_streaming == streaming, f"{non_streaming=}, {streaming=}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_sequence_interference(server: RemoteOpenAIServer):
|
||||
print("Testing scenario 1")
|
||||
await run_scenario(server, tools, messages, "veroniqueprattyushveroniqueprattyush")
|
||||
|
||||
print("Testing scenario 2")
|
||||
await run_scenario(
|
||||
server, tools2, messages2, "veroniqueprattyushveroniqueprattyush"
|
||||
)
|
||||
|
||||
print("Testing scenario 3")
|
||||
await run_scenario(server, tools2, messages3, "prattyush")
|
||||
@@ -3,29 +3,22 @@
|
||||
|
||||
import json
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from huggingface_hub import snapshot_download
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import ToolParser
|
||||
from vllm.tool_parsers.granite4_tool_parser import Granite4ToolParser
|
||||
from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
|
||||
|
||||
from ....utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
LORA_MODEL = "minpeter/LoRA-Llama-3.2-1B-tool-vllm-ci"
|
||||
|
||||
SERVER_ARGS = [
|
||||
"--enforce-eager",
|
||||
"--enable-auto-tool-choice",
|
||||
"--tool-call-parser",
|
||||
"hermes",
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
f"{LORA_MODEL}={LORA_MODEL}",
|
||||
"--tokenizer",
|
||||
f"{LORA_MODEL}",
|
||||
]
|
||||
|
||||
TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
@@ -50,6 +43,75 @@ TOOLS = [
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
class ServerConfig(TypedDict, total=False):
|
||||
model: str
|
||||
arguments: list[str]
|
||||
model_arg: str
|
||||
tool_parser: ToolParser
|
||||
|
||||
|
||||
CONFIGS: dict[str, ServerConfig] = {
|
||||
"llama": {
|
||||
"model": "meta-llama/Llama-3.2-1B-Instruct",
|
||||
"arguments": [
|
||||
"--enforce-eager",
|
||||
"--enable-auto-tool-choice",
|
||||
"--tool-call-parser",
|
||||
"hermes",
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
f"{LORA_MODEL}={LORA_MODEL}",
|
||||
"--tokenizer",
|
||||
f"{LORA_MODEL}",
|
||||
],
|
||||
"model_arg": LORA_MODEL,
|
||||
"tool_parser": Hermes2ProToolParser,
|
||||
},
|
||||
"granite4": {
|
||||
"model": "ibm-granite/granite-4.0-h-tiny",
|
||||
"arguments": [
|
||||
"--enforce-eager",
|
||||
"--enable-auto-tool-choice",
|
||||
"--tool-call-parser",
|
||||
"granite4",
|
||||
"--tokenizer",
|
||||
"ibm-granite/granite-4.0-h-tiny",
|
||||
"--max-model-len",
|
||||
"4096",
|
||||
"--max-num-seqs",
|
||||
"2",
|
||||
],
|
||||
"model_arg": "ibm-granite/granite-4.0-h-tiny",
|
||||
"tool_parser": Granite4ToolParser,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# for each server config, download the model and return the config
|
||||
@pytest.fixture(scope="session", params=CONFIGS.keys())
|
||||
def server_config(request):
|
||||
config = CONFIGS[request.param]
|
||||
|
||||
# download model and tokenizer using transformers
|
||||
snapshot_download(config["model"])
|
||||
yield CONFIGS[request.param]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(request, server_config: ServerConfig):
|
||||
model = server_config["model"]
|
||||
args_for_model = server_config["arguments"]
|
||||
with RemoteOpenAIServer(model, args_for_model, max_wait_seconds=480) as server:
|
||||
yield server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server: RemoteOpenAIServer):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
PRODUCT_TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
@@ -87,186 +149,182 @@ PRODUCT_MESSAGES = [
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_tool_call():
|
||||
async def test_non_streaming_tool_call(
|
||||
client: openai.AsyncOpenAI, server_config: ServerConfig
|
||||
):
|
||||
"""Test tool call in non-streaming mode."""
|
||||
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
|
||||
client = server.get_async_client()
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=LORA_MODEL,
|
||||
messages=MESSAGES,
|
||||
tools=TOOLS,
|
||||
tool_choice="auto",
|
||||
temperature=0.0,
|
||||
)
|
||||
response = await client.chat.completions.create(
|
||||
model=server_config["model_arg"],
|
||||
messages=MESSAGES,
|
||||
tools=TOOLS,
|
||||
tool_choice="auto",
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
assert response.choices
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
assert response.choices
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
assert choice.finish_reason == "tool_calls"
|
||||
assert message.tool_calls is not None
|
||||
assert choice.finish_reason == "tool_calls"
|
||||
assert message.tool_calls is not None
|
||||
|
||||
tool_call = message.tool_calls[0]
|
||||
assert tool_call.type == "function"
|
||||
assert tool_call.function.name == "get_current_weather"
|
||||
tool_call = message.tool_calls[0]
|
||||
assert tool_call.type == "function"
|
||||
assert tool_call.function.name == "get_current_weather"
|
||||
|
||||
arguments = json.loads(tool_call.function.arguments)
|
||||
assert "location" in arguments
|
||||
assert "Boston" in arguments["location"]
|
||||
print("\n[Non-Streaming Test Passed]")
|
||||
print(f"Tool Call: {tool_call.function.name}")
|
||||
print(f"Arguments: {arguments}")
|
||||
arguments = json.loads(tool_call.function.arguments)
|
||||
assert "location" in arguments
|
||||
assert "Boston" in arguments["location"]
|
||||
print("\n[Non-Streaming Test Passed]")
|
||||
print(f"Tool Call: {tool_call.function.name}")
|
||||
print(f"Arguments: {arguments}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_tool_call():
|
||||
async def test_streaming_tool_call(
|
||||
client: openai.AsyncOpenAI, server_config: ServerConfig
|
||||
):
|
||||
"""Test tool call in streaming mode."""
|
||||
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
|
||||
client = server.get_async_client()
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
model=LORA_MODEL,
|
||||
messages=MESSAGES,
|
||||
tools=TOOLS,
|
||||
tool_choice="auto",
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
)
|
||||
stream = await client.chat.completions.create(
|
||||
model=server_config["model_arg"],
|
||||
messages=MESSAGES,
|
||||
tools=TOOLS,
|
||||
tool_choice="auto",
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
tool_call_chunks = {}
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
tool_call_chunks = {}
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0].delta
|
||||
if not delta or not delta.tool_calls:
|
||||
continue
|
||||
delta = chunk.choices[0].delta
|
||||
if not delta or not delta.tool_calls:
|
||||
continue
|
||||
|
||||
for tool_chunk in delta.tool_calls:
|
||||
index = tool_chunk.index
|
||||
if index not in tool_call_chunks:
|
||||
tool_call_chunks[index] = {"name": "", "arguments": ""}
|
||||
for tool_chunk in delta.tool_calls:
|
||||
index = tool_chunk.index
|
||||
if index not in tool_call_chunks:
|
||||
tool_call_chunks[index] = {"name": "", "arguments": ""}
|
||||
|
||||
if tool_chunk.function.name:
|
||||
tool_call_chunks[index]["name"] += tool_chunk.function.name
|
||||
if tool_chunk.function.arguments:
|
||||
tool_call_chunks[index]["arguments"] += (
|
||||
tool_chunk.function.arguments
|
||||
)
|
||||
if tool_chunk.function.name:
|
||||
tool_call_chunks[index]["name"] += tool_chunk.function.name
|
||||
if tool_chunk.function.arguments:
|
||||
tool_call_chunks[index]["arguments"] += tool_chunk.function.arguments
|
||||
|
||||
assert len(tool_call_chunks) == 1
|
||||
reconstructed_tool_call = tool_call_chunks[0]
|
||||
assert len(tool_call_chunks) == 1
|
||||
reconstructed_tool_call = tool_call_chunks[0]
|
||||
|
||||
assert reconstructed_tool_call["name"] == "get_current_weather"
|
||||
assert reconstructed_tool_call["name"] == "get_current_weather"
|
||||
|
||||
arguments = json.loads(reconstructed_tool_call["arguments"])
|
||||
assert "location" in arguments
|
||||
assert "Boston" in arguments["location"]
|
||||
print("\n[Streaming Test Passed]")
|
||||
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
|
||||
print(f"Reconstructed Arguments: {arguments}")
|
||||
arguments = json.loads(reconstructed_tool_call["arguments"])
|
||||
assert "location" in arguments
|
||||
assert "Boston" in arguments["location"]
|
||||
print("\n[Streaming Test Passed]")
|
||||
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
|
||||
print(f"Reconstructed Arguments: {arguments}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_product_tool_call():
|
||||
async def test_non_streaming_product_tool_call(
|
||||
client: openai.AsyncOpenAI, server_config: ServerConfig
|
||||
):
|
||||
"""Test tool call integer and boolean parameters in non-streaming mode."""
|
||||
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
|
||||
client = server.get_async_client()
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=LORA_MODEL,
|
||||
messages=PRODUCT_MESSAGES,
|
||||
tools=PRODUCT_TOOLS,
|
||||
tool_choice="auto",
|
||||
temperature=0.66,
|
||||
)
|
||||
response = await client.chat.completions.create(
|
||||
model=server_config["model_arg"],
|
||||
messages=PRODUCT_MESSAGES,
|
||||
tools=PRODUCT_TOOLS,
|
||||
tool_choice="auto",
|
||||
temperature=0.66,
|
||||
)
|
||||
|
||||
assert response.choices
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
assert response.choices
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
assert choice.finish_reason == "tool_calls"
|
||||
assert message.tool_calls is not None
|
||||
assert choice.finish_reason == "tool_calls"
|
||||
assert message.tool_calls is not None
|
||||
|
||||
tool_call = message.tool_calls[0]
|
||||
assert tool_call.type == "function"
|
||||
assert tool_call.function.name == "get_product_info"
|
||||
tool_call = message.tool_calls[0]
|
||||
assert tool_call.type == "function"
|
||||
assert tool_call.function.name == "get_product_info"
|
||||
|
||||
arguments = json.loads(tool_call.function.arguments)
|
||||
assert "product_id" in arguments
|
||||
assert "inserted" in arguments
|
||||
arguments = json.loads(tool_call.function.arguments)
|
||||
assert "product_id" in arguments
|
||||
assert "inserted" in arguments
|
||||
|
||||
product_id = arguments.get("product_id")
|
||||
inserted = arguments.get("inserted")
|
||||
product_id = arguments.get("product_id")
|
||||
inserted = arguments.get("inserted")
|
||||
|
||||
assert isinstance(product_id, int)
|
||||
assert product_id == 7355608
|
||||
assert isinstance(inserted, bool)
|
||||
assert inserted is True
|
||||
assert isinstance(product_id, int)
|
||||
assert product_id == 7355608
|
||||
assert isinstance(inserted, bool)
|
||||
assert inserted is True
|
||||
|
||||
print("\n[Non-Streaming Product Test Passed]")
|
||||
print(f"Tool Call: {tool_call.function.name}")
|
||||
print(f"Arguments: {arguments}")
|
||||
print("\n[Non-Streaming Product Test Passed]")
|
||||
print(f"Tool Call: {tool_call.function.name}")
|
||||
print(f"Arguments: {arguments}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_product_tool_call():
|
||||
async def test_streaming_product_tool_call(
|
||||
client: openai.AsyncOpenAI, server_config: ServerConfig
|
||||
):
|
||||
"""Test tool call integer and boolean parameters in streaming mode."""
|
||||
with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
|
||||
client = server.get_async_client()
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
model=LORA_MODEL,
|
||||
messages=PRODUCT_MESSAGES,
|
||||
tools=PRODUCT_TOOLS,
|
||||
tool_choice="auto",
|
||||
temperature=0.66,
|
||||
stream=True,
|
||||
)
|
||||
stream = await client.chat.completions.create(
|
||||
model=server_config["model_arg"],
|
||||
messages=PRODUCT_MESSAGES,
|
||||
tools=PRODUCT_TOOLS,
|
||||
tool_choice="auto",
|
||||
temperature=0.66,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
tool_call_chunks = {}
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
tool_call_chunks = {}
|
||||
async for chunk in stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0].delta
|
||||
if not delta or not delta.tool_calls:
|
||||
continue
|
||||
delta = chunk.choices[0].delta
|
||||
if not delta or not delta.tool_calls:
|
||||
continue
|
||||
|
||||
for tool_chunk in delta.tool_calls:
|
||||
index = tool_chunk.index
|
||||
if index not in tool_call_chunks:
|
||||
tool_call_chunks[index] = {"name": "", "arguments": ""}
|
||||
for tool_chunk in delta.tool_calls:
|
||||
index = tool_chunk.index
|
||||
if index not in tool_call_chunks:
|
||||
tool_call_chunks[index] = {"name": "", "arguments": ""}
|
||||
|
||||
if tool_chunk.function.name:
|
||||
tool_call_chunks[index]["name"] += tool_chunk.function.name
|
||||
if tool_chunk.function.arguments:
|
||||
tool_call_chunks[index]["arguments"] += (
|
||||
tool_chunk.function.arguments
|
||||
)
|
||||
if tool_chunk.function.name:
|
||||
tool_call_chunks[index]["name"] += tool_chunk.function.name
|
||||
if tool_chunk.function.arguments:
|
||||
tool_call_chunks[index]["arguments"] += tool_chunk.function.arguments
|
||||
|
||||
assert len(tool_call_chunks) == 1
|
||||
reconstructed_tool_call = tool_call_chunks[0]
|
||||
assert len(tool_call_chunks) == 1
|
||||
reconstructed_tool_call = tool_call_chunks[0]
|
||||
|
||||
assert reconstructed_tool_call["name"] == "get_product_info"
|
||||
assert reconstructed_tool_call["name"] == "get_product_info"
|
||||
|
||||
arguments = json.loads(reconstructed_tool_call["arguments"])
|
||||
assert "product_id" in arguments
|
||||
assert "inserted" in arguments
|
||||
arguments = json.loads(reconstructed_tool_call["arguments"])
|
||||
assert "product_id" in arguments
|
||||
assert "inserted" in arguments
|
||||
|
||||
# Handle type coercion for streaming test as well
|
||||
product_id = arguments.get("product_id")
|
||||
inserted = arguments.get("inserted")
|
||||
# Handle type coercion for streaming test as well
|
||||
product_id = arguments.get("product_id")
|
||||
inserted = arguments.get("inserted")
|
||||
|
||||
assert isinstance(product_id, int)
|
||||
assert product_id == 7355608
|
||||
assert isinstance(inserted, bool)
|
||||
assert inserted is True
|
||||
assert isinstance(product_id, int)
|
||||
assert product_id == 7355608
|
||||
assert isinstance(inserted, bool)
|
||||
assert inserted is True
|
||||
|
||||
print("\n[Streaming Product Test Passed]")
|
||||
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
|
||||
print(f"Reconstructed Arguments: {arguments}")
|
||||
print("\n[Streaming Product Test Passed]")
|
||||
print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
|
||||
print(f"Reconstructed Arguments: {arguments}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -276,9 +334,10 @@ def qwen_tokenizer() -> TokenizerLike:
|
||||
return get_tokenizer("Qwen/Qwen3-32B")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hermes_parser(qwen_tokenizer: TokenizerLike) -> Hermes2ProToolParser:
|
||||
return Hermes2ProToolParser(qwen_tokenizer)
|
||||
@pytest.fixture(params=CONFIGS.keys())
|
||||
def hermes_parser(request, qwen_tokenizer: TokenizerLike) -> ToolParser:
|
||||
config = CONFIGS[request.param]
|
||||
return config["tool_parser"](qwen_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -292,7 +351,7 @@ def any_chat_request() -> ChatCompletionRequest:
|
||||
|
||||
def test_hermes_parser_streaming_just_forward_text(
|
||||
qwen_tokenizer: TokenizerLike,
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """This is some prior text that has nothing to do with tool calling."""
|
||||
@@ -324,7 +383,7 @@ def test_hermes_parser_streaming_just_forward_text(
|
||||
|
||||
def test_hermes_parser_streaming_failure_case_bug_19056(
|
||||
qwen_tokenizer: TokenizerLike,
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """<tool_call>
|
||||
@@ -358,7 +417,7 @@ def test_hermes_parser_streaming_failure_case_bug_19056(
|
||||
|
||||
def test_hermes_parser_streaming(
|
||||
qwen_tokenizer: TokenizerLike,
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = '<tool_call>\
|
||||
@@ -387,16 +446,20 @@ def test_hermes_parser_streaming(
|
||||
delta_messages.append(delta)
|
||||
print(delta_messages)
|
||||
assert delta_messages[0].tool_calls[0].function.name == "get_current_temperature"
|
||||
tool_call_args = "".join(
|
||||
delta.tool_calls[0].function.arguments or "" for delta in delta_messages
|
||||
)
|
||||
assert tool_call_args == (
|
||||
'{"location":"San Francisco, California, United States", "unit": "celsius"}'
|
||||
# load to normalize whitespace
|
||||
tool_call_args = json.loads(
|
||||
"".join(
|
||||
delta.tool_calls[0].function.arguments or "" for delta in delta_messages
|
||||
)
|
||||
)
|
||||
assert tool_call_args == {
|
||||
"location": "San Francisco, California, United States",
|
||||
"unit": "celsius",
|
||||
}
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_no_tool_call(
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """This is not a tool call."""
|
||||
@@ -410,7 +473,7 @@ def test_hermes_parser_non_streaming_no_tool_call(
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_tool_call_between_tags(
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
text = """<tool_call>
|
||||
@@ -428,9 +491,12 @@ def test_hermes_parser_non_streaming_tool_call_between_tags(
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_tool_call_until_eos(
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
if isinstance(hermes_parser, Granite4ToolParser):
|
||||
pytest.skip(reason="The Granite4 tool parser enforces a complete response")
|
||||
|
||||
text = """<tool_call>
|
||||
{"name": "final_answer", "arguments": {"trigger": true}}"""
|
||||
tool_call = hermes_parser.extract_tool_calls(
|
||||
@@ -445,7 +511,7 @@ def test_hermes_parser_non_streaming_tool_call_until_eos(
|
||||
|
||||
|
||||
def test_hermes_parser_non_streaming_tool_call_invalid_json(
|
||||
hermes_parser: Hermes2ProToolParser,
|
||||
hermes_parser: ToolParser,
|
||||
any_chat_request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
# Missing closing brace to trigger exception
|
||||
|
||||
@@ -54,6 +54,10 @@ _TOOL_PARSERS_TO_REGISTER = {
|
||||
"granite_tool_parser",
|
||||
"GraniteToolParser",
|
||||
),
|
||||
"granite4": (
|
||||
"granite4_tool_parser",
|
||||
"Granite4ToolParser",
|
||||
),
|
||||
"hermes": (
|
||||
"hermes_tool_parser",
|
||||
"Hermes2ProToolParser",
|
||||
|
||||
252
vllm/tool_parsers/granite4_tool_parser.py
Normal file
252
vllm/tool_parsers/granite4_tool_parser.py
Normal file
@@ -0,0 +1,252 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Protocol, TypeVar
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def dump_args(args: None | dict[str, Any] | str) -> str | None:
|
||||
if args is None or isinstance(args, str):
|
||||
return args
|
||||
else:
|
||||
return json.dumps(args, ensure_ascii=False)
|
||||
|
||||
|
||||
class _FunctionCallCtor(Protocol):
|
||||
def __init__(self, *, name: str, arguments: str | None): ...
|
||||
|
||||
|
||||
FuncT = TypeVar("FuncT", bound=_FunctionCallCtor)
|
||||
|
||||
|
||||
class Granite4ToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.streamed_args_for_tool = list[str]()
|
||||
|
||||
self.look_ahead = ""
|
||||
self.in_tc = False
|
||||
|
||||
self.tc_start = "<tool_call>"
|
||||
self.tc_end = "</tool_call>"
|
||||
self.start_regex = re.compile(self.tc_start)
|
||||
self.end_regex = re.compile(self.tc_end)
|
||||
|
||||
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
request = super().adjust_request(request)
|
||||
if request.tools and request.tool_choice != "none":
|
||||
# do not skip special tokens because the tool_call tokens are
|
||||
# marked "special" in some models. Since they are skipped
|
||||
# prior to the call to the tool parser, it breaks tool calling.
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
def _collect_results(
|
||||
self, text_segments: list[str], tc_segments: list[str], cls: type[FuncT]
|
||||
) -> tuple[str, list[FuncT]]:
|
||||
tool_calls_json: list[dict[str, Any]] = [
|
||||
json.loads(tc_text) for tc_text in tc_segments
|
||||
]
|
||||
tool_calls = []
|
||||
for tc in tool_calls_json:
|
||||
assert isinstance(tc, dict)
|
||||
self.prev_tool_call_arr.append(tc)
|
||||
tool_calls.append(
|
||||
cls(
|
||||
name=tc["name"],
|
||||
arguments=dump_args(tc["arguments"]),
|
||||
)
|
||||
)
|
||||
return "".join(text_segments), tool_calls
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
msg = ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
try:
|
||||
delimiters = [("TC_START", self.tc_start), ("TC_END", self.tc_end)]
|
||||
pattern = "|".join(f"(?P<{name}>{pattern})" for name, pattern in delimiters)
|
||||
regex = re.compile(pattern)
|
||||
|
||||
text_segments = list[str]()
|
||||
tc_segments = list[str]()
|
||||
last_cut_loc = 0
|
||||
|
||||
for match in regex.finditer(model_output):
|
||||
match_type = match.lastgroup
|
||||
if match_type == "TC_START":
|
||||
assert not self.in_tc, "Two tool call start tokens found in a row"
|
||||
if preceding_text := model_output[last_cut_loc : match.start()]:
|
||||
text_segments.append(preceding_text)
|
||||
self.in_tc = True
|
||||
elif match_type == "TC_END":
|
||||
assert self.in_tc, (
|
||||
"Tool call end token found without corresponding start token"
|
||||
)
|
||||
tool_text = model_output[last_cut_loc : match.start()]
|
||||
assert tool_text, (
|
||||
"Expected the model to generate text between tool call tokens"
|
||||
)
|
||||
tc_segments.append(tool_text)
|
||||
self.in_tc = False
|
||||
else:
|
||||
raise ValueError("Unexpected match")
|
||||
last_cut_loc = match.end()
|
||||
assert not self.in_tc, "The model generated an incomplete tool call"
|
||||
if final_text := model_output[last_cut_loc:]:
|
||||
text_segments.append(final_text)
|
||||
|
||||
content, tool_call_funcs = self._collect_results(
|
||||
text_segments, tc_segments, FunctionCall
|
||||
)
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=func,
|
||||
)
|
||||
for func in tool_call_funcs
|
||||
]
|
||||
msg.tools_called = bool(tool_calls)
|
||||
msg.tool_calls = tool_calls
|
||||
msg.content = content or None
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
return msg
|
||||
|
||||
def _tool_extraction_step(
|
||||
self,
|
||||
delta_text: str,
|
||||
) -> tuple[bool, str, str]:
|
||||
start_token_pos = start_token_end = end_token_pos = end_token_end = -1
|
||||
|
||||
if start_match := self.start_regex.search(delta_text, partial=True):
|
||||
if not start_match.partial:
|
||||
start_token_pos, start_token_end = start_match.span()
|
||||
elif start_match.end() > start_match.start():
|
||||
start_token_pos = -2
|
||||
|
||||
if end_match := self.end_regex.search(delta_text):
|
||||
end_token_pos, end_token_end = end_match.span()
|
||||
|
||||
# Done means that we've exhausted the current buffer
|
||||
# and need more output from the model
|
||||
done = True
|
||||
content = tc_text = ""
|
||||
|
||||
if start_token_pos < 0:
|
||||
# just streaming text so far
|
||||
if start_token_pos == -2:
|
||||
# There is a partial match
|
||||
content = delta_text[: start_match.start()]
|
||||
self.look_ahead = delta_text[start_match.start() :]
|
||||
else:
|
||||
content = delta_text
|
||||
|
||||
elif not self.in_tc:
|
||||
# we're entering a new tool call
|
||||
self.in_tc = True
|
||||
|
||||
content = delta_text[:start_token_pos]
|
||||
if end_token_pos > 0:
|
||||
self.start_in_tc = False
|
||||
tc_text = delta_text[start_token_end:end_token_pos]
|
||||
self.look_ahead = delta_text[end_token_end:]
|
||||
done = False # There could be more content already buffered
|
||||
else:
|
||||
self.look_ahead = delta_text[start_token_pos:]
|
||||
|
||||
elif end_token_pos < 0:
|
||||
# we're in between the start and the end token
|
||||
assert self.in_tc
|
||||
self.look_ahead = delta_text
|
||||
else:
|
||||
# We have found the end
|
||||
assert self.in_tc
|
||||
tc_text = delta_text[start_token_end:end_token_pos]
|
||||
self.in_tc = False
|
||||
self.look_ahead = delta_text[end_token_end:]
|
||||
done = False # There could be more content already buffered
|
||||
return done, content, tc_text
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
try:
|
||||
done = False
|
||||
text_segments = list[str]()
|
||||
tc_segments = list[str]()
|
||||
|
||||
while not done:
|
||||
delta_text = self.look_ahead + delta_text
|
||||
self.look_ahead = ""
|
||||
done, content, tc_text = self._tool_extraction_step(delta_text)
|
||||
if content:
|
||||
text_segments.append(content)
|
||||
if tc_text:
|
||||
tc_segments.append(tc_text)
|
||||
delta_text = ""
|
||||
|
||||
content, tool_call_funcs = self._collect_results(
|
||||
text_segments, tc_segments, DeltaFunctionCall
|
||||
)
|
||||
|
||||
delta_tool_calls = list[DeltaToolCall]()
|
||||
for function in tool_call_funcs:
|
||||
self.current_tool_id += 1
|
||||
delta_tool_calls.append(
|
||||
DeltaToolCall(
|
||||
id=make_tool_call_id(),
|
||||
type="function",
|
||||
index=self.current_tool_id,
|
||||
function=function.model_dump(exclude_none=True),
|
||||
)
|
||||
)
|
||||
self.streamed_args_for_tool.append(function.arguments or "")
|
||||
|
||||
assert self.current_tool_id + 1 == len(self.prev_tool_call_arr)
|
||||
assert self.current_tool_id + 1 == len(self.streamed_args_for_tool)
|
||||
|
||||
msg = DeltaMessage(content=content or None, tool_calls=delta_tool_calls)
|
||||
if msg.content or msg.tool_calls:
|
||||
return msg
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
return None
|
||||
Reference in New Issue
Block a user