[Fix] [gpt-oss] fix non-tool calling path for chat completion (#24324)
This commit is contained in:
@@ -36,21 +36,41 @@ def monkeypatch_module():
|
||||
mpatch.undo()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module",
|
||||
params=[True, False],
|
||||
ids=["with_tool_parser", "without_tool_parser"])
|
||||
def with_tool_parser(request) -> bool:
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gptoss_server(monkeypatch_module: pytest.MonkeyPatch):
|
||||
with monkeypatch_module.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
|
||||
args = [
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
def default_server_args(with_tool_parser: bool):
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
"4096",
|
||||
"--reasoning-parser",
|
||||
"openai_gptoss",
|
||||
"--gpu-memory-utilization",
|
||||
"0.8",
|
||||
]
|
||||
if with_tool_parser:
|
||||
args.extend([
|
||||
"--tool-call-parser",
|
||||
"openai",
|
||||
"--reasoning-parser",
|
||||
"openai_gptoss",
|
||||
"--enable-auto-tool-choice",
|
||||
]
|
||||
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, args) as remote_server:
|
||||
])
|
||||
return args
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gptoss_server(monkeypatch_module: pytest.MonkeyPatch,
|
||||
default_server_args: list[str]):
|
||||
with monkeypatch_module.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
|
||||
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME,
|
||||
default_server_args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@@ -61,7 +81,8 @@ async def gptoss_client(gptoss_server):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI):
|
||||
async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI,
|
||||
with_tool_parser: bool):
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
@@ -94,10 +115,14 @@ async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI):
|
||||
]
|
||||
|
||||
stream = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME, messages=messages, tools=tools, stream=True)
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages,
|
||||
tools=tools if with_tool_parser else None,
|
||||
stream=True)
|
||||
|
||||
name = None
|
||||
args_buf = ""
|
||||
content_buf = ""
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
if delta.tool_calls:
|
||||
@@ -106,13 +131,22 @@ async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI):
|
||||
name = tc.function.name
|
||||
if tc.function and tc.function.arguments:
|
||||
args_buf += tc.function.arguments
|
||||
|
||||
assert name is not None
|
||||
assert len(args_buf) > 0
|
||||
if getattr(delta, "content", None):
|
||||
content_buf += delta.content
|
||||
if with_tool_parser:
|
||||
assert name is not None
|
||||
assert len(args_buf) > 0
|
||||
else:
|
||||
assert name is None
|
||||
assert len(args_buf) == 0
|
||||
assert len(content_buf) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI):
|
||||
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI,
|
||||
with_tool_parser: bool):
|
||||
if not with_tool_parser:
|
||||
pytest.skip("skip non-tool for multi-turn tests")
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
@@ -175,7 +209,7 @@ async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI):
|
||||
)
|
||||
second_msg = second.choices[0].message
|
||||
assert (second_msg.content is not None and len(second_msg.content) > 0) or \
|
||||
(second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0) # noqa: E501
|
||||
(second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0)
|
||||
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
|
||||
Reference in New Issue
Block a user