[MODEL] Fix handling of multiple channels for gpt-oss with speculative decoding (#26291)
Signed-off-by: Aleksandr Samarin <astrlrd@nebius.com> Signed-off-by: southfreebird <yvorott@gmail.com> Co-authored-by: southfreebird <yvorott@gmail.com>
This commit is contained in:
committed by
GitHub
parent
3a612322eb
commit
d084e9fca7
@@ -35,6 +35,7 @@ from .utils import (
|
||||
)
|
||||
|
||||
GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b"
|
||||
GPT_OSS_SPECULATOR_NAME = "RedHatAI/gpt-oss-20b-speculator.eagle3"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -66,7 +67,8 @@ def exclude_tools_when_tool_choice_none(request) -> bool:
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_server_args(
|
||||
with_tool_parser: bool, exclude_tools_when_tool_choice_none: bool
|
||||
with_tool_parser: bool,
|
||||
exclude_tools_when_tool_choice_none: bool,
|
||||
):
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
@@ -76,7 +78,7 @@ def default_server_args(
|
||||
"--reasoning-parser",
|
||||
"openai_gptoss",
|
||||
"--gpu-memory-utilization",
|
||||
"0.8",
|
||||
"0.85",
|
||||
]
|
||||
if with_tool_parser:
|
||||
args.extend(
|
||||
@@ -91,327 +93,385 @@ def default_server_args(
|
||||
return args
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture(scope="class")
|
||||
def gptoss_server(default_server_args: list[str]):
|
||||
server_args = default_server_args + ["--attention-backend=TRITON_ATTN"]
|
||||
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, server_args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def gptoss_speculative_server(default_server_args: list[str]):
|
||||
server_args = default_server_args + [
|
||||
"--speculative-config",
|
||||
f'{{"model": "{GPT_OSS_SPECULATOR_NAME}", '
|
||||
f'"method": "eagle3", "num_speculative_tokens": 3}}',
|
||||
"--attention-backend=TRITON_ATTN",
|
||||
]
|
||||
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, server_args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def gptoss_client(gptoss_server):
|
||||
async with gptoss_server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpt_oss_chat_tool_call_streaming(
|
||||
gptoss_client: OpenAI, with_tool_parser: bool
|
||||
):
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
"state": {"type": "string"},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
@pytest_asyncio.fixture
|
||||
async def gptoss_speculative_client(gptoss_speculative_server):
|
||||
async with gptoss_speculative_server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
class TestGPTOSSChat:
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpt_oss_chat_tool_call_streaming(
|
||||
self, gptoss_client: OpenAI, with_tool_parser: bool
|
||||
):
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
"state": {"type": "string"},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["city", "state", "unit"],
|
||||
},
|
||||
"required": ["city", "state", "unit"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "What is the weather in Dallas, TX?"},
|
||||
]
|
||||
messages = [
|
||||
{"role": "user", "content": "What is the weather in Dallas, TX?"},
|
||||
]
|
||||
|
||||
stream = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages,
|
||||
tools=tools if with_tool_parser else None,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
name = None
|
||||
args_buf = ""
|
||||
content_buf = ""
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
if delta.tool_calls:
|
||||
tc = delta.tool_calls[0]
|
||||
if tc.function and tc.function.name:
|
||||
name = tc.function.name
|
||||
if tc.function and tc.function.arguments:
|
||||
args_buf += tc.function.arguments
|
||||
if getattr(delta, "content", None):
|
||||
content_buf += delta.content
|
||||
if with_tool_parser:
|
||||
assert name is not None
|
||||
assert len(args_buf) > 0
|
||||
else:
|
||||
assert name is None
|
||||
assert len(args_buf) == 0
|
||||
assert len(content_buf) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI, with_tool_parser: bool):
|
||||
if not with_tool_parser:
|
||||
pytest.skip("skip non-tool for multi-turn tests")
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
"state": {"type": "string"},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["city", "state", "unit"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "you are a helpful assistant"},
|
||||
{"role": "user", "content": "What is the weather in Dallas, TX with celsius?"},
|
||||
]
|
||||
|
||||
first = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
temperature=0.0,
|
||||
)
|
||||
first_msg = first.choices[0].message
|
||||
assert first_msg.tool_calls is not None and len(first_msg.tool_calls) > 0
|
||||
tc = first_msg.tool_calls[0]
|
||||
assert tc.function is not None and tc.function.name == "get_current_weather"
|
||||
args1 = tc.function.arguments
|
||||
assert args1 is not None and len(args1) > 0
|
||||
assert not first_msg.content
|
||||
|
||||
messages.append({"role": "assistant", "content": args1})
|
||||
messages.append(
|
||||
{"role": "user", "content": "Now convert to celsius and return JSON only"}
|
||||
)
|
||||
|
||||
second = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
temperature=0.0,
|
||||
)
|
||||
second_msg = second.choices[0].message
|
||||
assert (second_msg.content is not None and len(second_msg.content) > 0) or (
|
||||
second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpt_oss_tool_message_array_content(
|
||||
gptoss_client: OpenAI, with_tool_parser: bool
|
||||
):
|
||||
"""Test that tool messages support both string and array content formats."""
|
||||
if not with_tool_parser:
|
||||
pytest.skip("skip non-tool for array content tests")
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
"state": {"type": "string"},
|
||||
},
|
||||
"required": ["city", "state"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Test 1: Tool message with string content
|
||||
messages_string = [
|
||||
{"role": "user", "content": "What's the weather in Paris?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Paris", "state": "TX"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": "The weather in Paris, TX is sunny, 22°C"},
|
||||
]
|
||||
|
||||
response_string = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages_string,
|
||||
tools=tools,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
assert response_string is not None
|
||||
assert response_string.choices[0].message is not None
|
||||
|
||||
# Test 2: Tool message with array content
|
||||
messages_array = [
|
||||
{"role": "user", "content": "What's the weather in Dallas?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_456",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Dallas", "state": "TX"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [
|
||||
{"type": "text", "text": "f2e897a7-2705-4337-8193-2a8f57b81618"}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
response_array = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages_array,
|
||||
tools=tools,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
assert response_array is not None
|
||||
assert response_array.choices[0].message is not None
|
||||
|
||||
# Test 3: Tool message with multiple array content items
|
||||
messages_multi_array = [
|
||||
{"role": "user", "content": "Search for information"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_789",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Austin", "state": "TX"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [
|
||||
{"type": "text", "text": "Weather data: "},
|
||||
{"type": "text", "text": "Austin, TX - Partly cloudy, 25°C"},
|
||||
{"type": "text", "text": " with 60% humidity"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
response_multi_array = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages_multi_array,
|
||||
tools=tools,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
assert response_multi_array is not None
|
||||
assert response_multi_array.choices[0].message is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpt_oss_tool_choice_none(
|
||||
gptoss_client: OpenAI,
|
||||
with_tool_parser: bool,
|
||||
exclude_tools_when_tool_choice_none: bool,
|
||||
):
|
||||
if not (with_tool_parser and exclude_tools_when_tool_choice_none):
|
||||
pytest.skip(
|
||||
"skip tool_choice tests when non-tool or "
|
||||
"--exclude-tools-when-tool-choice-none not set"
|
||||
stream = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages,
|
||||
tools=tools if with_tool_parser else None,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
"state": {"type": "string"},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
name = None
|
||||
args_buf = ""
|
||||
content_buf = ""
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
if delta.tool_calls:
|
||||
tc = delta.tool_calls[0]
|
||||
if tc.function and tc.function.name:
|
||||
name = tc.function.name
|
||||
if tc.function and tc.function.arguments:
|
||||
args_buf += tc.function.arguments
|
||||
if getattr(delta, "content", None):
|
||||
content_buf += delta.content
|
||||
if with_tool_parser:
|
||||
assert name is not None
|
||||
assert len(args_buf) > 0
|
||||
else:
|
||||
assert name is None
|
||||
assert len(args_buf) == 0
|
||||
assert len(content_buf) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpt_oss_multi_turn_chat(
|
||||
self, gptoss_client: OpenAI, with_tool_parser: bool
|
||||
):
|
||||
if not with_tool_parser:
|
||||
pytest.skip("skip non-tool for multi-turn tests")
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
"state": {"type": "string"},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["city", "state", "unit"],
|
||||
},
|
||||
"required": ["city", "state", "unit"],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "you are a helpful assistant"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather in Dallas, TX with celsius?",
|
||||
},
|
||||
}
|
||||
]
|
||||
]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the temperature(in degrees Celsius) in Dallas?",
|
||||
},
|
||||
]
|
||||
first = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
temperature=0.0,
|
||||
)
|
||||
first_msg = first.choices[0].message
|
||||
assert first_msg.tool_calls is not None and len(first_msg.tool_calls) > 0
|
||||
tc = first_msg.tool_calls[0]
|
||||
assert tc.function is not None and tc.function.name == "get_current_weather"
|
||||
args1 = tc.function.arguments
|
||||
assert args1 is not None and len(args1) > 0
|
||||
assert not first_msg.content
|
||||
|
||||
tool_choice_auto = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
temperature=0.0,
|
||||
)
|
||||
msg = tool_choice_auto.choices[0].message
|
||||
assert len(msg.tool_calls) == 1
|
||||
messages.append({"role": "assistant", "content": args1})
|
||||
messages.append(
|
||||
{"role": "user", "content": "Now convert to celsius and return JSON only"}
|
||||
)
|
||||
|
||||
tool_choice_none = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="none",
|
||||
temperature=0.0,
|
||||
)
|
||||
second = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
temperature=0.0,
|
||||
)
|
||||
second_msg = second.choices[0].message
|
||||
assert (second_msg.content is not None and len(second_msg.content) > 0) or (
|
||||
second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0
|
||||
)
|
||||
|
||||
msg = tool_choice_none.choices[0].message
|
||||
assert len(msg.tool_calls) == 0
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpt_oss_tool_message_array_content(
|
||||
self, gptoss_client: OpenAI, with_tool_parser: bool
|
||||
):
|
||||
"""Test that tool messages support both string and array content formats."""
|
||||
if not with_tool_parser:
|
||||
pytest.skip("skip non-tool for array content tests")
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
"state": {"type": "string"},
|
||||
},
|
||||
"required": ["city", "state"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Test 1: Tool message with string content
|
||||
messages_string = [
|
||||
{"role": "user", "content": "What's the weather in Paris?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Paris", "state": "TX"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": "The weather in Paris, TX is sunny, 22°C"},
|
||||
]
|
||||
|
||||
response_string = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages_string,
|
||||
tools=tools,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
assert response_string is not None
|
||||
assert response_string.choices[0].message is not None
|
||||
|
||||
# Test 2: Tool message with array content
|
||||
messages_array = [
|
||||
{"role": "user", "content": "What's the weather in Dallas?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_456",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Dallas", "state": "TX"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [
|
||||
{"type": "text", "text": "f2e897a7-2705-4337-8193-2a8f57b81618"}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
response_array = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages_array,
|
||||
tools=tools,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
assert response_array is not None
|
||||
assert response_array.choices[0].message is not None
|
||||
|
||||
# Test 3: Tool message with multiple array content items
|
||||
messages_multi_array = [
|
||||
{"role": "user", "content": "Search for information"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_789",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"city": "Austin", "state": "TX"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": [
|
||||
{"type": "text", "text": "Weather data: "},
|
||||
{"type": "text", "text": "Austin, TX - Partly cloudy, 25°C"},
|
||||
{"type": "text", "text": " with 60% humidity"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
response_multi_array = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages_multi_array,
|
||||
tools=tools,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
assert response_multi_array is not None
|
||||
assert response_multi_array.choices[0].message is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpt_oss_tool_choice_none(
|
||||
self,
|
||||
gptoss_client: OpenAI,
|
||||
with_tool_parser: bool,
|
||||
exclude_tools_when_tool_choice_none: bool,
|
||||
):
|
||||
if not (with_tool_parser and exclude_tools_when_tool_choice_none):
|
||||
pytest.skip(
|
||||
"skip tool_choice tests when non-tool or "
|
||||
"--exclude-tools-when-tool-choice-none not set"
|
||||
)
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
"state": {"type": "string"},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["city", "state", "unit"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the temperature(in degrees Celsius) in Dallas?",
|
||||
},
|
||||
]
|
||||
|
||||
tool_choice_auto = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
temperature=0.0,
|
||||
)
|
||||
msg = tool_choice_auto.choices[0].message
|
||||
assert len(msg.tool_calls) == 1
|
||||
|
||||
tool_choice_none = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="none",
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
msg = tool_choice_none.choices[0].message
|
||||
assert len(msg.tool_calls) == 0
|
||||
|
||||
|
||||
class TestGPTOSSSpeculativeChat:
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpt_oss_speculative_reasoning_leakage(
|
||||
self,
|
||||
gptoss_speculative_client: OpenAI,
|
||||
with_tool_parser: bool,
|
||||
):
|
||||
if not with_tool_parser:
|
||||
pytest.skip("skip non-tool for array content tests")
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Calculate 2+2. Return the answer 4 only."},
|
||||
]
|
||||
|
||||
stream = await gptoss_speculative_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
content = ""
|
||||
reasoning_content = ""
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
if delta.content:
|
||||
content += delta.content
|
||||
|
||||
chunk_reasoning = getattr(delta, "reasoning", None)
|
||||
if chunk_reasoning:
|
||||
reasoning_content += delta.reasoning
|
||||
|
||||
assert len(reasoning_content) > 0, "No reasoning was generated."
|
||||
assert content.strip() == "4"
|
||||
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
|
||||
@@ -10,6 +10,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.stream_harmony import (
|
||||
TokenState,
|
||||
extract_harmony_streaming_delta,
|
||||
)
|
||||
|
||||
@@ -42,12 +43,14 @@ class TestExtractHarmonyStreamingDelta:
|
||||
def test_final_channel_returns_content_delta(self, delta_text, expected_content):
|
||||
"""Test that final channel returns a DeltaMessage with content."""
|
||||
parser = MockStreamableParser()
|
||||
|
||||
# Updated to use TokenState list
|
||||
token_states = [TokenState(channel="final", recipient=None, text=delta_text)]
|
||||
|
||||
delta_message, tools_streamed = extract_harmony_streaming_delta(
|
||||
harmony_parser=parser,
|
||||
cur_channel="final",
|
||||
cur_recipient=None,
|
||||
token_states=token_states,
|
||||
prev_recipient=None,
|
||||
delta_text=delta_text,
|
||||
include_reasoning=False,
|
||||
)
|
||||
|
||||
@@ -65,18 +68,19 @@ class TestExtractHarmonyStreamingDelta:
|
||||
def test_analysis_channel_reasoning(self, include_reasoning, expected_has_message):
|
||||
"""Test analysis channel respects include_reasoning flag."""
|
||||
parser = MockStreamableParser()
|
||||
text = "Let me think..."
|
||||
token_states = [TokenState(channel="analysis", recipient=None, text=text)]
|
||||
|
||||
delta_message, tools_streamed = extract_harmony_streaming_delta(
|
||||
harmony_parser=parser,
|
||||
cur_channel="analysis",
|
||||
cur_recipient=None,
|
||||
token_states=token_states,
|
||||
prev_recipient=None,
|
||||
delta_text="Let me think...",
|
||||
include_reasoning=include_reasoning,
|
||||
)
|
||||
|
||||
if expected_has_message:
|
||||
assert delta_message is not None
|
||||
assert delta_message.reasoning == "Let me think..."
|
||||
assert delta_message.reasoning == text
|
||||
else:
|
||||
assert delta_message is None
|
||||
assert tools_streamed is False
|
||||
@@ -88,12 +92,14 @@ class TestExtractHarmonyStreamingDelta:
|
||||
mock_make_tool_call_id.return_value = "call_test123"
|
||||
parser = MockStreamableParser()
|
||||
|
||||
token_states = [
|
||||
TokenState(channel=channel, recipient="functions.get_weather", text="")
|
||||
]
|
||||
|
||||
delta_message, tools_streamed = extract_harmony_streaming_delta(
|
||||
harmony_parser=parser,
|
||||
cur_channel=channel,
|
||||
cur_recipient="functions.get_weather",
|
||||
token_states=token_states,
|
||||
prev_recipient=None,
|
||||
delta_text="",
|
||||
include_reasoning=False,
|
||||
)
|
||||
|
||||
@@ -111,20 +117,25 @@ class TestExtractHarmonyStreamingDelta:
|
||||
def test_tool_call_argument_streaming(self, channel):
|
||||
"""Test streaming tool call arguments (same recipient)."""
|
||||
parser = MockStreamableParser()
|
||||
args_text = '{"location": "Paris"}'
|
||||
|
||||
token_states = [
|
||||
TokenState(
|
||||
channel=channel, recipient="functions.get_weather", text=args_text
|
||||
)
|
||||
]
|
||||
|
||||
delta_message, tools_streamed = extract_harmony_streaming_delta(
|
||||
harmony_parser=parser,
|
||||
cur_channel=channel,
|
||||
cur_recipient="functions.get_weather",
|
||||
token_states=token_states,
|
||||
prev_recipient="functions.get_weather",
|
||||
delta_text='{"location": "Paris"}',
|
||||
include_reasoning=False,
|
||||
)
|
||||
|
||||
assert delta_message is not None
|
||||
tool_call = delta_message.tool_calls[0]
|
||||
assert tool_call.id is None
|
||||
assert tool_call.function.arguments == '{"location": "Paris"}'
|
||||
assert tool_call.function.arguments == args_text
|
||||
assert tool_call.index == 0
|
||||
assert tools_streamed is True
|
||||
|
||||
@@ -133,12 +144,14 @@ class TestExtractHarmonyStreamingDelta:
|
||||
"""Test empty delta_text with same recipient returns None."""
|
||||
parser = MockStreamableParser()
|
||||
|
||||
token_states = [
|
||||
TokenState(channel=channel, recipient="functions.get_weather", text="")
|
||||
]
|
||||
|
||||
delta_message, tools_streamed = extract_harmony_streaming_delta(
|
||||
harmony_parser=parser,
|
||||
cur_channel=channel,
|
||||
cur_recipient="functions.get_weather",
|
||||
token_states=token_states,
|
||||
prev_recipient="functions.get_weather",
|
||||
delta_text="",
|
||||
include_reasoning=False,
|
||||
)
|
||||
|
||||
@@ -154,12 +167,14 @@ class TestExtractHarmonyStreamingDelta:
|
||||
]
|
||||
parser = MockStreamableParser(messages=messages)
|
||||
|
||||
token_states = [
|
||||
TokenState(channel="commentary", recipient="functions.tool2", text="args")
|
||||
]
|
||||
|
||||
delta_message, _ = extract_harmony_streaming_delta(
|
||||
harmony_parser=parser,
|
||||
cur_channel="commentary",
|
||||
cur_recipient="functions.tool2",
|
||||
token_states=token_states,
|
||||
prev_recipient="functions.tool2",
|
||||
delta_text="args",
|
||||
include_reasoning=False,
|
||||
)
|
||||
|
||||
@@ -173,15 +188,18 @@ class TestExtractHarmonyStreamingDelta:
|
||||
],
|
||||
)
|
||||
def test_returns_tool_call_preambles(self, channel, recipient):
|
||||
"""Test that invalid channel/recipient combinations return None."""
|
||||
"""Test that invalid tool recipient on commentary is treated as content."""
|
||||
parser = MockStreamableParser()
|
||||
delta_text = "some text"
|
||||
|
||||
token_states = [
|
||||
TokenState(channel=channel, recipient=recipient, text=delta_text)
|
||||
]
|
||||
|
||||
delta_message, tools_streamed = extract_harmony_streaming_delta(
|
||||
harmony_parser=parser,
|
||||
cur_channel=channel,
|
||||
cur_recipient=recipient,
|
||||
token_states=token_states,
|
||||
prev_recipient=None,
|
||||
delta_text=delta_text,
|
||||
include_reasoning=True,
|
||||
)
|
||||
|
||||
@@ -199,14 +217,140 @@ class TestExtractHarmonyStreamingDelta:
|
||||
"""Test that invalid channel/recipient combinations return None."""
|
||||
parser = MockStreamableParser()
|
||||
|
||||
token_states = [
|
||||
TokenState(channel=channel, recipient=recipient, text="some text")
|
||||
]
|
||||
|
||||
delta_message, tools_streamed = extract_harmony_streaming_delta(
|
||||
harmony_parser=parser,
|
||||
cur_channel=channel,
|
||||
cur_recipient=recipient,
|
||||
token_states=token_states,
|
||||
prev_recipient=None,
|
||||
delta_text="some text",
|
||||
include_reasoning=True,
|
||||
)
|
||||
|
||||
assert delta_message is None
|
||||
assert tools_streamed is False
|
||||
|
||||
def test_consecutive_token_grouping(self):
|
||||
"""
|
||||
Test that consecutive tokens with the same channel/recipient
|
||||
are merged into a single processing group.
|
||||
"""
|
||||
parser = MockStreamableParser()
|
||||
token_states = [
|
||||
TokenState("final", None, "H"),
|
||||
TokenState("final", None, "el"),
|
||||
TokenState("final", None, "lo"),
|
||||
TokenState("final", None, ","),
|
||||
TokenState("final", None, " World"),
|
||||
]
|
||||
|
||||
delta_message, _ = extract_harmony_streaming_delta(
|
||||
harmony_parser=parser,
|
||||
token_states=token_states,
|
||||
prev_recipient=None,
|
||||
include_reasoning=False,
|
||||
)
|
||||
|
||||
assert delta_message is not None
|
||||
assert delta_message.content == "Hello, World"
|
||||
|
||||
@patch("vllm.entrypoints.openai.chat_completion.stream_harmony.make_tool_call_id")
|
||||
def test_complex_batch_permutation(self, mock_make_id):
|
||||
"""
|
||||
Test a complex permutation: Reasoning -> Tool Call -> Content.
|
||||
This verifies that multiple distinct actions in one batch
|
||||
are all captured in the single DeltaMessage.
|
||||
"""
|
||||
mock_make_id.return_value = "call_batch_test"
|
||||
parser = MockStreamableParser()
|
||||
|
||||
token_states = [
|
||||
# 1. Reasoning
|
||||
TokenState("analysis", None, "Reasoning about query..."),
|
||||
# 2. Tool Calling
|
||||
TokenState("commentary", "functions.search", '{"query":'),
|
||||
TokenState("commentary", "functions.search", ' "vllm"}'),
|
||||
# 3. Final Content
|
||||
TokenState("final", None, "."),
|
||||
]
|
||||
|
||||
delta_message, tools_streamed = extract_harmony_streaming_delta(
|
||||
harmony_parser=parser,
|
||||
token_states=token_states,
|
||||
prev_recipient=None,
|
||||
include_reasoning=True,
|
||||
)
|
||||
|
||||
assert delta_message is not None
|
||||
|
||||
assert delta_message.reasoning == "Reasoning about query..."
|
||||
|
||||
# We expect 2 objects for 1 logical tool call:
|
||||
# 1. The definition (id, name, type)
|
||||
# 2. The arguments payload
|
||||
assert len(delta_message.tool_calls) == 2
|
||||
|
||||
header = delta_message.tool_calls[0]
|
||||
payload = delta_message.tool_calls[1]
|
||||
|
||||
assert header.function.name == "search"
|
||||
assert header.id == "call_batch_test"
|
||||
assert header.index == 0
|
||||
|
||||
assert payload.index == 0
|
||||
assert payload.function.arguments == '{"query": "vllm"}'
|
||||
|
||||
assert delta_message.content == "."
|
||||
assert tools_streamed is True
|
||||
|
||||
@patch("vllm.entrypoints.openai.chat_completion.stream_harmony.make_tool_call_id")
|
||||
def test_tool_call_index_consistency_with_ongoing_call(self, mock_make_id):
|
||||
"""
|
||||
Test that an ongoing tool call continuation and subsequent new calls
|
||||
maintain correct indexing when interleaved with content.
|
||||
"""
|
||||
mock_make_id.side_effect = ["id_b", "id_c"]
|
||||
|
||||
messages = [
|
||||
MockMessage(channel="commentary", recipient="functions.previous_tool")
|
||||
]
|
||||
parser = MockStreamableParser(messages=messages)
|
||||
|
||||
token_states = [
|
||||
TokenState("commentary", "functions.tool_a", '{"key_a": "val_a"}'),
|
||||
TokenState("final", None, "Thinking..."),
|
||||
TokenState("commentary", "functions.tool_b", '{"key_b": "val_b"}'),
|
||||
TokenState("final", None, " Thinking again..."),
|
||||
TokenState("commentary", "functions.tool_c", '{"key_c": "val_c"}'),
|
||||
]
|
||||
|
||||
delta_message, _ = extract_harmony_streaming_delta(
|
||||
harmony_parser=parser,
|
||||
token_states=token_states,
|
||||
prev_recipient="functions.tool_a",
|
||||
include_reasoning=False,
|
||||
)
|
||||
|
||||
assert delta_message is not None
|
||||
|
||||
tool_a_deltas = [t for t in delta_message.tool_calls if t.index == 1]
|
||||
assert len(tool_a_deltas) > 0
|
||||
assert tool_a_deltas[0].id is None
|
||||
assert tool_a_deltas[0].function.arguments == '{"key_a": "val_a"}'
|
||||
|
||||
tool_b_header = next(t for t in delta_message.tool_calls if t.id == "id_b")
|
||||
assert tool_b_header.index == 2
|
||||
tool_b_args = next(
|
||||
t for t in delta_message.tool_calls if t.index == 2 and t.id is None
|
||||
)
|
||||
assert tool_b_args.function.arguments == '{"key_b": "val_b"}'
|
||||
|
||||
tool_c_start = next(t for t in delta_message.tool_calls if t.id == "id_c")
|
||||
assert tool_c_start.index == 3
|
||||
tool_c_args = next(
|
||||
t for t in delta_message.tool_calls if t.index == 3 and t.id is None
|
||||
)
|
||||
assert tool_c_args.function.arguments == '{"key_c": "val_c"}'
|
||||
|
||||
assert delta_message.content == "Thinking... Thinking again..."
|
||||
|
||||
@@ -36,6 +36,7 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatMessage,
|
||||
)
|
||||
from vllm.entrypoints.openai.chat_completion.stream_harmony import (
|
||||
TokenState,
|
||||
extract_harmony_streaming_delta,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
@@ -826,12 +827,22 @@ class OpenAIServingChat(OpenAIServing):
|
||||
if self.use_harmony:
|
||||
harmony_parser = harmony_parsers[i]
|
||||
prev_recipient = harmony_parser.current_recipient
|
||||
delta_text = ""
|
||||
|
||||
# Track accumulated content per token with their state
|
||||
token_states: list[TokenState] = []
|
||||
for token_id in output.token_ids:
|
||||
harmony_parser.process(token_id)
|
||||
delta_text += harmony_parser.last_content_delta or ""
|
||||
token_delta = harmony_parser.last_content_delta or ""
|
||||
token_states.append(
|
||||
TokenState(
|
||||
harmony_parser.current_channel,
|
||||
harmony_parser.current_recipient,
|
||||
token_delta,
|
||||
)
|
||||
)
|
||||
delta_text = "".join(delta for _, _, delta in token_states)
|
||||
cur_channel = harmony_parser.current_channel
|
||||
cur_recipient = harmony_parser.current_recipient
|
||||
|
||||
# handle the case where several tokens where generated at once
|
||||
# including the final token, leading to a delta in the text
|
||||
# but the current channel to be empty (start state)
|
||||
@@ -869,10 +880,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
delta_message, tools_streamed_flag = (
|
||||
extract_harmony_streaming_delta(
|
||||
harmony_parser=harmony_parser,
|
||||
cur_channel=cur_channel,
|
||||
cur_recipient=cur_recipient,
|
||||
token_states=token_states,
|
||||
prev_recipient=prev_recipient,
|
||||
delta_text=delta_text,
|
||||
include_reasoning=request.include_reasoning,
|
||||
)
|
||||
)
|
||||
@@ -1139,17 +1148,23 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
# Log streaming delta if output logging is enabled
|
||||
if self.enable_log_outputs and self.request_logger:
|
||||
delta_content = ""
|
||||
delta_content_parts = []
|
||||
if delta_message.content:
|
||||
delta_content = delta_message.content
|
||||
elif delta_message.tool_calls:
|
||||
delta_content = "".join(
|
||||
delta_content_parts.append(delta_message.content)
|
||||
if delta_message.reasoning_content:
|
||||
reasoning = delta_message.reasoning_content
|
||||
delta_content_parts.append(f"[reasoning: {reasoning}]")
|
||||
if delta_message.tool_calls:
|
||||
tool_args = "".join(
|
||||
tc.function.arguments
|
||||
for tc in delta_message.tool_calls
|
||||
if tc.function and tc.function.arguments
|
||||
)
|
||||
if tool_args:
|
||||
delta_content_parts.append(f"[tool_calls: {tool_args}]")
|
||||
|
||||
if delta_content and self.enable_log_deltas:
|
||||
if delta_content_parts and self.enable_log_deltas:
|
||||
delta_content = " ".join(delta_content_parts)
|
||||
self.request_logger.log_outputs(
|
||||
request_id=request_id,
|
||||
outputs=delta_content,
|
||||
|
||||
@@ -7,6 +7,8 @@ This module handles the extraction of DeltaMessage objects from
|
||||
harmony parser state during streaming chat completions.
|
||||
"""
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
from openai_harmony import StreamableParser
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
@@ -17,12 +19,16 @@ from vllm.entrypoints.openai.engine.protocol import (
|
||||
)
|
||||
|
||||
|
||||
class TokenState(NamedTuple):
|
||||
channel: str | None
|
||||
recipient: str | None
|
||||
text: str
|
||||
|
||||
|
||||
def extract_harmony_streaming_delta(
|
||||
harmony_parser: StreamableParser,
|
||||
cur_channel: str | None,
|
||||
cur_recipient: str | None,
|
||||
token_states: list[TokenState],
|
||||
prev_recipient: str | None,
|
||||
delta_text: str,
|
||||
include_reasoning: bool,
|
||||
) -> tuple[DeltaMessage | None, bool]:
|
||||
"""
|
||||
@@ -30,38 +36,81 @@ def extract_harmony_streaming_delta(
|
||||
|
||||
Args:
|
||||
harmony_parser: The StreamableParser instance tracking parse state
|
||||
cur_channel: Current channel ("final", "analysis", "commentary", etc.)
|
||||
cur_recipient: Current recipient (e.g., "functions.my_func")
|
||||
token_states: List of TokenState tuples for each token
|
||||
prev_recipient: Previous recipient for detecting tool call transitions
|
||||
delta_text: The text delta to include in the message
|
||||
include_reasoning: Whether to include reasoning content
|
||||
|
||||
Returns:
|
||||
A tuple of (DeltaMessage or None, tools_streamed_flag)
|
||||
"""
|
||||
|
||||
if not token_states:
|
||||
return None, False
|
||||
|
||||
tools_streamed = False
|
||||
|
||||
if cur_channel == "final":
|
||||
delta_message = DeltaMessage(content=delta_text)
|
||||
elif (
|
||||
(cur_channel == "commentary" or cur_channel == "analysis")
|
||||
and cur_recipient
|
||||
and cur_recipient.startswith("functions.")
|
||||
):
|
||||
# Count completed tool calls to determine index
|
||||
base_index = 0
|
||||
for msg in harmony_parser.messages:
|
||||
if (
|
||||
(msg.channel == "commentary" or msg.channel == "analysis")
|
||||
and msg.recipient
|
||||
and msg.recipient.startswith("functions.")
|
||||
):
|
||||
base_index += 1
|
||||
# Group consecutive tokens with same channel/recipient
|
||||
groups: list[TokenState] = []
|
||||
|
||||
if prev_recipient != cur_recipient:
|
||||
tool_name = cur_recipient.split("functions.", 1)[1]
|
||||
delta_message = DeltaMessage(
|
||||
tool_calls=[
|
||||
current_channel = token_states[0].channel
|
||||
current_recipient = token_states[0].recipient
|
||||
current_text = token_states[0].text
|
||||
|
||||
for i in range(1, len(token_states)):
|
||||
state = token_states[i]
|
||||
if state.channel == current_channel and state.recipient == current_recipient:
|
||||
current_text += state.text
|
||||
else:
|
||||
groups.append(TokenState(current_channel, current_recipient, current_text))
|
||||
current_channel = state.channel
|
||||
current_recipient = state.recipient
|
||||
current_text = state.text
|
||||
|
||||
groups.append(TokenState(current_channel, current_recipient, current_text))
|
||||
|
||||
# Process each group and create delta messages
|
||||
delta_message = None
|
||||
combined_content = ""
|
||||
combined_reasoning = ""
|
||||
tool_messages = []
|
||||
content_encountered = False
|
||||
|
||||
# Calculate base_index once before the loop
|
||||
# This counts completed tool calls in messages
|
||||
base_index = 0
|
||||
for msg in harmony_parser.messages:
|
||||
if (
|
||||
(msg.channel == "commentary" or msg.channel == "analysis")
|
||||
and msg.recipient
|
||||
and msg.recipient.startswith("functions.")
|
||||
):
|
||||
base_index += 1
|
||||
|
||||
# If there's an ongoing tool call from previous chunk,
|
||||
# the next new tool call starts at base_index + 1
|
||||
if prev_recipient and prev_recipient.startswith("functions."):
|
||||
next_tool_index = base_index + 1
|
||||
# Ongoing call is at base_index
|
||||
ongoing_tool_index = base_index
|
||||
else:
|
||||
# No ongoing call, next new call is at base_index
|
||||
next_tool_index = base_index
|
||||
ongoing_tool_index = None
|
||||
|
||||
for group in groups:
|
||||
if group.channel == "final":
|
||||
combined_content += group.text
|
||||
content_encountered = True
|
||||
elif (
|
||||
(group.channel == "commentary" or group.channel == "analysis")
|
||||
and group.recipient
|
||||
and group.recipient.startswith("functions.")
|
||||
):
|
||||
opened_new_call = False
|
||||
if prev_recipient != group.recipient:
|
||||
# New tool call - emit the opening message
|
||||
tool_name = group.recipient.split("functions.", 1)[1]
|
||||
tool_messages.append(
|
||||
DeltaToolCall(
|
||||
id=make_tool_call_id(),
|
||||
type="function",
|
||||
@@ -69,32 +118,53 @@ def extract_harmony_streaming_delta(
|
||||
name=tool_name,
|
||||
arguments="",
|
||||
),
|
||||
index=base_index,
|
||||
index=next_tool_index,
|
||||
)
|
||||
]
|
||||
)
|
||||
elif delta_text:
|
||||
delta_message = DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=base_index,
|
||||
function=DeltaFunctionCall(arguments=delta_text),
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
delta_message = None
|
||||
)
|
||||
opened_new_call = True
|
||||
prev_recipient = group.recipient
|
||||
# Increment for subsequent new tool calls
|
||||
next_tool_index += 1
|
||||
|
||||
if delta_message is not None:
|
||||
if group.text:
|
||||
# Stream arguments for the ongoing tool call
|
||||
if opened_new_call:
|
||||
# Just opened in this group
|
||||
tool_call_index = next_tool_index - 1
|
||||
else:
|
||||
# Continuing from previous chunk
|
||||
# If ongoing_tool_index is None here, it means
|
||||
# we're continuing a call but prev_recipient
|
||||
# wasn't a function. Use base_index.
|
||||
tool_call_index = (
|
||||
ongoing_tool_index
|
||||
if ongoing_tool_index is not None
|
||||
else base_index
|
||||
)
|
||||
tool_messages.append(
|
||||
DeltaToolCall(
|
||||
index=tool_call_index,
|
||||
function=DeltaFunctionCall(arguments=group.text),
|
||||
)
|
||||
)
|
||||
elif group.channel == "commentary":
|
||||
# Tool call preambles meant to be shown to the user
|
||||
combined_content += group.text
|
||||
content_encountered = True
|
||||
elif group.channel == "analysis" and include_reasoning:
|
||||
combined_reasoning += group.text
|
||||
|
||||
# Combine all non-empty fields into a single message
|
||||
if content_encountered or combined_reasoning or tool_messages:
|
||||
delta_kwargs: dict[str, str | list[DeltaToolCall]] = {}
|
||||
if content_encountered:
|
||||
delta_kwargs["content"] = combined_content
|
||||
if combined_reasoning:
|
||||
delta_kwargs["reasoning"] = combined_reasoning
|
||||
if tool_messages:
|
||||
delta_kwargs["tool_calls"] = tool_messages
|
||||
tools_streamed = True
|
||||
elif cur_channel == "commentary":
|
||||
# Tool call preambles meant to be shown to the user
|
||||
delta_message = DeltaMessage(content=delta_text)
|
||||
elif cur_channel == "analysis":
|
||||
if include_reasoning:
|
||||
delta_message = DeltaMessage(reasoning=delta_text)
|
||||
else:
|
||||
delta_message = None
|
||||
delta_message = DeltaMessage(**delta_kwargs)
|
||||
else:
|
||||
delta_message = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user