[Fix] [gpt-oss] fix non-tool calling path for chat completion (#24324)

This commit is contained in:
Aaron Pham
2025-09-06 15:10:32 -04:00
committed by GitHub
parent 6024d115cd
commit fb691ee4e7
2 changed files with 83 additions and 38 deletions

View File

@@ -36,21 +36,41 @@ def monkeypatch_module():
mpatch.undo()
@pytest.fixture(scope="module",
params=[True, False],
ids=["with_tool_parser", "without_tool_parser"])
def with_tool_parser(request) -> bool:
return request.param
@pytest.fixture(scope="module")
def gptoss_server(monkeypatch_module: pytest.MonkeyPatch):
with monkeypatch_module.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
args = [
"--enforce-eager",
"--max-model-len",
"8192",
def default_server_args(with_tool_parser: bool):
args = [
# use half precision for speed and memory savings in CI environment
"--enforce-eager",
"--max-model-len",
"4096",
"--reasoning-parser",
"openai_gptoss",
"--gpu-memory-utilization",
"0.8",
]
if with_tool_parser:
args.extend([
"--tool-call-parser",
"openai",
"--reasoning-parser",
"openai_gptoss",
"--enable-auto-tool-choice",
]
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, args) as remote_server:
])
return args
@pytest.fixture(scope="module")
def gptoss_server(monkeypatch_module: pytest.MonkeyPatch,
default_server_args: list[str]):
with monkeypatch_module.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME,
default_server_args) as remote_server:
yield remote_server
@@ -61,7 +81,8 @@ async def gptoss_client(gptoss_server):
@pytest.mark.asyncio
async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI):
async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI,
with_tool_parser: bool):
tools = [{
"type": "function",
"function": {
@@ -94,10 +115,14 @@ async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI):
]
stream = await gptoss_client.chat.completions.create(
model=GPT_OSS_MODEL_NAME, messages=messages, tools=tools, stream=True)
model=GPT_OSS_MODEL_NAME,
messages=messages,
tools=tools if with_tool_parser else None,
stream=True)
name = None
args_buf = ""
content_buf = ""
async for chunk in stream:
delta = chunk.choices[0].delta
if delta.tool_calls:
@@ -106,13 +131,22 @@ async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI):
name = tc.function.name
if tc.function and tc.function.arguments:
args_buf += tc.function.arguments
assert name is not None
assert len(args_buf) > 0
if getattr(delta, "content", None):
content_buf += delta.content
if with_tool_parser:
assert name is not None
assert len(args_buf) > 0
else:
assert name is None
assert len(args_buf) == 0
assert len(content_buf) > 0
@pytest.mark.asyncio
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI):
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI,
with_tool_parser: bool):
if not with_tool_parser:
pytest.skip("skip non-tool for multi-turn tests")
tools = [{
"type": "function",
"function": {
@@ -175,7 +209,7 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI):
)
second_msg = second.choices[0].message
assert (second_msg.content is not None and len(second_msg.content) > 0) or \
(second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0) # noqa: E501
(second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0)
MODEL_NAME = "openai-community/gpt2"