[gpt-oss] Support chat completion api (#22342)
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import datetime
|
import datetime
|
||||||
|
from collections.abc import Iterable
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from openai.types.responses.tool import Tool
|
from openai.types.responses.tool import Tool
|
||||||
@@ -109,3 +110,36 @@ def get_stop_tokens_for_assistant_actions() -> list[int]:
|
|||||||
|
|
||||||
def get_streamable_parser_for_assistant() -> StreamableParser:
|
def get_streamable_parser_for_assistant() -> StreamableParser:
|
||||||
return StreamableParser(get_encoding(), role=Role.ASSISTANT)
|
return StreamableParser(get_encoding(), role=Role.ASSISTANT)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser:
|
||||||
|
parser = get_streamable_parser_for_assistant()
|
||||||
|
for token_id in token_ids:
|
||||||
|
parser.process(token_id)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def parse_chat_output(
|
||||||
|
token_ids: list[int]) -> tuple[Optional[str], Optional[str], bool]:
|
||||||
|
parser = parse_output_into_messages(token_ids)
|
||||||
|
output_msgs = parser.messages
|
||||||
|
if len(output_msgs) == 0:
|
||||||
|
# The generation has stopped during reasoning.
|
||||||
|
is_tool_call = False
|
||||||
|
reasoning_content = parser.current_content
|
||||||
|
final_content = None
|
||||||
|
elif len(output_msgs) == 1:
|
||||||
|
# The generation has stopped during final message.
|
||||||
|
is_tool_call = False
|
||||||
|
reasoning_content = output_msgs[0].content[0].text
|
||||||
|
final_content = parser.current_content
|
||||||
|
else:
|
||||||
|
if len(output_msgs) != 2:
|
||||||
|
raise ValueError(
|
||||||
|
"Expected 2 output messages (reasoning and final), "
|
||||||
|
f"but got {len(output_msgs)}.")
|
||||||
|
reasoning_msg, final_msg = output_msgs
|
||||||
|
reasoning_content = reasoning_msg.content[0].text
|
||||||
|
final_content = final_msg.content[0].text
|
||||||
|
is_tool_call = final_msg.recipient is not None
|
||||||
|
return reasoning_content, final_content, is_tool_call
|
||||||
|
|||||||
@@ -323,6 +323,7 @@ class ResponsesRequest(OpenAIBaseModel):
|
|||||||
if (top_p := self.top_p) is None:
|
if (top_p := self.top_p) is None:
|
||||||
top_p = default_sampling_params.get(
|
top_p = default_sampling_params.get(
|
||||||
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
|
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
|
||||||
|
stop_token_ids = default_sampling_params.get("stop_token_ids")
|
||||||
|
|
||||||
# Structured output
|
# Structured output
|
||||||
guided_decoding = None
|
guided_decoding = None
|
||||||
@@ -340,6 +341,7 @@ class ResponsesRequest(OpenAIBaseModel):
|
|||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
logprobs=self.top_logprobs,
|
logprobs=self.top_logprobs,
|
||||||
|
stop_token_ids=stop_token_ids,
|
||||||
output_kind=(RequestOutputKind.DELTA
|
output_kind=(RequestOutputKind.DELTA
|
||||||
if self.stream else RequestOutputKind.FINAL_ONLY),
|
if self.stream else RequestOutputKind.FINAL_ONLY),
|
||||||
guided_decoding=guided_decoding,
|
guided_decoding=guided_decoding,
|
||||||
@@ -404,6 +406,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
Literal["required"],
|
Literal["required"],
|
||||||
ChatCompletionNamedToolChoiceParam,
|
ChatCompletionNamedToolChoiceParam,
|
||||||
]] = "none"
|
]] = "none"
|
||||||
|
reasoning_effort: Optional[Literal["low", "medium", "high"]] = None
|
||||||
|
include_reasoning: bool = True
|
||||||
|
|
||||||
# NOTE this will be ignored by vLLM -- the model determines the behavior
|
# NOTE this will be ignored by vLLM -- the model determines the behavior
|
||||||
parallel_tool_calls: Optional[bool] = False
|
parallel_tool_calls: Optional[bool] = False
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import jinja2
|
|||||||
import partial_json_parser
|
import partial_json_parser
|
||||||
import regex as re
|
import regex as re
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
from openai_harmony import Message as OpenAIMessage
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
@@ -19,6 +20,10 @@ from vllm.engine.protocol import EngineClient
|
|||||||
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
||||||
ConversationMessage,
|
ConversationMessage,
|
||||||
random_tool_call_id)
|
random_tool_call_id)
|
||||||
|
from vllm.entrypoints.harmony_utils import (
|
||||||
|
get_developer_message, get_stop_tokens_for_assistant_actions,
|
||||||
|
get_streamable_parser_for_assistant, get_system_message, parse_chat_input,
|
||||||
|
parse_chat_output, render_for_completion)
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
ChatCompletionLogProb, ChatCompletionLogProbs,
|
ChatCompletionLogProb, ChatCompletionLogProbs,
|
||||||
@@ -35,6 +40,7 @@ from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
|||||||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
|
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
|
||||||
MistralToolCall)
|
MistralToolCall)
|
||||||
from vllm.entrypoints.utils import get_max_tokens
|
from vllm.entrypoints.utils import get_max_tokens
|
||||||
|
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import CompletionOutput, RequestOutput
|
from vllm.outputs import CompletionOutput, RequestOutput
|
||||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||||
@@ -125,6 +131,23 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
logger.info("Using default chat sampling params from %s: %s",
|
logger.info("Using default chat sampling params from %s: %s",
|
||||||
source, self.default_sampling_params)
|
source, self.default_sampling_params)
|
||||||
|
|
||||||
|
self.use_harmony = model_config.hf_config.model_type == "gpt_oss"
|
||||||
|
if self.use_harmony:
|
||||||
|
if "stop_token_ids" not in self.default_sampling_params:
|
||||||
|
self.default_sampling_params["stop_token_ids"] = []
|
||||||
|
self.default_sampling_params["stop_token_ids"].extend(
|
||||||
|
get_stop_tokens_for_assistant_actions())
|
||||||
|
|
||||||
|
# NOTE(woosuk): While OpenAI's chat completion API supports browsing
|
||||||
|
# for some models, currently vLLM doesn't support it. Please use the
|
||||||
|
# Responses API instead.
|
||||||
|
self.supports_browsing = False
|
||||||
|
self.browser_tool = None
|
||||||
|
# NOTE(woosuk): Chat completion API does not support code interpreter.
|
||||||
|
# Please use the Responses API instead.
|
||||||
|
self.supports_code_interpreter = False
|
||||||
|
self.python_tool = None
|
||||||
|
|
||||||
async def create_chat_completion(
|
async def create_chat_completion(
|
||||||
self,
|
self,
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
@@ -169,7 +192,8 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
if (request.tool_choice == "auto" and
|
if (request.tool_choice == "auto" and
|
||||||
not (self.enable_auto_tools and tool_parser is not None)
|
not (self.enable_auto_tools and tool_parser is not None)
|
||||||
and not isinstance(tokenizer, MistralTokenizer)):
|
and not isinstance(tokenizer, MistralTokenizer)
|
||||||
|
and not self.use_harmony):
|
||||||
# for hf tokenizers, "auto" tools requires
|
# for hf tokenizers, "auto" tools requires
|
||||||
# --enable-auto-tool-choice and --tool-call-parser
|
# --enable-auto-tool-choice and --tool-call-parser
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
@@ -184,25 +208,35 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
else:
|
else:
|
||||||
tool_dicts = [tool.model_dump() for tool in request.tools]
|
tool_dicts = [tool.model_dump() for tool in request.tools]
|
||||||
|
|
||||||
(
|
if not self.use_harmony:
|
||||||
conversation,
|
# Common case.
|
||||||
request_prompts,
|
(
|
||||||
engine_prompts,
|
conversation,
|
||||||
) = await self._preprocess_chat(
|
request_prompts,
|
||||||
request,
|
engine_prompts,
|
||||||
tokenizer,
|
) = await self._preprocess_chat(
|
||||||
request.messages,
|
request,
|
||||||
chat_template=request.chat_template or self.chat_template,
|
tokenizer,
|
||||||
chat_template_content_format=self.chat_template_content_format,
|
request.messages,
|
||||||
add_generation_prompt=request.add_generation_prompt,
|
chat_template=request.chat_template or self.chat_template,
|
||||||
continue_final_message=request.continue_final_message,
|
chat_template_content_format=self.
|
||||||
tool_dicts=tool_dicts,
|
chat_template_content_format,
|
||||||
documents=request.documents,
|
add_generation_prompt=request.add_generation_prompt,
|
||||||
chat_template_kwargs=request.chat_template_kwargs,
|
continue_final_message=request.continue_final_message,
|
||||||
tool_parser=tool_parser,
|
tool_dicts=tool_dicts,
|
||||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
documents=request.documents,
|
||||||
add_special_tokens=request.add_special_tokens,
|
chat_template_kwargs=request.chat_template_kwargs,
|
||||||
)
|
tool_parser=tool_parser,
|
||||||
|
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||||
|
add_special_tokens=request.add_special_tokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# For GPT-OSS.
|
||||||
|
(
|
||||||
|
conversation,
|
||||||
|
request_prompts,
|
||||||
|
engine_prompts,
|
||||||
|
) = self._make_request_with_harmony(request)
|
||||||
except (ValueError, TypeError, RuntimeError,
|
except (ValueError, TypeError, RuntimeError,
|
||||||
jinja2.TemplateError) as e:
|
jinja2.TemplateError) as e:
|
||||||
logger.exception("Error in preprocessing prompt inputs")
|
logger.exception("Error in preprocessing prompt inputs")
|
||||||
@@ -436,6 +470,11 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
finish_reason_sent = [False] * num_choices
|
finish_reason_sent = [False] * num_choices
|
||||||
num_prompt_tokens = 0
|
num_prompt_tokens = 0
|
||||||
num_cached_tokens = None
|
num_cached_tokens = None
|
||||||
|
if self.use_harmony:
|
||||||
|
harmony_parsers = [
|
||||||
|
get_streamable_parser_for_assistant()
|
||||||
|
for _ in range(num_choices)
|
||||||
|
]
|
||||||
|
|
||||||
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
||||||
tool_choice_function_name = request.tool_choice.function.name
|
tool_choice_function_name = request.tool_choice.function.name
|
||||||
@@ -597,7 +636,18 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
else:
|
else:
|
||||||
logprobs = None
|
logprobs = None
|
||||||
|
|
||||||
delta_text = output.text
|
if self.use_harmony:
|
||||||
|
harmony_parser = harmony_parsers[i]
|
||||||
|
for token_id in output.token_ids:
|
||||||
|
harmony_parser.process(token_id)
|
||||||
|
# FIXME(woosuk): Support function calling
|
||||||
|
is_final = harmony_parser.current_channel == "final"
|
||||||
|
if not (request.include_reasoning or is_final):
|
||||||
|
# Skip the reasoning content.
|
||||||
|
continue
|
||||||
|
delta_text = harmony_parser.last_content_delta or ""
|
||||||
|
else:
|
||||||
|
delta_text = output.text
|
||||||
|
|
||||||
if not delta_text and not output.token_ids and \
|
if not delta_text and not output.token_ids and \
|
||||||
not previous_num_tokens[i]:
|
not previous_num_tokens[i]:
|
||||||
@@ -607,7 +657,8 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
delta_message: Optional[DeltaMessage]
|
delta_message: Optional[DeltaMessage]
|
||||||
|
|
||||||
# just update previous_texts and previous_token_ids
|
# just update previous_texts and previous_token_ids
|
||||||
if tool_choice_auto or self.reasoning_parser:
|
if ((tool_choice_auto or self.reasoning_parser)
|
||||||
|
and not self.use_harmony):
|
||||||
assert previous_texts is not None
|
assert previous_texts is not None
|
||||||
assert all_previous_token_ids is not None
|
assert all_previous_token_ids is not None
|
||||||
previous_text = previous_texts[i]
|
previous_text = previous_texts[i]
|
||||||
@@ -621,8 +672,14 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
else:
|
else:
|
||||||
current_token_ids = list(output.token_ids)
|
current_token_ids = list(output.token_ids)
|
||||||
|
|
||||||
|
if self.use_harmony:
|
||||||
|
if is_final:
|
||||||
|
delta_message = DeltaMessage(content=delta_text)
|
||||||
|
else:
|
||||||
|
delta_message = DeltaMessage(
|
||||||
|
reasoning_content=delta_text)
|
||||||
# handle streaming deltas for tools with named tool_choice
|
# handle streaming deltas for tools with named tool_choice
|
||||||
if tool_choice_function_name:
|
elif tool_choice_function_name:
|
||||||
if (self.reasoning_parser and not reasoning_end_arr[i]
|
if (self.reasoning_parser and not reasoning_end_arr[i]
|
||||||
and not reasoning_parser.is_reasoning_end(
|
and not reasoning_parser.is_reasoning_end(
|
||||||
previous_token_ids)):
|
previous_token_ids)):
|
||||||
@@ -990,7 +1047,38 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logprobs = None
|
logprobs = None
|
||||||
auto_tools_called = False
|
|
||||||
|
if self.use_harmony:
|
||||||
|
reasoning_content, final_content, is_tool_call = (
|
||||||
|
parse_chat_output(token_ids))
|
||||||
|
if not request.include_reasoning:
|
||||||
|
reasoning_content = None
|
||||||
|
|
||||||
|
if is_tool_call:
|
||||||
|
# TODO(woosuk): Implement tool call for gpt-oss.
|
||||||
|
# For now, only Responses API supports tool call for
|
||||||
|
# gpt-oss.
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Tool call in Chat Completion API is not supported "
|
||||||
|
"for gpt-oss yet. Please use Responses API instead.")
|
||||||
|
else:
|
||||||
|
# Normal message
|
||||||
|
message = ChatMessage(
|
||||||
|
role=role,
|
||||||
|
reasoning_content=reasoning_content,
|
||||||
|
content=final_content,
|
||||||
|
)
|
||||||
|
|
||||||
|
choice_data = ChatCompletionResponseChoice(
|
||||||
|
index=output.index,
|
||||||
|
message=message,
|
||||||
|
logprobs=logprobs,
|
||||||
|
finish_reason="tool_calls" if is_tool_call else
|
||||||
|
output.finish_reason if output.finish_reason else "stop",
|
||||||
|
stop_reason=output.stop_reason,
|
||||||
|
)
|
||||||
|
choices.append(choice_data)
|
||||||
|
continue
|
||||||
|
|
||||||
if self.reasoning_parser:
|
if self.reasoning_parser:
|
||||||
try:
|
try:
|
||||||
@@ -1003,10 +1091,13 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
reasoning_content, content = (
|
reasoning_content, content = (
|
||||||
reasoning_parser.extract_reasoning_content(
|
reasoning_parser.extract_reasoning_content(
|
||||||
output.text, request=request))
|
output.text, request=request))
|
||||||
|
if not request.include_reasoning:
|
||||||
|
reasoning_content = None
|
||||||
else:
|
else:
|
||||||
reasoning_content = None
|
reasoning_content = None
|
||||||
content = output.text
|
content = output.text
|
||||||
|
|
||||||
|
auto_tools_called = False
|
||||||
# if auto tools are not enabled, and a named tool choice using
|
# if auto tools are not enabled, and a named tool choice using
|
||||||
# outlines is not being used
|
# outlines is not being used
|
||||||
if (not self.enable_auto_tools or not self.tool_parser) and \
|
if (not self.enable_auto_tools or not self.tool_parser) and \
|
||||||
@@ -1261,3 +1352,33 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
and delta_message.tool_calls[0].function
|
and delta_message.tool_calls[0].function
|
||||||
and delta_message.tool_calls[0].function.arguments is not None
|
and delta_message.tool_calls[0].function.arguments is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _make_request_with_harmony(
|
||||||
|
self,
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
):
|
||||||
|
messages: list[OpenAIMessage] = []
|
||||||
|
|
||||||
|
# Add system message.
|
||||||
|
# NOTE: In Chat Completion API, browsing is enabled by default
|
||||||
|
# if the model supports it. TODO: Support browsing.
|
||||||
|
assert not self.supports_browsing
|
||||||
|
assert not self.supports_code_interpreter
|
||||||
|
sys_msg = get_system_message(
|
||||||
|
reasoning_effort=request.reasoning_effort,
|
||||||
|
browser_description=None,
|
||||||
|
python_description=None)
|
||||||
|
messages.append(sys_msg)
|
||||||
|
|
||||||
|
# Add developer message.
|
||||||
|
dev_msg = get_developer_message()
|
||||||
|
messages.append(dev_msg)
|
||||||
|
|
||||||
|
# Add user message.
|
||||||
|
for chat_msg in request.messages:
|
||||||
|
messages.append(parse_chat_input(chat_msg))
|
||||||
|
|
||||||
|
# Render prompt token ids.
|
||||||
|
prompt_token_ids = render_for_completion(messages)
|
||||||
|
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
|
||||||
|
return messages, [prompt_token_ids], [engine_prompt]
|
||||||
|
|||||||
Reference in New Issue
Block a user