[gpt-oss] Support chat completion api (#22342)
This commit is contained in:
@@ -12,6 +12,7 @@ import jinja2
|
||||
import partial_json_parser
|
||||
import regex as re
|
||||
from fastapi import Request
|
||||
from openai_harmony import Message as OpenAIMessage
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
@@ -19,6 +20,10 @@ from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
||||
ConversationMessage,
|
||||
random_tool_call_id)
|
||||
from vllm.entrypoints.harmony_utils import (
|
||||
get_developer_message, get_stop_tokens_for_assistant_actions,
|
||||
get_streamable_parser_for_assistant, get_system_message, parse_chat_input,
|
||||
parse_chat_output, render_for_completion)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionLogProb, ChatCompletionLogProbs,
|
||||
@@ -35,6 +40,7 @@ from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
|
||||
MistralToolCall)
|
||||
from vllm.entrypoints.utils import get_max_tokens
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
@@ -125,6 +131,23 @@ class OpenAIServingChat(OpenAIServing):
|
||||
logger.info("Using default chat sampling params from %s: %s",
|
||||
source, self.default_sampling_params)
|
||||
|
||||
self.use_harmony = model_config.hf_config.model_type == "gpt_oss"
|
||||
if self.use_harmony:
|
||||
if "stop_token_ids" not in self.default_sampling_params:
|
||||
self.default_sampling_params["stop_token_ids"] = []
|
||||
self.default_sampling_params["stop_token_ids"].extend(
|
||||
get_stop_tokens_for_assistant_actions())
|
||||
|
||||
# NOTE(woosuk): While OpenAI's chat completion API supports browsing
|
||||
# for some models, currently vLLM doesn't support it. Please use the
|
||||
# Responses API instead.
|
||||
self.supports_browsing = False
|
||||
self.browser_tool = None
|
||||
# NOTE(woosuk): Chat completion API does not support code interpreter.
|
||||
# Please use the Responses API instead.
|
||||
self.supports_code_interpreter = False
|
||||
self.python_tool = None
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
@@ -169,7 +192,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
if (request.tool_choice == "auto" and
|
||||
not (self.enable_auto_tools and tool_parser is not None)
|
||||
and not isinstance(tokenizer, MistralTokenizer)):
|
||||
and not isinstance(tokenizer, MistralTokenizer)
|
||||
and not self.use_harmony):
|
||||
# for hf tokenizers, "auto" tools requires
|
||||
# --enable-auto-tool-choice and --tool-call-parser
|
||||
return self.create_error_response(
|
||||
@@ -184,25 +208,35 @@ class OpenAIServingChat(OpenAIServing):
|
||||
else:
|
||||
tool_dicts = [tool.model_dump() for tool in request.tools]
|
||||
|
||||
(
|
||||
conversation,
|
||||
request_prompts,
|
||||
engine_prompts,
|
||||
) = await self._preprocess_chat(
|
||||
request,
|
||||
tokenizer,
|
||||
request.messages,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
chat_template_content_format=self.chat_template_content_format,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
tool_dicts=tool_dicts,
|
||||
documents=request.documents,
|
||||
chat_template_kwargs=request.chat_template_kwargs,
|
||||
tool_parser=tool_parser,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
if not self.use_harmony:
|
||||
# Common case.
|
||||
(
|
||||
conversation,
|
||||
request_prompts,
|
||||
engine_prompts,
|
||||
) = await self._preprocess_chat(
|
||||
request,
|
||||
tokenizer,
|
||||
request.messages,
|
||||
chat_template=request.chat_template or self.chat_template,
|
||||
chat_template_content_format=self.
|
||||
chat_template_content_format,
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
continue_final_message=request.continue_final_message,
|
||||
tool_dicts=tool_dicts,
|
||||
documents=request.documents,
|
||||
chat_template_kwargs=request.chat_template_kwargs,
|
||||
tool_parser=tool_parser,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
# For GPT-OSS.
|
||||
(
|
||||
conversation,
|
||||
request_prompts,
|
||||
engine_prompts,
|
||||
) = self._make_request_with_harmony(request)
|
||||
except (ValueError, TypeError, RuntimeError,
|
||||
jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
@@ -436,6 +470,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
finish_reason_sent = [False] * num_choices
|
||||
num_prompt_tokens = 0
|
||||
num_cached_tokens = None
|
||||
if self.use_harmony:
|
||||
harmony_parsers = [
|
||||
get_streamable_parser_for_assistant()
|
||||
for _ in range(num_choices)
|
||||
]
|
||||
|
||||
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
||||
tool_choice_function_name = request.tool_choice.function.name
|
||||
@@ -597,7 +636,18 @@ class OpenAIServingChat(OpenAIServing):
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
delta_text = output.text
|
||||
if self.use_harmony:
|
||||
harmony_parser = harmony_parsers[i]
|
||||
for token_id in output.token_ids:
|
||||
harmony_parser.process(token_id)
|
||||
# FIXME(woosuk): Support function calling
|
||||
is_final = harmony_parser.current_channel == "final"
|
||||
if not (request.include_reasoning or is_final):
|
||||
# Skip the reasoning content.
|
||||
continue
|
||||
delta_text = harmony_parser.last_content_delta or ""
|
||||
else:
|
||||
delta_text = output.text
|
||||
|
||||
if not delta_text and not output.token_ids and \
|
||||
not previous_num_tokens[i]:
|
||||
@@ -607,7 +657,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
delta_message: Optional[DeltaMessage]
|
||||
|
||||
# just update previous_texts and previous_token_ids
|
||||
if tool_choice_auto or self.reasoning_parser:
|
||||
if ((tool_choice_auto or self.reasoning_parser)
|
||||
and not self.use_harmony):
|
||||
assert previous_texts is not None
|
||||
assert all_previous_token_ids is not None
|
||||
previous_text = previous_texts[i]
|
||||
@@ -621,8 +672,14 @@ class OpenAIServingChat(OpenAIServing):
|
||||
else:
|
||||
current_token_ids = list(output.token_ids)
|
||||
|
||||
if self.use_harmony:
|
||||
if is_final:
|
||||
delta_message = DeltaMessage(content=delta_text)
|
||||
else:
|
||||
delta_message = DeltaMessage(
|
||||
reasoning_content=delta_text)
|
||||
# handle streaming deltas for tools with named tool_choice
|
||||
if tool_choice_function_name:
|
||||
elif tool_choice_function_name:
|
||||
if (self.reasoning_parser and not reasoning_end_arr[i]
|
||||
and not reasoning_parser.is_reasoning_end(
|
||||
previous_token_ids)):
|
||||
@@ -990,7 +1047,38 @@ class OpenAIServingChat(OpenAIServing):
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
auto_tools_called = False
|
||||
|
||||
if self.use_harmony:
|
||||
reasoning_content, final_content, is_tool_call = (
|
||||
parse_chat_output(token_ids))
|
||||
if not request.include_reasoning:
|
||||
reasoning_content = None
|
||||
|
||||
if is_tool_call:
|
||||
# TODO(woosuk): Implement tool call for gpt-oss.
|
||||
# For now, only Responses API supports tool call for
|
||||
# gpt-oss.
|
||||
raise NotImplementedError(
|
||||
"Tool call in Chat Completion API is not supported "
|
||||
"for gpt-oss yet. Please use Responses API instead.")
|
||||
else:
|
||||
# Normal message
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
reasoning_content=reasoning_content,
|
||||
content=final_content,
|
||||
)
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=message,
|
||||
logprobs=logprobs,
|
||||
finish_reason="tool_calls" if is_tool_call else
|
||||
output.finish_reason if output.finish_reason else "stop",
|
||||
stop_reason=output.stop_reason,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
continue
|
||||
|
||||
if self.reasoning_parser:
|
||||
try:
|
||||
@@ -1003,10 +1091,13 @@ class OpenAIServingChat(OpenAIServing):
|
||||
reasoning_content, content = (
|
||||
reasoning_parser.extract_reasoning_content(
|
||||
output.text, request=request))
|
||||
if not request.include_reasoning:
|
||||
reasoning_content = None
|
||||
else:
|
||||
reasoning_content = None
|
||||
content = output.text
|
||||
|
||||
auto_tools_called = False
|
||||
# if auto tools are not enabled, and a named tool choice using
|
||||
# outlines is not being used
|
||||
if (not self.enable_auto_tools or not self.tool_parser) and \
|
||||
@@ -1261,3 +1352,33 @@ class OpenAIServingChat(OpenAIServing):
|
||||
and delta_message.tool_calls[0].function
|
||||
and delta_message.tool_calls[0].function.arguments is not None
|
||||
)
|
||||
|
||||
def _make_request_with_harmony(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
):
|
||||
messages: list[OpenAIMessage] = []
|
||||
|
||||
# Add system message.
|
||||
# NOTE: In Chat Completion API, browsing is enabled by default
|
||||
# if the model supports it. TODO: Support browsing.
|
||||
assert not self.supports_browsing
|
||||
assert not self.supports_code_interpreter
|
||||
sys_msg = get_system_message(
|
||||
reasoning_effort=request.reasoning_effort,
|
||||
browser_description=None,
|
||||
python_description=None)
|
||||
messages.append(sys_msg)
|
||||
|
||||
# Add developer message.
|
||||
dev_msg = get_developer_message()
|
||||
messages.append(dev_msg)
|
||||
|
||||
# Add user message.
|
||||
for chat_msg in request.messages:
|
||||
messages.append(parse_chat_input(chat_msg))
|
||||
|
||||
# Render prompt token ids.
|
||||
prompt_token_ids = render_for_completion(messages)
|
||||
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
|
||||
return messages, [prompt_token_ids], [engine_prompt]
|
||||
|
||||
Reference in New Issue
Block a user