[gpt-oss] Support chat completion api (#22342)

This commit is contained in:
Woosuk Kwon
2025-08-06 01:57:39 -07:00
committed by GitHub
parent 54991c548a
commit f263a4b53f
3 changed files with 183 additions and 24 deletions

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import datetime import datetime
from collections.abc import Iterable
from typing import Literal, Optional from typing import Literal, Optional
from openai.types.responses.tool import Tool from openai.types.responses.tool import Tool
@@ -109,3 +110,36 @@ def get_stop_tokens_for_assistant_actions() -> list[int]:
def get_streamable_parser_for_assistant() -> StreamableParser: def get_streamable_parser_for_assistant() -> StreamableParser:
return StreamableParser(get_encoding(), role=Role.ASSISTANT) return StreamableParser(get_encoding(), role=Role.ASSISTANT)
def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser:
parser = get_streamable_parser_for_assistant()
for token_id in token_ids:
parser.process(token_id)
return parser
def parse_chat_output(
token_ids: list[int]) -> tuple[Optional[str], Optional[str], bool]:
parser = parse_output_into_messages(token_ids)
output_msgs = parser.messages
if len(output_msgs) == 0:
# The generation has stopped during reasoning.
is_tool_call = False
reasoning_content = parser.current_content
final_content = None
elif len(output_msgs) == 1:
# The generation has stopped during final message.
is_tool_call = False
reasoning_content = output_msgs[0].content[0].text
final_content = parser.current_content
else:
if len(output_msgs) != 2:
raise ValueError(
"Expected 2 output messages (reasoning and final), "
f"but got {len(output_msgs)}.")
reasoning_msg, final_msg = output_msgs
reasoning_content = reasoning_msg.content[0].text
final_content = final_msg.content[0].text
is_tool_call = final_msg.recipient is not None
return reasoning_content, final_content, is_tool_call

View File

@@ -323,6 +323,7 @@ class ResponsesRequest(OpenAIBaseModel):
if (top_p := self.top_p) is None: if (top_p := self.top_p) is None:
top_p = default_sampling_params.get( top_p = default_sampling_params.get(
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
stop_token_ids = default_sampling_params.get("stop_token_ids")
# Structured output # Structured output
guided_decoding = None guided_decoding = None
@@ -340,6 +341,7 @@ class ResponsesRequest(OpenAIBaseModel):
top_p=top_p, top_p=top_p,
max_tokens=max_tokens, max_tokens=max_tokens,
logprobs=self.top_logprobs, logprobs=self.top_logprobs,
stop_token_ids=stop_token_ids,
output_kind=(RequestOutputKind.DELTA output_kind=(RequestOutputKind.DELTA
if self.stream else RequestOutputKind.FINAL_ONLY), if self.stream else RequestOutputKind.FINAL_ONLY),
guided_decoding=guided_decoding, guided_decoding=guided_decoding,
@@ -404,6 +406,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
Literal["required"], Literal["required"],
ChatCompletionNamedToolChoiceParam, ChatCompletionNamedToolChoiceParam,
]] = "none" ]] = "none"
reasoning_effort: Optional[Literal["low", "medium", "high"]] = None
include_reasoning: bool = True
# NOTE this will be ignored by vLLM -- the model determines the behavior # NOTE this will be ignored by vLLM -- the model determines the behavior
parallel_tool_calls: Optional[bool] = False parallel_tool_calls: Optional[bool] = False

View File

@@ -12,6 +12,7 @@ import jinja2
import partial_json_parser import partial_json_parser
import regex as re import regex as re
from fastapi import Request from fastapi import Request
from openai_harmony import Message as OpenAIMessage
from pydantic import TypeAdapter from pydantic import TypeAdapter
from vllm.config import ModelConfig from vllm.config import ModelConfig
@@ -19,6 +20,10 @@ from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
ConversationMessage, ConversationMessage,
random_tool_call_id) random_tool_call_id)
from vllm.entrypoints.harmony_utils import (
get_developer_message, get_stop_tokens_for_assistant_actions,
get_streamable_parser_for_assistant, get_system_message, parse_chat_input,
parse_chat_output, render_for_completion)
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProb, ChatCompletionLogProbs, ChatCompletionLogProb, ChatCompletionLogProbs,
@@ -35,6 +40,7 @@ from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolCall) MistralToolCall)
from vllm.entrypoints.utils import get_max_tokens from vllm.entrypoints.utils import get_max_tokens
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.reasoning import ReasoningParser, ReasoningParserManager
@@ -125,6 +131,23 @@ class OpenAIServingChat(OpenAIServing):
logger.info("Using default chat sampling params from %s: %s", logger.info("Using default chat sampling params from %s: %s",
source, self.default_sampling_params) source, self.default_sampling_params)
self.use_harmony = model_config.hf_config.model_type == "gpt_oss"
if self.use_harmony:
if "stop_token_ids" not in self.default_sampling_params:
self.default_sampling_params["stop_token_ids"] = []
self.default_sampling_params["stop_token_ids"].extend(
get_stop_tokens_for_assistant_actions())
# NOTE(woosuk): While OpenAI's chat completion API supports browsing
# for some models, currently vLLM doesn't support it. Please use the
# Responses API instead.
self.supports_browsing = False
self.browser_tool = None
# NOTE(woosuk): Chat completion API does not support code interpreter.
# Please use the Responses API instead.
self.supports_code_interpreter = False
self.python_tool = None
async def create_chat_completion( async def create_chat_completion(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
@@ -169,7 +192,8 @@ class OpenAIServingChat(OpenAIServing):
if (request.tool_choice == "auto" and if (request.tool_choice == "auto" and
not (self.enable_auto_tools and tool_parser is not None) not (self.enable_auto_tools and tool_parser is not None)
and not isinstance(tokenizer, MistralTokenizer)): and not isinstance(tokenizer, MistralTokenizer)
and not self.use_harmony):
# for hf tokenizers, "auto" tools requires # for hf tokenizers, "auto" tools requires
# --enable-auto-tool-choice and --tool-call-parser # --enable-auto-tool-choice and --tool-call-parser
return self.create_error_response( return self.create_error_response(
@@ -184,25 +208,35 @@ class OpenAIServingChat(OpenAIServing):
else: else:
tool_dicts = [tool.model_dump() for tool in request.tools] tool_dicts = [tool.model_dump() for tool in request.tools]
( if not self.use_harmony:
conversation, # Common case.
request_prompts, (
engine_prompts, conversation,
) = await self._preprocess_chat( request_prompts,
request, engine_prompts,
tokenizer, ) = await self._preprocess_chat(
request.messages, request,
chat_template=request.chat_template or self.chat_template, tokenizer,
chat_template_content_format=self.chat_template_content_format, request.messages,
add_generation_prompt=request.add_generation_prompt, chat_template=request.chat_template or self.chat_template,
continue_final_message=request.continue_final_message, chat_template_content_format=self.
tool_dicts=tool_dicts, chat_template_content_format,
documents=request.documents, add_generation_prompt=request.add_generation_prompt,
chat_template_kwargs=request.chat_template_kwargs, continue_final_message=request.continue_final_message,
tool_parser=tool_parser, tool_dicts=tool_dicts,
truncate_prompt_tokens=request.truncate_prompt_tokens, documents=request.documents,
add_special_tokens=request.add_special_tokens, chat_template_kwargs=request.chat_template_kwargs,
) tool_parser=tool_parser,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
else:
# For GPT-OSS.
(
conversation,
request_prompts,
engine_prompts,
) = self._make_request_with_harmony(request)
except (ValueError, TypeError, RuntimeError, except (ValueError, TypeError, RuntimeError,
jinja2.TemplateError) as e: jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
@@ -436,6 +470,11 @@ class OpenAIServingChat(OpenAIServing):
finish_reason_sent = [False] * num_choices finish_reason_sent = [False] * num_choices
num_prompt_tokens = 0 num_prompt_tokens = 0
num_cached_tokens = None num_cached_tokens = None
if self.use_harmony:
harmony_parsers = [
get_streamable_parser_for_assistant()
for _ in range(num_choices)
]
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
tool_choice_function_name = request.tool_choice.function.name tool_choice_function_name = request.tool_choice.function.name
@@ -597,7 +636,18 @@ class OpenAIServingChat(OpenAIServing):
else: else:
logprobs = None logprobs = None
delta_text = output.text if self.use_harmony:
harmony_parser = harmony_parsers[i]
for token_id in output.token_ids:
harmony_parser.process(token_id)
# FIXME(woosuk): Support function calling
is_final = harmony_parser.current_channel == "final"
if not (request.include_reasoning or is_final):
# Skip the reasoning content.
continue
delta_text = harmony_parser.last_content_delta or ""
else:
delta_text = output.text
if not delta_text and not output.token_ids and \ if not delta_text and not output.token_ids and \
not previous_num_tokens[i]: not previous_num_tokens[i]:
@@ -607,7 +657,8 @@ class OpenAIServingChat(OpenAIServing):
delta_message: Optional[DeltaMessage] delta_message: Optional[DeltaMessage]
# just update previous_texts and previous_token_ids # just update previous_texts and previous_token_ids
if tool_choice_auto or self.reasoning_parser: if ((tool_choice_auto or self.reasoning_parser)
and not self.use_harmony):
assert previous_texts is not None assert previous_texts is not None
assert all_previous_token_ids is not None assert all_previous_token_ids is not None
previous_text = previous_texts[i] previous_text = previous_texts[i]
@@ -621,8 +672,14 @@ class OpenAIServingChat(OpenAIServing):
else: else:
current_token_ids = list(output.token_ids) current_token_ids = list(output.token_ids)
if self.use_harmony:
if is_final:
delta_message = DeltaMessage(content=delta_text)
else:
delta_message = DeltaMessage(
reasoning_content=delta_text)
# handle streaming deltas for tools with named tool_choice # handle streaming deltas for tools with named tool_choice
if tool_choice_function_name: elif tool_choice_function_name:
if (self.reasoning_parser and not reasoning_end_arr[i] if (self.reasoning_parser and not reasoning_end_arr[i]
and not reasoning_parser.is_reasoning_end( and not reasoning_parser.is_reasoning_end(
previous_token_ids)): previous_token_ids)):
@@ -990,7 +1047,38 @@ class OpenAIServingChat(OpenAIServing):
) )
else: else:
logprobs = None logprobs = None
auto_tools_called = False
if self.use_harmony:
reasoning_content, final_content, is_tool_call = (
parse_chat_output(token_ids))
if not request.include_reasoning:
reasoning_content = None
if is_tool_call:
# TODO(woosuk): Implement tool call for gpt-oss.
# For now, only Responses API supports tool call for
# gpt-oss.
raise NotImplementedError(
"Tool call in Chat Completion API is not supported "
"for gpt-oss yet. Please use Responses API instead.")
else:
# Normal message
message = ChatMessage(
role=role,
reasoning_content=reasoning_content,
content=final_content,
)
choice_data = ChatCompletionResponseChoice(
index=output.index,
message=message,
logprobs=logprobs,
finish_reason="tool_calls" if is_tool_call else
output.finish_reason if output.finish_reason else "stop",
stop_reason=output.stop_reason,
)
choices.append(choice_data)
continue
if self.reasoning_parser: if self.reasoning_parser:
try: try:
@@ -1003,10 +1091,13 @@ class OpenAIServingChat(OpenAIServing):
reasoning_content, content = ( reasoning_content, content = (
reasoning_parser.extract_reasoning_content( reasoning_parser.extract_reasoning_content(
output.text, request=request)) output.text, request=request))
if not request.include_reasoning:
reasoning_content = None
else: else:
reasoning_content = None reasoning_content = None
content = output.text content = output.text
auto_tools_called = False
# if auto tools are not enabled, and a named tool choice using # if auto tools are not enabled, and a named tool choice using
# outlines is not being used # outlines is not being used
if (not self.enable_auto_tools or not self.tool_parser) and \ if (not self.enable_auto_tools or not self.tool_parser) and \
@@ -1261,3 +1352,33 @@ class OpenAIServingChat(OpenAIServing):
and delta_message.tool_calls[0].function and delta_message.tool_calls[0].function
and delta_message.tool_calls[0].function.arguments is not None and delta_message.tool_calls[0].function.arguments is not None
) )
def _make_request_with_harmony(
self,
request: ChatCompletionRequest,
):
messages: list[OpenAIMessage] = []
# Add system message.
# NOTE: In Chat Completion API, browsing is enabled by default
# if the model supports it. TODO: Support browsing.
assert not self.supports_browsing
assert not self.supports_code_interpreter
sys_msg = get_system_message(
reasoning_effort=request.reasoning_effort,
browser_description=None,
python_description=None)
messages.append(sys_msg)
# Add developer message.
dev_msg = get_developer_message()
messages.append(dev_msg)
# Add user message.
for chat_msg in request.messages:
messages.append(parse_chat_input(chat_msg))
# Render prompt token ids.
prompt_token_ids = render_for_completion(messages)
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
return messages, [prompt_token_ids], [engine_prompt]