diff --git a/docs/features/tool_calling.md b/docs/features/tool_calling.md
index fe95735b9..b590b33e9 100644
--- a/docs/features/tool_calling.md
+++ b/docs/features/tool_calling.md
@@ -219,7 +219,7 @@ Supported models:
* `ibm-granite/granite-4.0-h-small` and other Granite 4.0 models
- Recommended flags: `--tool-call-parser hermes`
+ Recommended flags: `--tool-call-parser granite4`
* `ibm-granite/granite-3.0-8b-instruct`
diff --git a/tests/entrypoints/openai/tool_parsers/test_granite4_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_granite4_tool_parser.py
new file mode 100644
index 000000000..27e7a8c5d
--- /dev/null
+++ b/tests/entrypoints/openai/tool_parsers/test_granite4_tool_parser.py
@@ -0,0 +1,360 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import json
+import random
+from typing import Any
+
+import openai
+import pytest
+from transformers import AutoTokenizer
+
+from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
+from vllm.entrypoints.openai.engine.protocol import (
+ DeltaMessage,
+)
+from vllm.tool_parsers.granite4_tool_parser import Granite4ToolParser
+
+from ....utils import RemoteOpenAIServer
+
+MODEL = "ibm-granite/granite-4.0-h-tiny"
+
+
+@pytest.fixture(scope="module")
+def server():
+ model = MODEL
+ args_for_model = [
+ "--enforce-eager",
+ "--enable-auto-tool-choice",
+ "--tool-call-parser",
+ "granite4",
+ "--tokenizer",
+ "ibm-granite/granite-4.0-h-tiny",
+ "--max-model-len",
+ "4096",
+ "--max-num-seqs",
+ "2",
+ ]
+ with RemoteOpenAIServer(model, args_for_model, max_wait_seconds=480) as server:
+ yield server
+
+
+def create_complex_input(create_string_args: bool):
+ coord_arg: dict | str = {
+ "coordinates": [[23.54, 43.1], [-12.2, 54.3], [4, 5]],
+ "coordinate_type": "latlong",
+ }
+ if create_string_args:
+ # test granite behavior
+ coord_arg = json.dumps(coord_arg)
+ return [
+ {"name": "find_bbox", "arguments": coord_arg},
+ {
+ "name": "get_stock_price",
+ "arguments": {
+ "symbol": "AAPL",
+ "start_date": "2021-01-01",
+ "end_date": "2021-12-31",
+ },
+ },
+ {"name": "find_bbox", "arguments": coord_arg},
+ ]
+
+
+def random_chunks(s: str, min_len: int, max_len: int):
+ chunks = []
+ i = 0
+ n = len(s)
+
+ while i < n:
+ size = random.randint(min_len, max_len)
+ chunks.append(s[i : i + size])
+ i += size
+
+ return chunks
+
+
+@pytest.fixture(scope="module")
+def tokenizer():
+ return AutoTokenizer.from_pretrained(MODEL)
+
+
+# create a variety of input chunk sizes
+@pytest.mark.parametrize(
+ "min_chunk, max_chunk",
+ [
+ (1, 1),
+ (1, 2),
+ (5, 7),
+ (6, 20),
+ ],
+)
+def test_tool_call_parser_complex(min_chunk: int, max_chunk: int, tokenizer):
+ input_dicts = create_complex_input(True)
+
+ formatted_tcs = [
+ " " + json.dumps(call) + " " for call in input_dicts
+ ]
+
+ text_messages = [
+ "Here goes the bbox call: \n",
+ " Now the stock price call: \n ",
+ " Now another bbox call: \n ",
+ " See? I'm a helpful assistant.",
+ ]
+
+ test_input = (
+ text_messages[0]
+ + formatted_tcs[0]
+ + text_messages[1]
+ + formatted_tcs[1]
+ + text_messages[2]
+ + formatted_tcs[2]
+ + text_messages[3]
+ )
+
+ any_chat_request = ChatCompletionRequest(
+ seed=42,
+ model=MODEL,
+ messages=[],
+ )
+
+ parser = Granite4ToolParser(tokenizer=tokenizer)
+
+ delta_messages = list[DeltaMessage]()
+ for text in random_chunks(test_input, min_chunk, max_chunk):
+ delta = parser.extract_tool_calls_streaming(
+ previous_text="",
+ current_text="",
+ delta_text=text,
+ previous_token_ids=[],
+ current_token_ids=[],
+ delta_token_ids=[],
+ request=any_chat_request,
+ )
+ if delta is not None:
+ delta_messages.append(delta)
+
+ content = ""
+ tool_calls = list[dict[str, Any]]()
+
+ current_name = "__start__"
+ current_args = ""
+
+ for msg in delta_messages:
+ if msg.content:
+ content += msg.content
+ for tool_call in msg.tool_calls:
+ if delta_func := tool_call.function:
+ if delta_func.name is not None:
+ if current_name == "__start__":
+ current_name = delta_func.name
+
+ if delta_func.name != current_name:
+ tool_calls.append(
+ {
+ "name": current_name,
+ "arguments": json.loads(current_args),
+ }
+ )
+ current_name = delta_func.name
+ current_args = ""
+
+ if delta_func.arguments:
+ current_args += delta_func.arguments
+
+ if current_name != "__start__":
+ tool_calls.append({"name": current_name, "arguments": json.loads(current_args)})
+
+ assert content == "".join(text_messages)
+ assert tool_calls == create_complex_input(False)
+
+
+tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_acme_region_name_for_transaction_id",
+ "description": "Returns ACME transaction/transaction ID information"
+ " including ACME regions\n\nArgs:\n start_time "
+ "(str): Start date and time in datetime format "
+ '"%Y-%m-%dT%H:%M:%S.%f"\n end_time (str): End '
+ "date and time in datetime format "
+ '"%Y-%m-%dT%H:%M:%S.%f"\n size (int, optional): '
+ "Number of ACME Transaction IDs to return\n "
+ "order (str, optional): Sort by most run "
+ "transaction IDs. The value can be 'asc' for "
+ "ascending or 'desc' for descending\n "
+ "transaction_id (str, optional): ACME Transaction "
+ "ID to filter on\n acme_region (str, optional): "
+ "ACME Region to filter on\nReturns:\n - A "
+ "dictionary containing a list of ACME transaction "
+ "ids and the ACME regions they run in:\n {\n"
+ ' "Number of transaction IDs" : int,\n'
+ ' "Total transaction IDs available": int'
+ ',\n "ACME Transaction IDs": [\n '
+ ' {\n "Transaction ID": '
+ 'str,\n "Number of runs": int,\n'
+ ' "ACME Regions": [str],\n '
+ " },\n ...\n ],"
+ '\n "Start time" : datetime,\n '
+ ' "End time" : datetime,\n '
+ ' "Order" : str\n }\n '
+ " - If no ACME region found for transaction id, "
+ 'returns:\n {"Success": "No ACME region '
+ 'found for transaction id."}\n - If an error '
+ 'occurs, returns:\n {"Error": "{exception'
+ ' message}"}',
+ "parameters": {
+ "properties": {
+ "start_time": {},
+ "end_time": {},
+ "size": {"default": 500},
+ "order": {"default": "desc"},
+ "transaction_id": {"default": None},
+ "acme_region": {"default": None},
+ },
+ "required": ["start_time", "end_time"],
+ "type": "object",
+ },
+ },
+ }
+]
+
+tools2 = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_current_weather",
+ "description": "Get the current weather",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {
+ "description": "The city and state, e.g. San Francisco, CA",
+ "type": "string",
+ }
+ },
+ "required": ["location"],
+ },
+ },
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "get_stock_price",
+ "description": "Retrieves the current stock price for a given "
+ "ticker symbol. The ticker symbol must be a valid "
+ "symbol for a publicly traded company on a major US"
+ " stock exchange like NYSE or NASDAQ. The tool will"
+ " return the latest trade price in USD. It should "
+ "be used when the user asks about the current or "
+ "most recent price of a specific stock. It will not"
+ " provide any other information about the stock or"
+ " company.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "ticker": {
+ "description": "The stock ticker symbol, e.g."
+ " AAPL for Apple Inc.",
+ "type": "string",
+ }
+ },
+ },
+ },
+ },
+]
+
+messages = [
+ {
+ "content": "\n\nSystem: You are a helpful, precise, and methodical AI"
+ " assistant that uses tool outputs provided inline.\nAlways"
+ " assume the current datetime is 2026-01-29T13:59:09.238901"
+ "+00:00.\n\nIf you receive a ToolMessage with `tool_call_id"
+ '` equal to "get_time_range" (or "time_range_tool"), you '
+ "MUST:\n 1. Parse that JSON and use the values `start` and"
+ " `end` directly when calling other tools.\n 2. Do not "
+ "re-call or re-compute the time range.\n 3. Pass resolved "
+ "values (ISO strings) as arguments to any subsequent tool "
+ "(do not pass function metadata or placeholders).\n 4. If "
+ "a tool requires datetime objects rather than strings, "
+ "convert the ISO strings into language-native datetime "
+ "objects before invoking.\n\nAlways return fully resolved "
+ "arguments in correct types (e.g., ISO datetime strings or"
+ " datetime objects) and never include placeholders like "
+ '"".\n\n',
+ "role": "system",
+ },
+ {
+ "content": "What are the transaction IDs that ran in the"
+ " ACME region A9345 over the last two months?",
+ "role": "user",
+ },
+ {
+ "content": '["2026-01-26T09: 51: 55.467722Z", "2026-01-27T09: 51: 55.467722Z"]',
+ "role": "tool",
+ "tool_call_id": "time_range_tool",
+ },
+]
+messages2 = [{"role": "user", "content": "What's stock price for IBM?"}]
+
+messages3 = [{"role": "user", "content": "What's the current weather in New York?"}]
+
+
+def get_args(client: openai.OpenAI, _tools, _messages, _stop):
+ response = client.chat.completions.create(
+ model=MODEL,
+ messages=_messages,
+ temperature=0,
+ tools=_tools,
+ max_tokens=200,
+ stop=_stop,
+ tool_choice="auto",
+ )
+
+ return response.choices[0].message.tool_calls[0].function.arguments
+
+
+async def get_args_streaming(
+ async_client: openai.AsyncOpenAI, _tools, _messages, _stop
+):
+ stream = await async_client.chat.completions.create(
+ model=MODEL,
+ messages=_messages,
+ temperature=0,
+ tools=_tools,
+ max_tokens=200,
+ stop=_stop,
+ tool_choice="auto",
+ stream=True,
+ )
+ full_call = []
+ async for chunk in stream:
+ tc = chunk.choices[0].delta.tool_calls
+ if tc and tc[0].function.arguments:
+ full_call.append(tc[0].function.arguments)
+ return "".join(full_call)
+
+
+async def run_scenario(server: RemoteOpenAIServer, _tools, _messages, _stop):
+ non_streaming = get_args(server.get_client(), _tools, _messages, _stop)
+ json.loads(non_streaming) # verify that it is json loadable
+ streaming = await get_args_streaming(
+ server.get_async_client(), _tools, _messages, _stop
+ )
+ json.loads(streaming)
+ assert non_streaming == streaming, f"{non_streaming=}, {streaming=}"
+
+
+@pytest.mark.asyncio
+async def test_stop_sequence_interference(server: RemoteOpenAIServer):
+ print("Testing scenario 1")
+ await run_scenario(server, tools, messages, "veroniqueprattyushveroniqueprattyush")
+
+ print("Testing scenario 2")
+ await run_scenario(
+ server, tools2, messages2, "veroniqueprattyushveroniqueprattyush"
+ )
+
+ print("Testing scenario 3")
+ await run_scenario(server, tools2, messages3, "prattyush")
diff --git a/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py
index 626d845e1..be910fbb1 100644
--- a/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py
+++ b/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py
@@ -3,29 +3,22 @@
import json
+import openai
import pytest
+import pytest_asyncio
+from huggingface_hub import snapshot_download
+from typing_extensions import TypedDict
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.tokenizers import TokenizerLike
+from vllm.tool_parsers.abstract_tool_parser import ToolParser
+from vllm.tool_parsers.granite4_tool_parser import Granite4ToolParser
from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
from ....utils import RemoteOpenAIServer
-MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
LORA_MODEL = "minpeter/LoRA-Llama-3.2-1B-tool-vllm-ci"
-SERVER_ARGS = [
- "--enforce-eager",
- "--enable-auto-tool-choice",
- "--tool-call-parser",
- "hermes",
- "--enable-lora",
- "--lora-modules",
- f"{LORA_MODEL}={LORA_MODEL}",
- "--tokenizer",
- f"{LORA_MODEL}",
-]
-
TOOLS = [
{
"type": "function",
@@ -50,6 +43,75 @@ TOOLS = [
}
]
+
+class ServerConfig(TypedDict, total=False):
+ model: str
+ arguments: list[str]
+ model_arg: str
+ tool_parser: ToolParser
+
+
+CONFIGS: dict[str, ServerConfig] = {
+ "llama": {
+ "model": "meta-llama/Llama-3.2-1B-Instruct",
+ "arguments": [
+ "--enforce-eager",
+ "--enable-auto-tool-choice",
+ "--tool-call-parser",
+ "hermes",
+ "--enable-lora",
+ "--lora-modules",
+ f"{LORA_MODEL}={LORA_MODEL}",
+ "--tokenizer",
+ f"{LORA_MODEL}",
+ ],
+ "model_arg": LORA_MODEL,
+ "tool_parser": Hermes2ProToolParser,
+ },
+ "granite4": {
+ "model": "ibm-granite/granite-4.0-h-tiny",
+ "arguments": [
+ "--enforce-eager",
+ "--enable-auto-tool-choice",
+ "--tool-call-parser",
+ "granite4",
+ "--tokenizer",
+ "ibm-granite/granite-4.0-h-tiny",
+ "--max-model-len",
+ "4096",
+ "--max-num-seqs",
+ "2",
+ ],
+ "model_arg": "ibm-granite/granite-4.0-h-tiny",
+ "tool_parser": Granite4ToolParser,
+ },
+}
+
+
+# for each server config, download the model and return the config
+@pytest.fixture(scope="session", params=CONFIGS.keys())
+def server_config(request):
+ config = CONFIGS[request.param]
+
+ # download model and tokenizer using transformers
+ snapshot_download(config["model"])
+ yield CONFIGS[request.param]
+
+
+@pytest.fixture(scope="module")
+def server(request, server_config: ServerConfig):
+ model = server_config["model"]
+ args_for_model = server_config["arguments"]
+ with RemoteOpenAIServer(model, args_for_model, max_wait_seconds=480) as server:
+ yield server
+
+
+@pytest_asyncio.fixture
+async def client(server: RemoteOpenAIServer):
+ async with server.get_async_client() as async_client:
+ yield async_client
+
+
PRODUCT_TOOLS = [
{
"type": "function",
@@ -87,186 +149,182 @@ PRODUCT_MESSAGES = [
@pytest.mark.asyncio
-async def test_non_streaming_tool_call():
+async def test_non_streaming_tool_call(
+ client: openai.AsyncOpenAI, server_config: ServerConfig
+):
"""Test tool call in non-streaming mode."""
- with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
- client = server.get_async_client()
- response = await client.chat.completions.create(
- model=LORA_MODEL,
- messages=MESSAGES,
- tools=TOOLS,
- tool_choice="auto",
- temperature=0.0,
- )
+ response = await client.chat.completions.create(
+ model=server_config["model_arg"],
+ messages=MESSAGES,
+ tools=TOOLS,
+ tool_choice="auto",
+ temperature=0.0,
+ )
- assert response.choices
- choice = response.choices[0]
- message = choice.message
+ assert response.choices
+ choice = response.choices[0]
+ message = choice.message
- assert choice.finish_reason == "tool_calls"
- assert message.tool_calls is not None
+ assert choice.finish_reason == "tool_calls"
+ assert message.tool_calls is not None
- tool_call = message.tool_calls[0]
- assert tool_call.type == "function"
- assert tool_call.function.name == "get_current_weather"
+ tool_call = message.tool_calls[0]
+ assert tool_call.type == "function"
+ assert tool_call.function.name == "get_current_weather"
- arguments = json.loads(tool_call.function.arguments)
- assert "location" in arguments
- assert "Boston" in arguments["location"]
- print("\n[Non-Streaming Test Passed]")
- print(f"Tool Call: {tool_call.function.name}")
- print(f"Arguments: {arguments}")
+ arguments = json.loads(tool_call.function.arguments)
+ assert "location" in arguments
+ assert "Boston" in arguments["location"]
+ print("\n[Non-Streaming Test Passed]")
+ print(f"Tool Call: {tool_call.function.name}")
+ print(f"Arguments: {arguments}")
@pytest.mark.asyncio
-async def test_streaming_tool_call():
+async def test_streaming_tool_call(
+ client: openai.AsyncOpenAI, server_config: ServerConfig
+):
"""Test tool call in streaming mode."""
- with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
- client = server.get_async_client()
- stream = await client.chat.completions.create(
- model=LORA_MODEL,
- messages=MESSAGES,
- tools=TOOLS,
- tool_choice="auto",
- temperature=0.0,
- stream=True,
- )
+ stream = await client.chat.completions.create(
+ model=server_config["model_arg"],
+ messages=MESSAGES,
+ tools=TOOLS,
+ tool_choice="auto",
+ temperature=0.0,
+ stream=True,
+ )
- tool_call_chunks = {}
- async for chunk in stream:
- if not chunk.choices:
- continue
+ tool_call_chunks = {}
+ async for chunk in stream:
+ if not chunk.choices:
+ continue
- delta = chunk.choices[0].delta
- if not delta or not delta.tool_calls:
- continue
+ delta = chunk.choices[0].delta
+ if not delta or not delta.tool_calls:
+ continue
- for tool_chunk in delta.tool_calls:
- index = tool_chunk.index
- if index not in tool_call_chunks:
- tool_call_chunks[index] = {"name": "", "arguments": ""}
+ for tool_chunk in delta.tool_calls:
+ index = tool_chunk.index
+ if index not in tool_call_chunks:
+ tool_call_chunks[index] = {"name": "", "arguments": ""}
- if tool_chunk.function.name:
- tool_call_chunks[index]["name"] += tool_chunk.function.name
- if tool_chunk.function.arguments:
- tool_call_chunks[index]["arguments"] += (
- tool_chunk.function.arguments
- )
+ if tool_chunk.function.name:
+ tool_call_chunks[index]["name"] += tool_chunk.function.name
+ if tool_chunk.function.arguments:
+ tool_call_chunks[index]["arguments"] += tool_chunk.function.arguments
- assert len(tool_call_chunks) == 1
- reconstructed_tool_call = tool_call_chunks[0]
+ assert len(tool_call_chunks) == 1
+ reconstructed_tool_call = tool_call_chunks[0]
- assert reconstructed_tool_call["name"] == "get_current_weather"
+ assert reconstructed_tool_call["name"] == "get_current_weather"
- arguments = json.loads(reconstructed_tool_call["arguments"])
- assert "location" in arguments
- assert "Boston" in arguments["location"]
- print("\n[Streaming Test Passed]")
- print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
- print(f"Reconstructed Arguments: {arguments}")
+ arguments = json.loads(reconstructed_tool_call["arguments"])
+ assert "location" in arguments
+ assert "Boston" in arguments["location"]
+ print("\n[Streaming Test Passed]")
+ print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
+ print(f"Reconstructed Arguments: {arguments}")
@pytest.mark.asyncio
-async def test_non_streaming_product_tool_call():
+async def test_non_streaming_product_tool_call(
+ client: openai.AsyncOpenAI, server_config: ServerConfig
+):
"""Test tool call integer and boolean parameters in non-streaming mode."""
- with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
- client = server.get_async_client()
- response = await client.chat.completions.create(
- model=LORA_MODEL,
- messages=PRODUCT_MESSAGES,
- tools=PRODUCT_TOOLS,
- tool_choice="auto",
- temperature=0.66,
- )
+ response = await client.chat.completions.create(
+ model=server_config["model_arg"],
+ messages=PRODUCT_MESSAGES,
+ tools=PRODUCT_TOOLS,
+ tool_choice="auto",
+ temperature=0.66,
+ )
- assert response.choices
- choice = response.choices[0]
- message = choice.message
+ assert response.choices
+ choice = response.choices[0]
+ message = choice.message
- assert choice.finish_reason == "tool_calls"
- assert message.tool_calls is not None
+ assert choice.finish_reason == "tool_calls"
+ assert message.tool_calls is not None
- tool_call = message.tool_calls[0]
- assert tool_call.type == "function"
- assert tool_call.function.name == "get_product_info"
+ tool_call = message.tool_calls[0]
+ assert tool_call.type == "function"
+ assert tool_call.function.name == "get_product_info"
- arguments = json.loads(tool_call.function.arguments)
- assert "product_id" in arguments
- assert "inserted" in arguments
+ arguments = json.loads(tool_call.function.arguments)
+ assert "product_id" in arguments
+ assert "inserted" in arguments
- product_id = arguments.get("product_id")
- inserted = arguments.get("inserted")
+ product_id = arguments.get("product_id")
+ inserted = arguments.get("inserted")
- assert isinstance(product_id, int)
- assert product_id == 7355608
- assert isinstance(inserted, bool)
- assert inserted is True
+ assert isinstance(product_id, int)
+ assert product_id == 7355608
+ assert isinstance(inserted, bool)
+ assert inserted is True
- print("\n[Non-Streaming Product Test Passed]")
- print(f"Tool Call: {tool_call.function.name}")
- print(f"Arguments: {arguments}")
+ print("\n[Non-Streaming Product Test Passed]")
+ print(f"Tool Call: {tool_call.function.name}")
+ print(f"Arguments: {arguments}")
@pytest.mark.asyncio
-async def test_streaming_product_tool_call():
+async def test_streaming_product_tool_call(
+ client: openai.AsyncOpenAI, server_config: ServerConfig
+):
"""Test tool call integer and boolean parameters in streaming mode."""
- with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server:
- client = server.get_async_client()
- stream = await client.chat.completions.create(
- model=LORA_MODEL,
- messages=PRODUCT_MESSAGES,
- tools=PRODUCT_TOOLS,
- tool_choice="auto",
- temperature=0.66,
- stream=True,
- )
+ stream = await client.chat.completions.create(
+ model=server_config["model_arg"],
+ messages=PRODUCT_MESSAGES,
+ tools=PRODUCT_TOOLS,
+ tool_choice="auto",
+ temperature=0.66,
+ stream=True,
+ )
- tool_call_chunks = {}
- async for chunk in stream:
- if not chunk.choices:
- continue
+ tool_call_chunks = {}
+ async for chunk in stream:
+ if not chunk.choices:
+ continue
- delta = chunk.choices[0].delta
- if not delta or not delta.tool_calls:
- continue
+ delta = chunk.choices[0].delta
+ if not delta or not delta.tool_calls:
+ continue
- for tool_chunk in delta.tool_calls:
- index = tool_chunk.index
- if index not in tool_call_chunks:
- tool_call_chunks[index] = {"name": "", "arguments": ""}
+ for tool_chunk in delta.tool_calls:
+ index = tool_chunk.index
+ if index not in tool_call_chunks:
+ tool_call_chunks[index] = {"name": "", "arguments": ""}
- if tool_chunk.function.name:
- tool_call_chunks[index]["name"] += tool_chunk.function.name
- if tool_chunk.function.arguments:
- tool_call_chunks[index]["arguments"] += (
- tool_chunk.function.arguments
- )
+ if tool_chunk.function.name:
+ tool_call_chunks[index]["name"] += tool_chunk.function.name
+ if tool_chunk.function.arguments:
+ tool_call_chunks[index]["arguments"] += tool_chunk.function.arguments
- assert len(tool_call_chunks) == 1
- reconstructed_tool_call = tool_call_chunks[0]
+ assert len(tool_call_chunks) == 1
+ reconstructed_tool_call = tool_call_chunks[0]
- assert reconstructed_tool_call["name"] == "get_product_info"
+ assert reconstructed_tool_call["name"] == "get_product_info"
- arguments = json.loads(reconstructed_tool_call["arguments"])
- assert "product_id" in arguments
- assert "inserted" in arguments
+ arguments = json.loads(reconstructed_tool_call["arguments"])
+ assert "product_id" in arguments
+ assert "inserted" in arguments
- # Handle type coercion for streaming test as well
- product_id = arguments.get("product_id")
- inserted = arguments.get("inserted")
+ # Handle type coercion for streaming test as well
+ product_id = arguments.get("product_id")
+ inserted = arguments.get("inserted")
- assert isinstance(product_id, int)
- assert product_id == 7355608
- assert isinstance(inserted, bool)
- assert inserted is True
+ assert isinstance(product_id, int)
+ assert product_id == 7355608
+ assert isinstance(inserted, bool)
+ assert inserted is True
- print("\n[Streaming Product Test Passed]")
- print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
- print(f"Reconstructed Arguments: {arguments}")
+ print("\n[Streaming Product Test Passed]")
+ print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}")
+ print(f"Reconstructed Arguments: {arguments}")
@pytest.fixture
@@ -276,9 +334,10 @@ def qwen_tokenizer() -> TokenizerLike:
return get_tokenizer("Qwen/Qwen3-32B")
-@pytest.fixture
-def hermes_parser(qwen_tokenizer: TokenizerLike) -> Hermes2ProToolParser:
- return Hermes2ProToolParser(qwen_tokenizer)
+@pytest.fixture(params=CONFIGS.keys())
+def hermes_parser(request, qwen_tokenizer: TokenizerLike) -> ToolParser:
+ config = CONFIGS[request.param]
+ return config["tool_parser"](qwen_tokenizer)
@pytest.fixture
@@ -292,7 +351,7 @@ def any_chat_request() -> ChatCompletionRequest:
def test_hermes_parser_streaming_just_forward_text(
qwen_tokenizer: TokenizerLike,
- hermes_parser: Hermes2ProToolParser,
+ hermes_parser: ToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = """This is some prior text that has nothing to do with tool calling."""
@@ -324,7 +383,7 @@ def test_hermes_parser_streaming_just_forward_text(
def test_hermes_parser_streaming_failure_case_bug_19056(
qwen_tokenizer: TokenizerLike,
- hermes_parser: Hermes2ProToolParser,
+ hermes_parser: ToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = """
@@ -358,7 +417,7 @@ def test_hermes_parser_streaming_failure_case_bug_19056(
def test_hermes_parser_streaming(
qwen_tokenizer: TokenizerLike,
- hermes_parser: Hermes2ProToolParser,
+ hermes_parser: ToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = '\
@@ -387,16 +446,20 @@ def test_hermes_parser_streaming(
delta_messages.append(delta)
print(delta_messages)
assert delta_messages[0].tool_calls[0].function.name == "get_current_temperature"
- tool_call_args = "".join(
- delta.tool_calls[0].function.arguments or "" for delta in delta_messages
- )
- assert tool_call_args == (
- '{"location":"San Francisco, California, United States", "unit": "celsius"}'
+ # load to normalize whitespace
+ tool_call_args = json.loads(
+ "".join(
+ delta.tool_calls[0].function.arguments or "" for delta in delta_messages
+ )
)
+ assert tool_call_args == {
+ "location": "San Francisco, California, United States",
+ "unit": "celsius",
+ }
def test_hermes_parser_non_streaming_no_tool_call(
- hermes_parser: Hermes2ProToolParser,
+ hermes_parser: ToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = """This is not a tool call."""
@@ -410,7 +473,7 @@ def test_hermes_parser_non_streaming_no_tool_call(
def test_hermes_parser_non_streaming_tool_call_between_tags(
- hermes_parser: Hermes2ProToolParser,
+ hermes_parser: ToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
text = """
@@ -428,9 +491,12 @@ def test_hermes_parser_non_streaming_tool_call_between_tags(
def test_hermes_parser_non_streaming_tool_call_until_eos(
- hermes_parser: Hermes2ProToolParser,
+ hermes_parser: ToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
+ if isinstance(hermes_parser, Granite4ToolParser):
+ pytest.skip(reason="The Granite4 tool parser enforces a complete response")
+
text = """
{"name": "final_answer", "arguments": {"trigger": true}}"""
tool_call = hermes_parser.extract_tool_calls(
@@ -445,7 +511,7 @@ def test_hermes_parser_non_streaming_tool_call_until_eos(
def test_hermes_parser_non_streaming_tool_call_invalid_json(
- hermes_parser: Hermes2ProToolParser,
+ hermes_parser: ToolParser,
any_chat_request: ChatCompletionRequest,
) -> None:
# Missing closing brace to trigger exception
diff --git a/vllm/tool_parsers/__init__.py b/vllm/tool_parsers/__init__.py
index c1a39f2af..f480a635c 100644
--- a/vllm/tool_parsers/__init__.py
+++ b/vllm/tool_parsers/__init__.py
@@ -54,6 +54,10 @@ _TOOL_PARSERS_TO_REGISTER = {
"granite_tool_parser",
"GraniteToolParser",
),
+ "granite4": (
+ "granite4_tool_parser",
+ "Granite4ToolParser",
+ ),
"hermes": (
"hermes_tool_parser",
"Hermes2ProToolParser",
diff --git a/vllm/tool_parsers/granite4_tool_parser.py b/vllm/tool_parsers/granite4_tool_parser.py
new file mode 100644
index 000000000..693c4dc8f
--- /dev/null
+++ b/vllm/tool_parsers/granite4_tool_parser.py
@@ -0,0 +1,252 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import json
+from collections.abc import Sequence
+from typing import Any, Protocol, TypeVar
+
+import regex as re
+
+from vllm.entrypoints.chat_utils import make_tool_call_id
+from vllm.entrypoints.openai.chat_completion.protocol import (
+ ChatCompletionRequest,
+)
+from vllm.entrypoints.openai.engine.protocol import (
+ DeltaFunctionCall,
+ DeltaMessage,
+ DeltaToolCall,
+ ExtractedToolCallInformation,
+ FunctionCall,
+ ToolCall,
+)
+from vllm.logger import init_logger
+from vllm.tokenizers import TokenizerLike
+from vllm.tool_parsers.abstract_tool_parser import (
+ ToolParser,
+)
+
+logger = init_logger(__name__)
+
+
+def dump_args(args: None | dict[str, Any] | str) -> str | None:
+ if args is None or isinstance(args, str):
+ return args
+ else:
+ return json.dumps(args, ensure_ascii=False)
+
+
+class _FunctionCallCtor(Protocol):
+ def __init__(self, *, name: str, arguments: str | None): ...
+
+
+FuncT = TypeVar("FuncT", bound=_FunctionCallCtor)
+
+
+class Granite4ToolParser(ToolParser):
+ def __init__(self, tokenizer: TokenizerLike):
+ super().__init__(tokenizer)
+
+ self.prev_tool_call_arr: list[dict] = []
+ self.current_tool_id: int = -1
+ self.streamed_args_for_tool = list[str]()
+
+ self.look_ahead = ""
+ self.in_tc = False
+
+ self.tc_start = ""
+ self.tc_end = ""
+ self.start_regex = re.compile(self.tc_start)
+ self.end_regex = re.compile(self.tc_end)
+
+ def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
+ request = super().adjust_request(request)
+ if request.tools and request.tool_choice != "none":
+ # do not skip special tokens because the tool_call tokens are
+ # marked "special" in some models. Since they are skipped
+ # prior to the call to the tool parser, it breaks tool calling.
+ request.skip_special_tokens = False
+ return request
+
+ def _collect_results(
+ self, text_segments: list[str], tc_segments: list[str], cls: type[FuncT]
+ ) -> tuple[str, list[FuncT]]:
+ tool_calls_json: list[dict[str, Any]] = [
+ json.loads(tc_text) for tc_text in tc_segments
+ ]
+ tool_calls = []
+ for tc in tool_calls_json:
+ assert isinstance(tc, dict)
+ self.prev_tool_call_arr.append(tc)
+ tool_calls.append(
+ cls(
+ name=tc["name"],
+ arguments=dump_args(tc["arguments"]),
+ )
+ )
+ return "".join(text_segments), tool_calls
+
+ def extract_tool_calls(
+ self,
+ model_output: str,
+ request: ChatCompletionRequest,
+ ) -> ExtractedToolCallInformation:
+ msg = ExtractedToolCallInformation(
+ tools_called=False, tool_calls=[], content=model_output
+ )
+ try:
+ delimiters = [("TC_START", self.tc_start), ("TC_END", self.tc_end)]
+ pattern = "|".join(f"(?P<{name}>{pattern})" for name, pattern in delimiters)
+ regex = re.compile(pattern)
+
+ text_segments = list[str]()
+ tc_segments = list[str]()
+ last_cut_loc = 0
+
+ for match in regex.finditer(model_output):
+ match_type = match.lastgroup
+ if match_type == "TC_START":
+ assert not self.in_tc, "Two tool call start tokens found in a row"
+ if preceding_text := model_output[last_cut_loc : match.start()]:
+ text_segments.append(preceding_text)
+ self.in_tc = True
+ elif match_type == "TC_END":
+ assert self.in_tc, (
+ "Tool call end token found without corresponding start token"
+ )
+ tool_text = model_output[last_cut_loc : match.start()]
+ assert tool_text, (
+ "Expected the model to generate text between tool call tokens"
+ )
+ tc_segments.append(tool_text)
+ self.in_tc = False
+ else:
+ raise ValueError("Unexpected match")
+ last_cut_loc = match.end()
+ assert not self.in_tc, "The model generated an incomplete tool call"
+ if final_text := model_output[last_cut_loc:]:
+ text_segments.append(final_text)
+
+ content, tool_call_funcs = self._collect_results(
+ text_segments, tc_segments, FunctionCall
+ )
+ tool_calls = [
+ ToolCall(
+ type="function",
+ function=func,
+ )
+ for func in tool_call_funcs
+ ]
+ msg.tools_called = bool(tool_calls)
+ msg.tool_calls = tool_calls
+ msg.content = content or None
+ except Exception:
+ logger.exception("Error in extracting tool call from response.")
+ return msg
+
+ def _tool_extraction_step(
+ self,
+ delta_text: str,
+ ) -> tuple[bool, str, str]:
+ start_token_pos = start_token_end = end_token_pos = end_token_end = -1
+
+ if start_match := self.start_regex.search(delta_text, partial=True):
+ if not start_match.partial:
+ start_token_pos, start_token_end = start_match.span()
+ elif start_match.end() > start_match.start():
+ start_token_pos = -2
+
+ if end_match := self.end_regex.search(delta_text):
+ end_token_pos, end_token_end = end_match.span()
+
+ # Done means that we've exhausted the current buffer
+ # and need more output from the model
+ done = True
+ content = tc_text = ""
+
+ if start_token_pos < 0:
+ # just streaming text so far
+ if start_token_pos == -2:
+ # There is a partial match
+ content = delta_text[: start_match.start()]
+ self.look_ahead = delta_text[start_match.start() :]
+ else:
+ content = delta_text
+
+ elif not self.in_tc:
+ # we're entering a new tool call
+ self.in_tc = True
+
+ content = delta_text[:start_token_pos]
+ if end_token_pos > 0:
+ self.start_in_tc = False
+ tc_text = delta_text[start_token_end:end_token_pos]
+ self.look_ahead = delta_text[end_token_end:]
+ done = False # There could be more content already buffered
+ else:
+ self.look_ahead = delta_text[start_token_pos:]
+
+ elif end_token_pos < 0:
+ # we're in between the start and the end token
+ assert self.in_tc
+ self.look_ahead = delta_text
+ else:
+ # We have found the end
+ assert self.in_tc
+ tc_text = delta_text[start_token_end:end_token_pos]
+ self.in_tc = False
+ self.look_ahead = delta_text[end_token_end:]
+ done = False # There could be more content already buffered
+ return done, content, tc_text
+
+ def extract_tool_calls_streaming(
+ self,
+ previous_text: str,
+ current_text: str,
+ delta_text: str,
+ previous_token_ids: Sequence[int],
+ current_token_ids: Sequence[int],
+ delta_token_ids: Sequence[int],
+ request: ChatCompletionRequest,
+ ) -> DeltaMessage | None:
+ try:
+ done = False
+ text_segments = list[str]()
+ tc_segments = list[str]()
+
+ while not done:
+ delta_text = self.look_ahead + delta_text
+ self.look_ahead = ""
+ done, content, tc_text = self._tool_extraction_step(delta_text)
+ if content:
+ text_segments.append(content)
+ if tc_text:
+ tc_segments.append(tc_text)
+ delta_text = ""
+
+ content, tool_call_funcs = self._collect_results(
+ text_segments, tc_segments, DeltaFunctionCall
+ )
+
+ delta_tool_calls = list[DeltaToolCall]()
+ for function in tool_call_funcs:
+ self.current_tool_id += 1
+ delta_tool_calls.append(
+ DeltaToolCall(
+ id=make_tool_call_id(),
+ type="function",
+ index=self.current_tool_id,
+ function=function.model_dump(exclude_none=True),
+ )
+ )
+ self.streamed_args_for_tool.append(function.arguments or "")
+
+ assert self.current_tool_id + 1 == len(self.prev_tool_call_arr)
+ assert self.current_tool_id + 1 == len(self.streamed_args_for_tool)
+
+ msg = DeltaMessage(content=content or None, tool_calls=delta_tool_calls)
+ if msg.content or msg.tool_calls:
+ return msg
+
+ except Exception:
+ logger.exception("Error trying to handle streaming tool call.")
+ return None