diff --git a/docs/features/tool_calling.md b/docs/features/tool_calling.md index fe95735b9..b590b33e9 100644 --- a/docs/features/tool_calling.md +++ b/docs/features/tool_calling.md @@ -219,7 +219,7 @@ Supported models: * `ibm-granite/granite-4.0-h-small` and other Granite 4.0 models - Recommended flags: `--tool-call-parser hermes` + Recommended flags: `--tool-call-parser granite4` * `ibm-granite/granite-3.0-8b-instruct` diff --git a/tests/entrypoints/openai/tool_parsers/test_granite4_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_granite4_tool_parser.py new file mode 100644 index 000000000..27e7a8c5d --- /dev/null +++ b/tests/entrypoints/openai/tool_parsers/test_granite4_tool_parser.py @@ -0,0 +1,360 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +import random +from typing import Any + +import openai +import pytest +from transformers import AutoTokenizer + +from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest +from vllm.entrypoints.openai.engine.protocol import ( + DeltaMessage, +) +from vllm.tool_parsers.granite4_tool_parser import Granite4ToolParser + +from ....utils import RemoteOpenAIServer + +MODEL = "ibm-granite/granite-4.0-h-tiny" + + +@pytest.fixture(scope="module") +def server(): + model = MODEL + args_for_model = [ + "--enforce-eager", + "--enable-auto-tool-choice", + "--tool-call-parser", + "granite4", + "--tokenizer", + "ibm-granite/granite-4.0-h-tiny", + "--max-model-len", + "4096", + "--max-num-seqs", + "2", + ] + with RemoteOpenAIServer(model, args_for_model, max_wait_seconds=480) as server: + yield server + + +def create_complex_input(create_string_args: bool): + coord_arg: dict | str = { + "coordinates": [[23.54, 43.1], [-12.2, 54.3], [4, 5]], + "coordinate_type": "latlong", + } + if create_string_args: + # test granite behavior + coord_arg = json.dumps(coord_arg) + return [ + {"name": "find_bbox", "arguments": coord_arg}, + { + "name": "get_stock_price", + "arguments": { + "symbol": "AAPL", + "start_date": "2021-01-01", + "end_date": "2021-12-31", + }, + }, + {"name": "find_bbox", "arguments": coord_arg}, + ] + + +def random_chunks(s: str, min_len: int, max_len: int): + chunks = [] + i = 0 + n = len(s) + + while i < n: + size = random.randint(min_len, max_len) + chunks.append(s[i : i + size]) + i += size + + return chunks + + +@pytest.fixture(scope="module") +def tokenizer(): + return AutoTokenizer.from_pretrained(MODEL) + + +# create a variety of input chunk sizes +@pytest.mark.parametrize( + "min_chunk, max_chunk", + [ + (1, 1), + (1, 2), + (5, 7), + (6, 20), + ], +) +def test_tool_call_parser_complex(min_chunk: int, max_chunk: int, tokenizer): + input_dicts = create_complex_input(True) + + formatted_tcs = [ + " " + json.dumps(call) + " " for call in input_dicts + ] + + text_messages = [ + "Here goes the bbox call: \n", + " Now the stock price call: \n ", + " Now another bbox call: \n ", + " See? I'm a helpful assistant.", + ] + + test_input = ( + text_messages[0] + + formatted_tcs[0] + + text_messages[1] + + formatted_tcs[1] + + text_messages[2] + + formatted_tcs[2] + + text_messages[3] + ) + + any_chat_request = ChatCompletionRequest( + seed=42, + model=MODEL, + messages=[], + ) + + parser = Granite4ToolParser(tokenizer=tokenizer) + + delta_messages = list[DeltaMessage]() + for text in random_chunks(test_input, min_chunk, max_chunk): + delta = parser.extract_tool_calls_streaming( + previous_text="", + current_text="", + delta_text=text, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=any_chat_request, + ) + if delta is not None: + delta_messages.append(delta) + + content = "" + tool_calls = list[dict[str, Any]]() + + current_name = "__start__" + current_args = "" + + for msg in delta_messages: + if msg.content: + content += msg.content + for tool_call in msg.tool_calls: + if delta_func := tool_call.function: + if delta_func.name is not None: + if current_name == "__start__": + current_name = delta_func.name + + if delta_func.name != current_name: + tool_calls.append( + { + "name": current_name, + "arguments": json.loads(current_args), + } + ) + current_name = delta_func.name + current_args = "" + + if delta_func.arguments: + current_args += delta_func.arguments + + if current_name != "__start__": + tool_calls.append({"name": current_name, "arguments": json.loads(current_args)}) + + assert content == "".join(text_messages) + assert tool_calls == create_complex_input(False) + + +tools = [ + { + "type": "function", + "function": { + "name": "get_acme_region_name_for_transaction_id", + "description": "Returns ACME transaction/transaction ID information" + " including ACME regions\n\nArgs:\n start_time " + "(str): Start date and time in datetime format " + '"%Y-%m-%dT%H:%M:%S.%f"\n end_time (str): End ' + "date and time in datetime format " + '"%Y-%m-%dT%H:%M:%S.%f"\n size (int, optional): ' + "Number of ACME Transaction IDs to return\n " + "order (str, optional): Sort by most run " + "transaction IDs. The value can be 'asc' for " + "ascending or 'desc' for descending\n " + "transaction_id (str, optional): ACME Transaction " + "ID to filter on\n acme_region (str, optional): " + "ACME Region to filter on\nReturns:\n - A " + "dictionary containing a list of ACME transaction " + "ids and the ACME regions they run in:\n {\n" + ' "Number of transaction IDs" : int,\n' + ' "Total transaction IDs available": int' + ',\n "ACME Transaction IDs": [\n ' + ' {\n "Transaction ID": ' + 'str,\n "Number of runs": int,\n' + ' "ACME Regions": [str],\n ' + " },\n ...\n ]," + '\n "Start time" : datetime,\n ' + ' "End time" : datetime,\n ' + ' "Order" : str\n }\n ' + " - If no ACME region found for transaction id, " + 'returns:\n {"Success": "No ACME region ' + 'found for transaction id."}\n - If an error ' + 'occurs, returns:\n {"Error": "{exception' + ' message}"}', + "parameters": { + "properties": { + "start_time": {}, + "end_time": {}, + "size": {"default": 500}, + "order": {"default": "desc"}, + "transaction_id": {"default": None}, + "acme_region": {"default": None}, + }, + "required": ["start_time", "end_time"], + "type": "object", + }, + }, + } +] + +tools2 = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "description": "The city and state, e.g. San Francisco, CA", + "type": "string", + } + }, + "required": ["location"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_stock_price", + "description": "Retrieves the current stock price for a given " + "ticker symbol. The ticker symbol must be a valid " + "symbol for a publicly traded company on a major US" + " stock exchange like NYSE or NASDAQ. The tool will" + " return the latest trade price in USD. It should " + "be used when the user asks about the current or " + "most recent price of a specific stock. It will not" + " provide any other information about the stock or" + " company.", + "parameters": { + "type": "object", + "properties": { + "ticker": { + "description": "The stock ticker symbol, e.g." + " AAPL for Apple Inc.", + "type": "string", + } + }, + }, + }, + }, +] + +messages = [ + { + "content": "\n\nSystem: You are a helpful, precise, and methodical AI" + " assistant that uses tool outputs provided inline.\nAlways" + " assume the current datetime is 2026-01-29T13:59:09.238901" + "+00:00.\n\nIf you receive a ToolMessage with `tool_call_id" + '` equal to "get_time_range" (or "time_range_tool"), you ' + "MUST:\n 1. Parse that JSON and use the values `start` and" + " `end` directly when calling other tools.\n 2. Do not " + "re-call or re-compute the time range.\n 3. Pass resolved " + "values (ISO strings) as arguments to any subsequent tool " + "(do not pass function metadata or placeholders).\n 4. If " + "a tool requires datetime objects rather than strings, " + "convert the ISO strings into language-native datetime " + "objects before invoking.\n\nAlways return fully resolved " + "arguments in correct types (e.g., ISO datetime strings or" + " datetime objects) and never include placeholders like " + '"".\n\n', + "role": "system", + }, + { + "content": "What are the transaction IDs that ran in the" + " ACME region A9345 over the last two months?", + "role": "user", + }, + { + "content": '["2026-01-26T09: 51: 55.467722Z", "2026-01-27T09: 51: 55.467722Z"]', + "role": "tool", + "tool_call_id": "time_range_tool", + }, +] +messages2 = [{"role": "user", "content": "What's stock price for IBM?"}] + +messages3 = [{"role": "user", "content": "What's the current weather in New York?"}] + + +def get_args(client: openai.OpenAI, _tools, _messages, _stop): + response = client.chat.completions.create( + model=MODEL, + messages=_messages, + temperature=0, + tools=_tools, + max_tokens=200, + stop=_stop, + tool_choice="auto", + ) + + return response.choices[0].message.tool_calls[0].function.arguments + + +async def get_args_streaming( + async_client: openai.AsyncOpenAI, _tools, _messages, _stop +): + stream = await async_client.chat.completions.create( + model=MODEL, + messages=_messages, + temperature=0, + tools=_tools, + max_tokens=200, + stop=_stop, + tool_choice="auto", + stream=True, + ) + full_call = [] + async for chunk in stream: + tc = chunk.choices[0].delta.tool_calls + if tc and tc[0].function.arguments: + full_call.append(tc[0].function.arguments) + return "".join(full_call) + + +async def run_scenario(server: RemoteOpenAIServer, _tools, _messages, _stop): + non_streaming = get_args(server.get_client(), _tools, _messages, _stop) + json.loads(non_streaming) # verify that it is json loadable + streaming = await get_args_streaming( + server.get_async_client(), _tools, _messages, _stop + ) + json.loads(streaming) + assert non_streaming == streaming, f"{non_streaming=}, {streaming=}" + + +@pytest.mark.asyncio +async def test_stop_sequence_interference(server: RemoteOpenAIServer): + print("Testing scenario 1") + await run_scenario(server, tools, messages, "veroniqueprattyushveroniqueprattyush") + + print("Testing scenario 2") + await run_scenario( + server, tools2, messages2, "veroniqueprattyushveroniqueprattyush" + ) + + print("Testing scenario 3") + await run_scenario(server, tools2, messages3, "prattyush") diff --git a/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py index 626d845e1..be910fbb1 100644 --- a/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py +++ b/tests/entrypoints/openai/tool_parsers/test_hermes_tool_parser.py @@ -3,29 +3,22 @@ import json +import openai import pytest +import pytest_asyncio +from huggingface_hub import snapshot_download +from typing_extensions import TypedDict from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ToolParser +from vllm.tool_parsers.granite4_tool_parser import Granite4ToolParser from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser from ....utils import RemoteOpenAIServer -MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" LORA_MODEL = "minpeter/LoRA-Llama-3.2-1B-tool-vllm-ci" -SERVER_ARGS = [ - "--enforce-eager", - "--enable-auto-tool-choice", - "--tool-call-parser", - "hermes", - "--enable-lora", - "--lora-modules", - f"{LORA_MODEL}={LORA_MODEL}", - "--tokenizer", - f"{LORA_MODEL}", -] - TOOLS = [ { "type": "function", @@ -50,6 +43,75 @@ TOOLS = [ } ] + +class ServerConfig(TypedDict, total=False): + model: str + arguments: list[str] + model_arg: str + tool_parser: ToolParser + + +CONFIGS: dict[str, ServerConfig] = { + "llama": { + "model": "meta-llama/Llama-3.2-1B-Instruct", + "arguments": [ + "--enforce-eager", + "--enable-auto-tool-choice", + "--tool-call-parser", + "hermes", + "--enable-lora", + "--lora-modules", + f"{LORA_MODEL}={LORA_MODEL}", + "--tokenizer", + f"{LORA_MODEL}", + ], + "model_arg": LORA_MODEL, + "tool_parser": Hermes2ProToolParser, + }, + "granite4": { + "model": "ibm-granite/granite-4.0-h-tiny", + "arguments": [ + "--enforce-eager", + "--enable-auto-tool-choice", + "--tool-call-parser", + "granite4", + "--tokenizer", + "ibm-granite/granite-4.0-h-tiny", + "--max-model-len", + "4096", + "--max-num-seqs", + "2", + ], + "model_arg": "ibm-granite/granite-4.0-h-tiny", + "tool_parser": Granite4ToolParser, + }, +} + + +# for each server config, download the model and return the config +@pytest.fixture(scope="session", params=CONFIGS.keys()) +def server_config(request): + config = CONFIGS[request.param] + + # download model and tokenizer using transformers + snapshot_download(config["model"]) + yield CONFIGS[request.param] + + +@pytest.fixture(scope="module") +def server(request, server_config: ServerConfig): + model = server_config["model"] + args_for_model = server_config["arguments"] + with RemoteOpenAIServer(model, args_for_model, max_wait_seconds=480) as server: + yield server + + +@pytest_asyncio.fixture +async def client(server: RemoteOpenAIServer): + async with server.get_async_client() as async_client: + yield async_client + + PRODUCT_TOOLS = [ { "type": "function", @@ -87,186 +149,182 @@ PRODUCT_MESSAGES = [ @pytest.mark.asyncio -async def test_non_streaming_tool_call(): +async def test_non_streaming_tool_call( + client: openai.AsyncOpenAI, server_config: ServerConfig +): """Test tool call in non-streaming mode.""" - with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server: - client = server.get_async_client() - response = await client.chat.completions.create( - model=LORA_MODEL, - messages=MESSAGES, - tools=TOOLS, - tool_choice="auto", - temperature=0.0, - ) + response = await client.chat.completions.create( + model=server_config["model_arg"], + messages=MESSAGES, + tools=TOOLS, + tool_choice="auto", + temperature=0.0, + ) - assert response.choices - choice = response.choices[0] - message = choice.message + assert response.choices + choice = response.choices[0] + message = choice.message - assert choice.finish_reason == "tool_calls" - assert message.tool_calls is not None + assert choice.finish_reason == "tool_calls" + assert message.tool_calls is not None - tool_call = message.tool_calls[0] - assert tool_call.type == "function" - assert tool_call.function.name == "get_current_weather" + tool_call = message.tool_calls[0] + assert tool_call.type == "function" + assert tool_call.function.name == "get_current_weather" - arguments = json.loads(tool_call.function.arguments) - assert "location" in arguments - assert "Boston" in arguments["location"] - print("\n[Non-Streaming Test Passed]") - print(f"Tool Call: {tool_call.function.name}") - print(f"Arguments: {arguments}") + arguments = json.loads(tool_call.function.arguments) + assert "location" in arguments + assert "Boston" in arguments["location"] + print("\n[Non-Streaming Test Passed]") + print(f"Tool Call: {tool_call.function.name}") + print(f"Arguments: {arguments}") @pytest.mark.asyncio -async def test_streaming_tool_call(): +async def test_streaming_tool_call( + client: openai.AsyncOpenAI, server_config: ServerConfig +): """Test tool call in streaming mode.""" - with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server: - client = server.get_async_client() - stream = await client.chat.completions.create( - model=LORA_MODEL, - messages=MESSAGES, - tools=TOOLS, - tool_choice="auto", - temperature=0.0, - stream=True, - ) + stream = await client.chat.completions.create( + model=server_config["model_arg"], + messages=MESSAGES, + tools=TOOLS, + tool_choice="auto", + temperature=0.0, + stream=True, + ) - tool_call_chunks = {} - async for chunk in stream: - if not chunk.choices: - continue + tool_call_chunks = {} + async for chunk in stream: + if not chunk.choices: + continue - delta = chunk.choices[0].delta - if not delta or not delta.tool_calls: - continue + delta = chunk.choices[0].delta + if not delta or not delta.tool_calls: + continue - for tool_chunk in delta.tool_calls: - index = tool_chunk.index - if index not in tool_call_chunks: - tool_call_chunks[index] = {"name": "", "arguments": ""} + for tool_chunk in delta.tool_calls: + index = tool_chunk.index + if index not in tool_call_chunks: + tool_call_chunks[index] = {"name": "", "arguments": ""} - if tool_chunk.function.name: - tool_call_chunks[index]["name"] += tool_chunk.function.name - if tool_chunk.function.arguments: - tool_call_chunks[index]["arguments"] += ( - tool_chunk.function.arguments - ) + if tool_chunk.function.name: + tool_call_chunks[index]["name"] += tool_chunk.function.name + if tool_chunk.function.arguments: + tool_call_chunks[index]["arguments"] += tool_chunk.function.arguments - assert len(tool_call_chunks) == 1 - reconstructed_tool_call = tool_call_chunks[0] + assert len(tool_call_chunks) == 1 + reconstructed_tool_call = tool_call_chunks[0] - assert reconstructed_tool_call["name"] == "get_current_weather" + assert reconstructed_tool_call["name"] == "get_current_weather" - arguments = json.loads(reconstructed_tool_call["arguments"]) - assert "location" in arguments - assert "Boston" in arguments["location"] - print("\n[Streaming Test Passed]") - print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}") - print(f"Reconstructed Arguments: {arguments}") + arguments = json.loads(reconstructed_tool_call["arguments"]) + assert "location" in arguments + assert "Boston" in arguments["location"] + print("\n[Streaming Test Passed]") + print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}") + print(f"Reconstructed Arguments: {arguments}") @pytest.mark.asyncio -async def test_non_streaming_product_tool_call(): +async def test_non_streaming_product_tool_call( + client: openai.AsyncOpenAI, server_config: ServerConfig +): """Test tool call integer and boolean parameters in non-streaming mode.""" - with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server: - client = server.get_async_client() - response = await client.chat.completions.create( - model=LORA_MODEL, - messages=PRODUCT_MESSAGES, - tools=PRODUCT_TOOLS, - tool_choice="auto", - temperature=0.66, - ) + response = await client.chat.completions.create( + model=server_config["model_arg"], + messages=PRODUCT_MESSAGES, + tools=PRODUCT_TOOLS, + tool_choice="auto", + temperature=0.66, + ) - assert response.choices - choice = response.choices[0] - message = choice.message + assert response.choices + choice = response.choices[0] + message = choice.message - assert choice.finish_reason == "tool_calls" - assert message.tool_calls is not None + assert choice.finish_reason == "tool_calls" + assert message.tool_calls is not None - tool_call = message.tool_calls[0] - assert tool_call.type == "function" - assert tool_call.function.name == "get_product_info" + tool_call = message.tool_calls[0] + assert tool_call.type == "function" + assert tool_call.function.name == "get_product_info" - arguments = json.loads(tool_call.function.arguments) - assert "product_id" in arguments - assert "inserted" in arguments + arguments = json.loads(tool_call.function.arguments) + assert "product_id" in arguments + assert "inserted" in arguments - product_id = arguments.get("product_id") - inserted = arguments.get("inserted") + product_id = arguments.get("product_id") + inserted = arguments.get("inserted") - assert isinstance(product_id, int) - assert product_id == 7355608 - assert isinstance(inserted, bool) - assert inserted is True + assert isinstance(product_id, int) + assert product_id == 7355608 + assert isinstance(inserted, bool) + assert inserted is True - print("\n[Non-Streaming Product Test Passed]") - print(f"Tool Call: {tool_call.function.name}") - print(f"Arguments: {arguments}") + print("\n[Non-Streaming Product Test Passed]") + print(f"Tool Call: {tool_call.function.name}") + print(f"Arguments: {arguments}") @pytest.mark.asyncio -async def test_streaming_product_tool_call(): +async def test_streaming_product_tool_call( + client: openai.AsyncOpenAI, server_config: ServerConfig +): """Test tool call integer and boolean parameters in streaming mode.""" - with RemoteOpenAIServer(MODEL_NAME, SERVER_ARGS) as server: - client = server.get_async_client() - stream = await client.chat.completions.create( - model=LORA_MODEL, - messages=PRODUCT_MESSAGES, - tools=PRODUCT_TOOLS, - tool_choice="auto", - temperature=0.66, - stream=True, - ) + stream = await client.chat.completions.create( + model=server_config["model_arg"], + messages=PRODUCT_MESSAGES, + tools=PRODUCT_TOOLS, + tool_choice="auto", + temperature=0.66, + stream=True, + ) - tool_call_chunks = {} - async for chunk in stream: - if not chunk.choices: - continue + tool_call_chunks = {} + async for chunk in stream: + if not chunk.choices: + continue - delta = chunk.choices[0].delta - if not delta or not delta.tool_calls: - continue + delta = chunk.choices[0].delta + if not delta or not delta.tool_calls: + continue - for tool_chunk in delta.tool_calls: - index = tool_chunk.index - if index not in tool_call_chunks: - tool_call_chunks[index] = {"name": "", "arguments": ""} + for tool_chunk in delta.tool_calls: + index = tool_chunk.index + if index not in tool_call_chunks: + tool_call_chunks[index] = {"name": "", "arguments": ""} - if tool_chunk.function.name: - tool_call_chunks[index]["name"] += tool_chunk.function.name - if tool_chunk.function.arguments: - tool_call_chunks[index]["arguments"] += ( - tool_chunk.function.arguments - ) + if tool_chunk.function.name: + tool_call_chunks[index]["name"] += tool_chunk.function.name + if tool_chunk.function.arguments: + tool_call_chunks[index]["arguments"] += tool_chunk.function.arguments - assert len(tool_call_chunks) == 1 - reconstructed_tool_call = tool_call_chunks[0] + assert len(tool_call_chunks) == 1 + reconstructed_tool_call = tool_call_chunks[0] - assert reconstructed_tool_call["name"] == "get_product_info" + assert reconstructed_tool_call["name"] == "get_product_info" - arguments = json.loads(reconstructed_tool_call["arguments"]) - assert "product_id" in arguments - assert "inserted" in arguments + arguments = json.loads(reconstructed_tool_call["arguments"]) + assert "product_id" in arguments + assert "inserted" in arguments - # Handle type coercion for streaming test as well - product_id = arguments.get("product_id") - inserted = arguments.get("inserted") + # Handle type coercion for streaming test as well + product_id = arguments.get("product_id") + inserted = arguments.get("inserted") - assert isinstance(product_id, int) - assert product_id == 7355608 - assert isinstance(inserted, bool) - assert inserted is True + assert isinstance(product_id, int) + assert product_id == 7355608 + assert isinstance(inserted, bool) + assert inserted is True - print("\n[Streaming Product Test Passed]") - print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}") - print(f"Reconstructed Arguments: {arguments}") + print("\n[Streaming Product Test Passed]") + print(f"Reconstructed Tool Call: {reconstructed_tool_call['name']}") + print(f"Reconstructed Arguments: {arguments}") @pytest.fixture @@ -276,9 +334,10 @@ def qwen_tokenizer() -> TokenizerLike: return get_tokenizer("Qwen/Qwen3-32B") -@pytest.fixture -def hermes_parser(qwen_tokenizer: TokenizerLike) -> Hermes2ProToolParser: - return Hermes2ProToolParser(qwen_tokenizer) +@pytest.fixture(params=CONFIGS.keys()) +def hermes_parser(request, qwen_tokenizer: TokenizerLike) -> ToolParser: + config = CONFIGS[request.param] + return config["tool_parser"](qwen_tokenizer) @pytest.fixture @@ -292,7 +351,7 @@ def any_chat_request() -> ChatCompletionRequest: def test_hermes_parser_streaming_just_forward_text( qwen_tokenizer: TokenizerLike, - hermes_parser: Hermes2ProToolParser, + hermes_parser: ToolParser, any_chat_request: ChatCompletionRequest, ) -> None: text = """This is some prior text that has nothing to do with tool calling.""" @@ -324,7 +383,7 @@ def test_hermes_parser_streaming_just_forward_text( def test_hermes_parser_streaming_failure_case_bug_19056( qwen_tokenizer: TokenizerLike, - hermes_parser: Hermes2ProToolParser, + hermes_parser: ToolParser, any_chat_request: ChatCompletionRequest, ) -> None: text = """ @@ -358,7 +417,7 @@ def test_hermes_parser_streaming_failure_case_bug_19056( def test_hermes_parser_streaming( qwen_tokenizer: TokenizerLike, - hermes_parser: Hermes2ProToolParser, + hermes_parser: ToolParser, any_chat_request: ChatCompletionRequest, ) -> None: text = '\ @@ -387,16 +446,20 @@ def test_hermes_parser_streaming( delta_messages.append(delta) print(delta_messages) assert delta_messages[0].tool_calls[0].function.name == "get_current_temperature" - tool_call_args = "".join( - delta.tool_calls[0].function.arguments or "" for delta in delta_messages - ) - assert tool_call_args == ( - '{"location":"San Francisco, California, United States", "unit": "celsius"}' + # load to normalize whitespace + tool_call_args = json.loads( + "".join( + delta.tool_calls[0].function.arguments or "" for delta in delta_messages + ) ) + assert tool_call_args == { + "location": "San Francisco, California, United States", + "unit": "celsius", + } def test_hermes_parser_non_streaming_no_tool_call( - hermes_parser: Hermes2ProToolParser, + hermes_parser: ToolParser, any_chat_request: ChatCompletionRequest, ) -> None: text = """This is not a tool call.""" @@ -410,7 +473,7 @@ def test_hermes_parser_non_streaming_no_tool_call( def test_hermes_parser_non_streaming_tool_call_between_tags( - hermes_parser: Hermes2ProToolParser, + hermes_parser: ToolParser, any_chat_request: ChatCompletionRequest, ) -> None: text = """ @@ -428,9 +491,12 @@ def test_hermes_parser_non_streaming_tool_call_between_tags( def test_hermes_parser_non_streaming_tool_call_until_eos( - hermes_parser: Hermes2ProToolParser, + hermes_parser: ToolParser, any_chat_request: ChatCompletionRequest, ) -> None: + if isinstance(hermes_parser, Granite4ToolParser): + pytest.skip(reason="The Granite4 tool parser enforces a complete response") + text = """ {"name": "final_answer", "arguments": {"trigger": true}}""" tool_call = hermes_parser.extract_tool_calls( @@ -445,7 +511,7 @@ def test_hermes_parser_non_streaming_tool_call_until_eos( def test_hermes_parser_non_streaming_tool_call_invalid_json( - hermes_parser: Hermes2ProToolParser, + hermes_parser: ToolParser, any_chat_request: ChatCompletionRequest, ) -> None: # Missing closing brace to trigger exception diff --git a/vllm/tool_parsers/__init__.py b/vllm/tool_parsers/__init__.py index c1a39f2af..f480a635c 100644 --- a/vllm/tool_parsers/__init__.py +++ b/vllm/tool_parsers/__init__.py @@ -54,6 +54,10 @@ _TOOL_PARSERS_TO_REGISTER = { "granite_tool_parser", "GraniteToolParser", ), + "granite4": ( + "granite4_tool_parser", + "Granite4ToolParser", + ), "hermes": ( "hermes_tool_parser", "Hermes2ProToolParser", diff --git a/vllm/tool_parsers/granite4_tool_parser.py b/vllm/tool_parsers/granite4_tool_parser.py new file mode 100644 index 000000000..693c4dc8f --- /dev/null +++ b/vllm/tool_parsers/granite4_tool_parser.py @@ -0,0 +1,252 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from collections.abc import Sequence +from typing import Any, Protocol, TypeVar + +import regex as re + +from vllm.entrypoints.chat_utils import make_tool_call_id +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, +) +from vllm.entrypoints.openai.engine.protocol import ( + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) +from vllm.logger import init_logger +from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ( + ToolParser, +) + +logger = init_logger(__name__) + + +def dump_args(args: None | dict[str, Any] | str) -> str | None: + if args is None or isinstance(args, str): + return args + else: + return json.dumps(args, ensure_ascii=False) + + +class _FunctionCallCtor(Protocol): + def __init__(self, *, name: str, arguments: str | None): ... + + +FuncT = TypeVar("FuncT", bound=_FunctionCallCtor) + + +class Granite4ToolParser(ToolParser): + def __init__(self, tokenizer: TokenizerLike): + super().__init__(tokenizer) + + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id: int = -1 + self.streamed_args_for_tool = list[str]() + + self.look_ahead = "" + self.in_tc = False + + self.tc_start = "" + self.tc_end = "" + self.start_regex = re.compile(self.tc_start) + self.end_regex = re.compile(self.tc_end) + + def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest: + request = super().adjust_request(request) + if request.tools and request.tool_choice != "none": + # do not skip special tokens because the tool_call tokens are + # marked "special" in some models. Since they are skipped + # prior to the call to the tool parser, it breaks tool calling. + request.skip_special_tokens = False + return request + + def _collect_results( + self, text_segments: list[str], tc_segments: list[str], cls: type[FuncT] + ) -> tuple[str, list[FuncT]]: + tool_calls_json: list[dict[str, Any]] = [ + json.loads(tc_text) for tc_text in tc_segments + ] + tool_calls = [] + for tc in tool_calls_json: + assert isinstance(tc, dict) + self.prev_tool_call_arr.append(tc) + tool_calls.append( + cls( + name=tc["name"], + arguments=dump_args(tc["arguments"]), + ) + ) + return "".join(text_segments), tool_calls + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + msg = ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + try: + delimiters = [("TC_START", self.tc_start), ("TC_END", self.tc_end)] + pattern = "|".join(f"(?P<{name}>{pattern})" for name, pattern in delimiters) + regex = re.compile(pattern) + + text_segments = list[str]() + tc_segments = list[str]() + last_cut_loc = 0 + + for match in regex.finditer(model_output): + match_type = match.lastgroup + if match_type == "TC_START": + assert not self.in_tc, "Two tool call start tokens found in a row" + if preceding_text := model_output[last_cut_loc : match.start()]: + text_segments.append(preceding_text) + self.in_tc = True + elif match_type == "TC_END": + assert self.in_tc, ( + "Tool call end token found without corresponding start token" + ) + tool_text = model_output[last_cut_loc : match.start()] + assert tool_text, ( + "Expected the model to generate text between tool call tokens" + ) + tc_segments.append(tool_text) + self.in_tc = False + else: + raise ValueError("Unexpected match") + last_cut_loc = match.end() + assert not self.in_tc, "The model generated an incomplete tool call" + if final_text := model_output[last_cut_loc:]: + text_segments.append(final_text) + + content, tool_call_funcs = self._collect_results( + text_segments, tc_segments, FunctionCall + ) + tool_calls = [ + ToolCall( + type="function", + function=func, + ) + for func in tool_call_funcs + ] + msg.tools_called = bool(tool_calls) + msg.tool_calls = tool_calls + msg.content = content or None + except Exception: + logger.exception("Error in extracting tool call from response.") + return msg + + def _tool_extraction_step( + self, + delta_text: str, + ) -> tuple[bool, str, str]: + start_token_pos = start_token_end = end_token_pos = end_token_end = -1 + + if start_match := self.start_regex.search(delta_text, partial=True): + if not start_match.partial: + start_token_pos, start_token_end = start_match.span() + elif start_match.end() > start_match.start(): + start_token_pos = -2 + + if end_match := self.end_regex.search(delta_text): + end_token_pos, end_token_end = end_match.span() + + # Done means that we've exhausted the current buffer + # and need more output from the model + done = True + content = tc_text = "" + + if start_token_pos < 0: + # just streaming text so far + if start_token_pos == -2: + # There is a partial match + content = delta_text[: start_match.start()] + self.look_ahead = delta_text[start_match.start() :] + else: + content = delta_text + + elif not self.in_tc: + # we're entering a new tool call + self.in_tc = True + + content = delta_text[:start_token_pos] + if end_token_pos > 0: + self.start_in_tc = False + tc_text = delta_text[start_token_end:end_token_pos] + self.look_ahead = delta_text[end_token_end:] + done = False # There could be more content already buffered + else: + self.look_ahead = delta_text[start_token_pos:] + + elif end_token_pos < 0: + # we're in between the start and the end token + assert self.in_tc + self.look_ahead = delta_text + else: + # We have found the end + assert self.in_tc + tc_text = delta_text[start_token_end:end_token_pos] + self.in_tc = False + self.look_ahead = delta_text[end_token_end:] + done = False # There could be more content already buffered + return done, content, tc_text + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> DeltaMessage | None: + try: + done = False + text_segments = list[str]() + tc_segments = list[str]() + + while not done: + delta_text = self.look_ahead + delta_text + self.look_ahead = "" + done, content, tc_text = self._tool_extraction_step(delta_text) + if content: + text_segments.append(content) + if tc_text: + tc_segments.append(tc_text) + delta_text = "" + + content, tool_call_funcs = self._collect_results( + text_segments, tc_segments, DeltaFunctionCall + ) + + delta_tool_calls = list[DeltaToolCall]() + for function in tool_call_funcs: + self.current_tool_id += 1 + delta_tool_calls.append( + DeltaToolCall( + id=make_tool_call_id(), + type="function", + index=self.current_tool_id, + function=function.model_dump(exclude_none=True), + ) + ) + self.streamed_args_for_tool.append(function.arguments or "") + + assert self.current_tool_id + 1 == len(self.prev_tool_call_arr) + assert self.current_tool_id + 1 == len(self.streamed_args_for_tool) + + msg = DeltaMessage(content=content or None, tool_calls=delta_tool_calls) + if msg.content or msg.tool_calls: + return msg + + except Exception: + logger.exception("Error trying to handle streaming tool call.") + return None