[Bugfix] Add support for <tool_call> format in streaming mode for XLAM Tool Parser (#22769)

Signed-off-by: Devon Peroutky <devon@kindo.ai>
This commit is contained in:
Code Jesus
2025-08-31 23:07:54 -07:00
committed by GitHub
parent 1cb39dbcdd
commit 422e793fa6
2 changed files with 296 additions and 24 deletions

View File

@@ -2,12 +2,17 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Generator
from typing import Optional
import pytest
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage, FunctionCall,
ToolCall)
from vllm.entrypoints.openai.tool_parsers import xLAMToolParser
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.transformers_utils.detokenizer import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
# Use a common model that is likely to be available
MODEL = "Salesforce/Llama-xLAM-2-8B-fc-r"
@@ -36,6 +41,56 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall],
assert actual_tool_call.function == expected_tool_call.function
def stream_delta_message_generator(
xlam_tool_parser: xLAMToolParser,
xlam_tokenizer: AnyTokenizer,
model_output: str,
request: Optional[ChatCompletionRequest] = None,
) -> Generator[DeltaMessage, None, None]:
all_token_ids = xlam_tokenizer.encode(model_output,
add_special_tokens=False)
previous_text = ""
previous_tokens = None
prefix_offset = 0
read_offset = 0
for i, delta_token in enumerate(all_token_ids):
delta_token_ids = [delta_token]
previous_token_ids = all_token_ids[:i]
current_token_ids = all_token_ids[:i + 1]
(new_tokens, delta_text, new_prefix_offset,
new_read_offset) = (detokenize_incrementally(
tokenizer=xlam_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=False,
spaces_between_special_tokens=True,
))
current_text = previous_text + delta_text
delta_message = xlam_tool_parser.extract_tool_calls_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
request=request,
)
if delta_message:
yield delta_message
previous_text = current_text
previous_tokens = (previous_tokens +
new_tokens if previous_tokens else new_tokens)
prefix_offset = new_prefix_offset
read_offset = new_read_offset
def test_extract_tool_calls_no_tools(xlam_tool_parser):
model_output = "This is a test"
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
@@ -51,6 +106,7 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser):
"single_tool_with_think_tag",
"single_tool_with_json_code_block",
"single_tool_with_tool_calls_tag",
"single_tool_with_tool_call_xml_tags",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
@@ -118,6 +174,20 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser):
],
"I'll check the weather for you.",
),
(
"""I'll help you check the weather.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
"I'll help you check the weather.",
),
],
)
def test_extract_tool_calls(xlam_tool_parser, model_output,
@@ -245,3 +315,147 @@ def test_streaming_with_list_structure(xlam_tool_parser):
assert hasattr(result, "tool_calls")
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "get_current_weather"
@pytest.mark.parametrize(
ids=[
"parallel_tool_calls",
"single_tool_with_think_tag",
"single_tool_with_json_code_block",
"single_tool_with_tool_calls_tag",
"single_tool_with_tool_call_xml_tags",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
)),
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Orlando",
"state": "FL",
"unit": "fahrenheit",
}),
)),
],
"",
),
(
"""<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
"<think>I'll help you with that.</think>",
),
(
"""```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
"",
),
(
"""[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
"",
),
(
"""I can help with that.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
"I can help with that.",
),
],
)
def test_extract_tool_calls_streaming_incremental(
xlam_tool_parser,
xlam_tokenizer,
model_output,
expected_tool_calls,
expected_content,
):
"""Verify the XLAM Parser streaming behavior by verifying each chunk is as expected.""" # noqa: E501
request = ChatCompletionRequest(model=MODEL, messages=[], tools=[])
chunks = []
for delta_message in stream_delta_message_generator(
xlam_tool_parser, xlam_tokenizer, model_output, request):
chunks.append(delta_message)
# Should have multiple chunks
assert len(chunks) >= 3
# Should have a chunk with tool header (id, name, type) for the first tool call # noqa: E501
header_found = False
expected_first_tool = expected_tool_calls[0]
for chunk in chunks:
if chunk.tool_calls and chunk.tool_calls[0].id:
header_found = True
assert (chunk.tool_calls[0].function.name ==
expected_first_tool.function.name)
assert chunk.tool_calls[0].type == "function"
# Arguments may be empty initially or None
if chunk.tool_calls[0].function.arguments is not None:
# If present, should be empty string initially
assert chunk.tool_calls[0].function.arguments == ""
break
assert header_found
# Should have chunks with incremental arguments
arg_chunks = []
for chunk in chunks:
if (chunk.tool_calls and chunk.tool_calls[0].function.arguments
and chunk.tool_calls[0].function.arguments != ""
and chunk.tool_calls[0].index ==
0 # Only collect arguments from the first tool call
):
arg_chunks.append(chunk.tool_calls[0].function.arguments)
# Arguments should be streamed incrementally
assert len(arg_chunks) > 1
# Concatenated arguments should form valid JSON for the first tool call
full_args = "".join(arg_chunks)
parsed_args = json.loads(full_args)
expected_args = json.loads(expected_first_tool.function.arguments)
assert parsed_args == expected_args