[Bugfix] Multiple fixes for gpt-oss Chat Completion prompting (#28729)

Signed-off-by: Ben Browning <bbrownin@redhat.com>
Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
This commit is contained in:
Ben Browning
2025-12-11 23:59:53 -05:00
committed by GitHub
parent fe1787107e
commit 8f8fda261a
6 changed files with 1749 additions and 139 deletions

View File

@@ -11,13 +11,25 @@ import pytest_asyncio
from openai import OpenAI
from vllm.config.multimodal import MultiModalConfig
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.parser.harmony_utils import get_encoding
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
RequestResponseMetadata,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.tokenizers import get_tokenizer
from vllm.v1.engine.async_llm import AsyncLLM
from ...utils import RemoteOpenAIServer
from .utils import (
accumulate_streaming_response,
verify_chat_response,
verify_harmony_messages,
)
GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b"
@@ -728,3 +740,635 @@ async def test_serving_chat_data_parallel_rank_extraction():
# Verify that data_parallel_rank defaults to None
assert "data_parallel_rank" in mock_engine.generate.call_args.kwargs
assert mock_engine.generate.call_args.kwargs["data_parallel_rank"] is None
class TestServingChatWithHarmony:
"""
These tests ensure Chat Completion requests are being properly converted into
Harmony messages and Harmony response messages back into Chat Completion responses.
These tests are not exhaustive, but each one was created to cover a specific case
that we got wrong but is now fixed.
Any changes to the tests and their expectations may result in changes to the
accuracy of model prompting and responses generated. It is suggested to run
an evaluation or benchmarking suite (such as bfcl multi_turn) to understand
any impact of changes in how we prompt Harmony models.
"""
@pytest.fixture(params=[False, True], ids=["non_streaming", "streaming"])
def stream(self, request) -> bool:
"""Parameterize tests to run in both non-streaming and streaming modes."""
return request.param
@pytest.fixture()
def mock_engine(self) -> AsyncLLM:
mock_engine = MagicMock(spec=AsyncLLM)
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
mock_engine.errored = False
mock_engine.model_config = MockModelConfig()
mock_engine.input_processor = MagicMock()
mock_engine.io_processor = MagicMock()
return mock_engine
@pytest.fixture()
def serving_chat(self, mock_engine) -> OpenAIServingChat:
chat = _build_serving_chat(mock_engine)
chat.use_harmony = True
chat.tool_parser = ToolParserManager.get_tool_parser("openai")
return chat
def mock_request_output_from_req_and_token_ids(
self, req: ChatCompletionRequest, token_ids: list[int], finished: bool = False
) -> RequestOutput:
# Our tests don't use most fields, so just get the token ids correct
completion_output = CompletionOutput(
index=0,
text="",
token_ids=token_ids,
cumulative_logprob=0.0,
logprobs=None,
)
return RequestOutput(
request_id=req.request_id,
prompt=[],
prompt_token_ids=[],
prompt_logprobs=None,
outputs=[completion_output],
finished=finished,
)
@pytest.fixture
def weather_tools(self) -> list[dict[str, Any]]:
return [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string"},
},
"required": ["location"],
},
},
},
]
@pytest.fixture
def weather_messages_start(self) -> list[dict[str, Any]]:
return [
{
"role": "user",
"content": "What's the weather like in Paris today?",
},
]
async def generate_response_from_harmony_str(
self,
serving_chat: OpenAIServingChat,
req: ChatCompletionRequest,
harmony_str: str,
stream: bool = False,
) -> ChatCompletionResponse:
harmony_token_ids = get_encoding().encode(harmony_str, allowed_special="all")
async def result_generator():
if stream:
for token_id in harmony_token_ids:
yield self.mock_request_output_from_req_and_token_ids(
req, [token_id]
)
yield self.mock_request_output_from_req_and_token_ids(
req, [], finished=True
)
else:
yield self.mock_request_output_from_req_and_token_ids(
req, harmony_token_ids, finished=True
)
generator_func = (
serving_chat.chat_completion_stream_generator
if stream
else serving_chat.chat_completion_full_generator
)
result = generator_func(
request=req,
result_generator=result_generator(),
request_id=req.request_id,
model_name=req.model,
conversation=[],
tokenizer=get_tokenizer(req.model),
request_metadata=RequestResponseMetadata(
request_id=req.request_id,
model_name=req.model,
),
)
if stream:
return await accumulate_streaming_response(result)
return await result
@pytest.mark.asyncio
async def test_simple_chat(self, serving_chat, stream):
messages = [{"role": "user", "content": "what is 1+1?"}]
# Test the Harmony messages for the first turn's input
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer"},
{"role": "user", "content": messages[0]["content"]},
],
)
# Test the Chat Completion response for the first turn's output
reasoning_str = "We need to think really hard about this."
final_str = "The answer is 2."
response_str = (
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
f"<|start|>assistant<|channel|>final<|message|>{final_str}<|end|>"
)
response = await self.generate_response_from_harmony_str(
serving_chat, req, response_str, stream=stream
)
verify_chat_response(response, content=final_str, reasoning=reasoning_str)
# Add the output messages from the first turn as input to the second turn
for choice in response.choices:
messages.append(choice.message.model_dump(exclude_none=True))
# Test the Harmony messages for the second turn's input
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2)
verify_harmony_messages(
input_messages_2,
[
{"role": "system"},
{"role": "developer"},
{"role": "user"},
# The analysis message should be dropped on subsequent inputs because
# of the subsequent assistant message to the final channel.
{"role": "assistant", "channel": "final", "content": final_str},
],
)
@pytest.mark.asyncio
async def test_tool_call_response_with_content(
self, serving_chat, stream, weather_tools, weather_messages_start
):
tools = weather_tools
messages = list(weather_messages_start)
# Test the Harmony messages for the first turn's input
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer", "tool_definitions": ["get_weather"]},
{"role": "user", "content": messages[0]["content"]},
],
)
# Test the Chat Completion response for the first turn's output
commentary_str = "We'll call get_weather."
tool_args_str = '{"location": "Paris"}'
response_str = (
f"<|channel|>commentary<|message|>{commentary_str}<|end|>"
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
f"<|constrain|>json<|message|>{tool_args_str}<|call|>"
)
response = await self.generate_response_from_harmony_str(
serving_chat, req, response_str, stream=stream
)
verify_chat_response(
response,
content=commentary_str,
tool_calls=[("get_weather", tool_args_str)],
)
tool_call = response.choices[0].message.tool_calls[0]
# Add the output messages from the first turn as input to the second turn
for choice in response.choices:
messages.append(choice.message.model_dump(exclude_none=True))
# Add our tool output message
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": "20 degrees Celsius",
},
)
# Test the Harmony messages for the second turn's input
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2)
verify_harmony_messages(
input_messages_2,
[
{"role": "system"},
{"role": "developer"},
{"role": "user"},
{
"role": "assistant",
"channel": "commentary",
"content": commentary_str,
},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": tool_args_str,
},
{
"role": "tool",
"author_name": "functions.get_weather",
"channel": "commentary",
"recipient": "assistant",
"content": "20 degrees Celsius",
},
],
)
@pytest.mark.asyncio
async def test_tools_and_reasoning(
self, serving_chat, stream, weather_tools, weather_messages_start
):
tools = weather_tools
messages = list(weather_messages_start)
# Test the Harmony messages for the first turn's input
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer", "tool_definitions": ["get_weather"]},
{"role": "user", "content": messages[0]["content"]},
],
)
# Test the Chat Completion response for the first turn's output
reasoning_str = "I'll call get_weather."
tool_args_str = '{"location": "Paris"}'
response_str = (
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
f"<|constrain|>json<|message|>{tool_args_str}<|call|>"
)
response = await self.generate_response_from_harmony_str(
serving_chat, req, response_str, stream=stream
)
verify_chat_response(
response,
reasoning=reasoning_str,
tool_calls=[("get_weather", tool_args_str)],
)
tool_call = response.choices[0].message.tool_calls[0]
# Add the output messages from the first turn as input to the second turn
for choice in response.choices:
messages.append(choice.message.model_dump(exclude_none=True))
# Add our tool output message
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": "20 degrees Celsius",
},
)
# Test the Harmony messages for the second turn's input
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2)
verify_harmony_messages(
input_messages_2,
[
{"role": "system"},
{"role": "developer"},
{"role": "user"},
{
"role": "assistant",
"channel": "analysis",
"content": reasoning_str,
},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": tool_args_str,
},
{
"role": "tool",
"author_name": "functions.get_weather",
"channel": "commentary",
"recipient": "assistant",
"content": "20 degrees Celsius",
},
],
)
@pytest.mark.asyncio
async def test_multi_turn_tools_and_reasoning(
self, serving_chat, stream, weather_tools, weather_messages_start
):
tools = weather_tools
messages = list(weather_messages_start)
# Test the Harmony messages for the first turn's input
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages, tools=tools)
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer", "tool_definitions": ["get_weather"]},
{"role": "user", "content": messages[0]["content"]},
],
)
# Test the Chat Completion response for the first turn's output
reasoning_str = "I'll call get_weather."
paris_tool_args_str = '{"location": "Paris"}'
response_str = (
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
f"<|constrain|>json<|message|>{paris_tool_args_str}<|call|>"
)
response = await self.generate_response_from_harmony_str(
serving_chat, req, response_str, stream=stream
)
verify_chat_response(
response,
reasoning=reasoning_str,
tool_calls=[("get_weather", paris_tool_args_str)],
)
tool_call = response.choices[0].message.tool_calls[0]
# Add the output messages from the first turn as input to the second turn
for choice in response.choices:
messages.append(choice.message.model_dump(exclude_none=True))
# Add our tool output message
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": "20 degrees Celsius",
},
)
# Test the Harmony messages for the second turn's input
req_2 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_2, _, _ = serving_chat._make_request_with_harmony(req_2)
verify_harmony_messages(
input_messages_2,
[
{"role": "system"},
{"role": "developer"},
{"role": "user"},
{
"role": "assistant",
"channel": "analysis",
"content": reasoning_str,
},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": paris_tool_args_str,
},
{
"role": "tool",
"author_name": "functions.get_weather",
"channel": "commentary",
"recipient": "assistant",
"content": "20 degrees Celsius",
},
],
)
# Test the Chat Completion response for the second turn's output
paris_weather_str = "The weather in Paris today is 20 degrees Celsius."
response_str = f"<|channel|>final<|message|>{paris_weather_str}<|end|>"
response_2 = await self.generate_response_from_harmony_str(
serving_chat, req_2, response_str, stream=stream
)
verify_chat_response(response_2, content=paris_weather_str)
# Add the output messages from the second turn as input to the third turn
for choice in response_2.choices:
messages.append(choice.message.model_dump(exclude_none=True))
# Add a new user message for the third turn
messages.append(
{
"role": "user",
"content": "What's the weather like in Boston today?",
},
)
# Test the Harmony messages for the third turn's input
req_3 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_3, _, _ = serving_chat._make_request_with_harmony(req_3)
verify_harmony_messages(
input_messages_3,
[
{"role": "system"},
{"role": "developer"},
{"role": "user"},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": paris_tool_args_str,
},
{
"role": "tool",
"author_name": "functions.get_weather",
"channel": "commentary",
"recipient": "assistant",
"content": "20 degrees Celsius",
},
{
"role": "assistant",
"channel": "final",
"content": paris_weather_str,
},
{"role": "user", "content": messages[-1]["content"]},
],
)
# Test the Chat Completion response for the third turn's output
reasoning_str = "I'll call get_weather."
boston_tool_args_str = '{"location": "Boston"}'
response_str = (
f"<|channel|>analysis<|message|>{reasoning_str}<|end|>"
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
f"<|constrain|>json<|message|>{boston_tool_args_str}<|call|>"
)
response_3 = await self.generate_response_from_harmony_str(
serving_chat, req, response_str, stream=stream
)
verify_chat_response(
response_3,
reasoning=reasoning_str,
tool_calls=[("get_weather", boston_tool_args_str)],
)
tool_call = response_3.choices[0].message.tool_calls[0]
# Add the output messages from the third turn as input to the fourth turn
for choice in response_3.choices:
messages.append(choice.message.model_dump(exclude_none=True))
# Add our tool output message
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": "10 degrees Celsius",
},
)
# Test the Harmony messages for the fourth turn's input
req_4 = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages_4, _, _ = serving_chat._make_request_with_harmony(req_4)
verify_harmony_messages(
input_messages_4,
[
{"role": "system"},
{"role": "developer"},
{"role": "user"},
{"role": "assistant"},
{"role": "tool"},
{
"role": "assistant",
"channel": "final",
},
{"role": "user"},
{
"role": "assistant",
"channel": "analysis",
"content": reasoning_str,
},
{
"role": "assistant",
"channel": "commentary",
"recipient": "functions.get_weather",
"content": boston_tool_args_str,
},
{
"role": "tool",
"author_name": "functions.get_weather",
"channel": "commentary",
"recipient": "assistant",
"content": "10 degrees Celsius",
},
],
)
@pytest.mark.asyncio
async def test_non_tool_reasoning(self, serving_chat):
messages: list[dict[str, Any]] = [
{
"role": "user",
"content": "What's 2+2?",
},
{
"role": "assistant",
"reasoning": "Adding 2 and 2 is easy. The result is 4.",
"content": "4",
},
]
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer"},
{"role": "user", "content": messages[0]["content"]},
# The reasoning that would have resulted in an analysis message is
# dropped because of a later assistant message to the final channel.
{
"role": "assistant",
"channel": "final",
"content": messages[1]["content"],
},
],
)
@pytest.mark.asyncio
async def test_non_tool_reasoning_empty_content(self, serving_chat):
messages: list[dict[str, Any]] = [
{
"role": "user",
"content": "What's 2+2?",
},
{
"role": "assistant",
"reasoning": "Adding 2 and 2 is easy. The result is 4.",
"content": "",
},
]
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer"},
{"role": "user", "content": messages[0]["content"]},
{
"role": "assistant",
"channel": "analysis",
"content": messages[1]["reasoning"],
},
],
)
@pytest.mark.asyncio
async def test_non_tool_reasoning_empty_content_list(self, serving_chat):
messages: list[dict[str, Any]] = [
{
"role": "user",
"content": "What's 2+2?",
},
{
"role": "assistant",
"reasoning": "Adding 2 and 2 is easy. The result is 4.",
"content": [],
},
]
req = ChatCompletionRequest(model=MODEL_NAME, messages=messages)
input_messages, _, _ = serving_chat._make_request_with_harmony(req)
verify_harmony_messages(
input_messages,
[
{"role": "system"},
{"role": "developer"},
{"role": "user", "content": messages[0]["content"]},
{
"role": "assistant",
"channel": "analysis",
"content": messages[1]["reasoning"],
},
],
)