[gpt-oss] tool parser supports for /chat/completions [1/n] (#22386)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
@@ -1,13 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from vllm.config import MultiModalConfig
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
@@ -17,6 +20,164 @@ from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai import OpenAI
|
||||
|
||||
GPT_OSS_MODEL_NAME = "openai/gpt-oss-20b"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def monkeypatch_module():
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
mpatch = MonkeyPatch()
|
||||
yield mpatch
|
||||
mpatch.undo()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gptoss_server(monkeypatch_module: pytest.MonkeyPatch):
|
||||
with monkeypatch_module.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN_VLLM_V1")
|
||||
args = [
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--tool-call-parser",
|
||||
"openai",
|
||||
"--reasoning-parser",
|
||||
"openai_gptoss",
|
||||
"--enable-auto-tool-choice",
|
||||
]
|
||||
with RemoteOpenAIServer(GPT_OSS_MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def gptoss_client(gptoss_server):
|
||||
async with gptoss_server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpt_oss_chat_tool_call_streaming(gptoss_client: OpenAI):
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string"
|
||||
},
|
||||
"state": {
|
||||
"type": "string"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["city", "state", "unit"],
|
||||
},
|
||||
},
|
||||
}]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather in Dallas, TX?"
|
||||
},
|
||||
]
|
||||
|
||||
stream = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME, messages=messages, tools=tools, stream=True)
|
||||
|
||||
name = None
|
||||
args_buf = ""
|
||||
async for chunk in stream:
|
||||
delta = chunk.choices[0].delta
|
||||
if delta.tool_calls:
|
||||
tc = delta.tool_calls[0]
|
||||
if tc.function and tc.function.name:
|
||||
name = tc.function.name
|
||||
if tc.function and tc.function.arguments:
|
||||
args_buf += tc.function.arguments
|
||||
|
||||
assert name is not None
|
||||
assert len(args_buf) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gpt_oss_multi_turn_chat(gptoss_client: OpenAI):
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string"
|
||||
},
|
||||
"state": {
|
||||
"type": "string"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["city", "state", "unit"],
|
||||
},
|
||||
},
|
||||
}]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the weather in Dallas, TX?"
|
||||
},
|
||||
]
|
||||
|
||||
first = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
temperature=0.0,
|
||||
)
|
||||
first_msg = first.choices[0].message
|
||||
assert first_msg.tool_calls is not None and len(first_msg.tool_calls) > 0
|
||||
tc = first_msg.tool_calls[0]
|
||||
assert tc.function is not None and tc.function.name == "get_current_weather"
|
||||
args1 = tc.function.arguments
|
||||
assert args1 is not None and len(args1) > 0
|
||||
|
||||
messages.append({"role": "assistant", "content": args1})
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": "Now convert to celsius and return JSON only"
|
||||
})
|
||||
|
||||
second = await gptoss_client.chat.completions.create(
|
||||
model=GPT_OSS_MODEL_NAME,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
temperature=0.0,
|
||||
)
|
||||
second_msg = second.choices[0].message
|
||||
assert (second_msg.content is not None and len(second_msg.content) > 0) or \
|
||||
(second_msg.tool_calls is not None and len(second_msg.tool_calls) > 0) # noqa: E501
|
||||
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
CHAT_TEMPLATE = "Dummy chat template for testing {}"
|
||||
BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)]
|
||||
|
||||
Reference in New Issue
Block a user