[Bugfix] Preserve tool call id/type/name in streaming finish chunk (#31438)
Signed-off-by: amittell <mittell@me.com> Signed-off-by: Alex Mittell <mittell@me.com>
This commit is contained in:
@@ -1506,3 +1506,142 @@ async def test_tool_choice_validation_without_parser():
|
||||
assert isinstance(response_named, ErrorResponse)
|
||||
assert "tool_choice" in response_named.error.message
|
||||
assert "--tool-call-parser" in response_named.error.message
|
||||
|
||||
|
||||
class TestCreateRemainingArgsDelta:
|
||||
"""Tests for _create_remaining_args_delta helper function.
|
||||
|
||||
This helper is used when streaming tool calls to preserve id/type/name
|
||||
fields in the finish chunk, which would otherwise be lost.
|
||||
"""
|
||||
|
||||
def test_preserves_id_type_name(self):
|
||||
"""Test that id, type, and name are preserved from original delta."""
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
|
||||
original_delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=0,
|
||||
id="call_abc123",
|
||||
type="function",
|
||||
function=DeltaFunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"location": "Paris"}',
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = OpenAIServingChat._create_remaining_args_delta(
|
||||
original_delta, '", "unit": "celsius"}', 0
|
||||
)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
tc = result.tool_calls[0]
|
||||
assert tc.index == 0
|
||||
assert tc.id == "call_abc123"
|
||||
assert tc.type == "function"
|
||||
assert tc.function.name == "get_weather"
|
||||
assert tc.function.arguments == '", "unit": "celsius"}'
|
||||
|
||||
def test_matches_by_index(self):
|
||||
"""Test that the correct tool call is matched by index."""
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
|
||||
original_delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=0,
|
||||
id="call_first",
|
||||
type="function",
|
||||
function=DeltaFunctionCall(name="func_a", arguments="{}"),
|
||||
),
|
||||
DeltaToolCall(
|
||||
index=1,
|
||||
id="call_second",
|
||||
type="function",
|
||||
function=DeltaFunctionCall(name="func_b", arguments="{}"),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
result = OpenAIServingChat._create_remaining_args_delta(
|
||||
original_delta, '{"extra": true}', 1
|
||||
)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
tc = result.tool_calls[0]
|
||||
assert tc.index == 1
|
||||
assert tc.id == "call_second"
|
||||
assert tc.function.name == "func_b"
|
||||
|
||||
def test_no_matching_tool_call(self):
|
||||
"""Test graceful handling when no matching tool call is found."""
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
|
||||
original_delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=0,
|
||||
id="call_zero",
|
||||
type="function",
|
||||
function=DeltaFunctionCall(name="func", arguments="{}"),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = OpenAIServingChat._create_remaining_args_delta(
|
||||
original_delta, '{"arg": 1}', 5
|
||||
)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
tc = result.tool_calls[0]
|
||||
assert tc.index == 5
|
||||
assert tc.id is None
|
||||
assert tc.type is None
|
||||
assert tc.function.name is None
|
||||
assert tc.function.arguments == '{"arg": 1}'
|
||||
|
||||
def test_function_is_none(self):
|
||||
"""Test handling when original tool call has no function."""
|
||||
from vllm.entrypoints.openai.protocol import DeltaMessage, DeltaToolCall
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
|
||||
original_delta = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=0,
|
||||
id="call_nofunc",
|
||||
type="function",
|
||||
function=None,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = OpenAIServingChat._create_remaining_args_delta(
|
||||
original_delta, '{"data": "value"}', 0
|
||||
)
|
||||
|
||||
assert len(result.tool_calls) == 1
|
||||
tc = result.tool_calls[0]
|
||||
assert tc.index == 0
|
||||
assert tc.id == "call_nofunc"
|
||||
assert tc.type == "function"
|
||||
assert tc.function.name is None
|
||||
assert tc.function.arguments == '{"data": "value"}'
|
||||
|
||||
@@ -1208,15 +1208,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# check to see if there's anything left to stream
|
||||
remaining_call = expected_call.replace(actual_call, "", 1)
|
||||
# set that as a delta message
|
||||
delta_message = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=index,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=remaining_call
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
delta_message = self._create_remaining_args_delta(
|
||||
delta_message, remaining_call, index
|
||||
)
|
||||
|
||||
# Send the finish response for each request.n only once
|
||||
@@ -1803,6 +1796,35 @@ class OpenAIServingChat(OpenAIServing):
|
||||
and delta_message.tool_calls[0].function.arguments is not None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_remaining_args_delta(
|
||||
delta_message: DeltaMessage,
|
||||
remaining_call: str,
|
||||
index: int,
|
||||
) -> DeltaMessage:
|
||||
"""
|
||||
Create a delta message for remaining tool arguments, preserving
|
||||
id/type/name from the original delta.
|
||||
"""
|
||||
original_tc = next(
|
||||
(tc for tc in delta_message.tool_calls if tc.index == index),
|
||||
None,
|
||||
)
|
||||
original_fn = original_tc.function if original_tc else None
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=index,
|
||||
id=original_tc.id if original_tc else None,
|
||||
type=original_tc.type if original_tc else None,
|
||||
function=DeltaFunctionCall(
|
||||
name=original_fn.name if original_fn else None,
|
||||
arguments=remaining_call,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def _make_request_with_harmony(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
|
||||
Reference in New Issue
Block a user