[Feature][Responses API] Support MCP tool in background mode (#23494)
Signed-off-by: wuhang <wuhang6@huawei.com>
This commit is contained in:
@@ -4,13 +4,15 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import TYPE_CHECKING, Union
|
from contextlib import AsyncExitStack
|
||||||
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
||||||
|
|
||||||
from vllm.entrypoints.harmony_utils import (
|
from vllm.entrypoints.harmony_utils import (
|
||||||
get_encoding, get_streamable_parser_for_assistant, render_for_completion)
|
get_encoding, get_streamable_parser_for_assistant, render_for_completion)
|
||||||
from vllm.entrypoints.tool import Tool
|
from vllm.entrypoints.tool import Tool
|
||||||
|
from vllm.entrypoints.tool_server import ToolServer
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -37,6 +39,11 @@ class ConversationContext(ABC):
|
|||||||
def render_for_completion(self) -> list[int]:
|
def render_for_completion(self) -> list[int]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||||
|
exit_stack: AsyncExitStack) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SimpleContext(ConversationContext):
|
class SimpleContext(ConversationContext):
|
||||||
|
|
||||||
@@ -55,16 +62,21 @@ class SimpleContext(ConversationContext):
|
|||||||
def render_for_completion(self) -> list[int]:
|
def render_for_completion(self) -> list[int]:
|
||||||
raise NotImplementedError("Should not be called.")
|
raise NotImplementedError("Should not be called.")
|
||||||
|
|
||||||
|
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||||
|
exit_stack: AsyncExitStack) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class HarmonyContext(ConversationContext):
|
class HarmonyContext(ConversationContext):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
messages: list,
|
messages: list,
|
||||||
tool_sessions: dict[str, Tool],
|
available_tools: list[str],
|
||||||
):
|
):
|
||||||
self._messages = messages
|
self._messages = messages
|
||||||
self.tool_sessions = tool_sessions
|
self.available_tools = available_tools
|
||||||
|
self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {}
|
||||||
|
|
||||||
self.parser = get_streamable_parser_for_assistant()
|
self.parser = get_streamable_parser_for_assistant()
|
||||||
self.num_init_messages = len(messages)
|
self.num_init_messages = len(messages)
|
||||||
@@ -116,10 +128,10 @@ class HarmonyContext(ConversationContext):
|
|||||||
if recipient is not None:
|
if recipient is not None:
|
||||||
if recipient.startswith("browser."):
|
if recipient.startswith("browser."):
|
||||||
return await self.call_search_tool(
|
return await self.call_search_tool(
|
||||||
self.tool_sessions["browser"], last_msg)
|
self._tool_sessions["browser"], last_msg)
|
||||||
elif recipient.startswith("python"):
|
elif recipient.startswith("python"):
|
||||||
return await self.call_python_tool(
|
return await self.call_python_tool(
|
||||||
self.tool_sessions["python"], last_msg)
|
self._tool_sessions["python"], last_msg)
|
||||||
raise ValueError("No tool call found")
|
raise ValueError("No tool call found")
|
||||||
|
|
||||||
def render_for_completion(self) -> list[int]:
|
def render_for_completion(self) -> list[int]:
|
||||||
@@ -161,6 +173,15 @@ class HarmonyContext(ConversationContext):
|
|||||||
recipient=Role.ASSISTANT)
|
recipient=Role.ASSISTANT)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
async def init_tool_sessions(self, tool_server: Optional[ToolServer],
|
||||||
|
exit_stack: AsyncExitStack) -> None:
|
||||||
|
if tool_server:
|
||||||
|
for tool_name in self.available_tools:
|
||||||
|
if tool_name not in self._tool_sessions:
|
||||||
|
self._tool_sessions[
|
||||||
|
tool_name] = await exit_stack.enter_async_context(
|
||||||
|
tool_server.new_session(tool_name))
|
||||||
|
|
||||||
|
|
||||||
class StreamingHarmonyContext(HarmonyContext):
|
class StreamingHarmonyContext(HarmonyContext):
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from collections.abc import AsyncGenerator, AsyncIterator, Sequence
|
|||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Any, Callable, Final, Optional, Union
|
from typing import Callable, Final, Optional, Union
|
||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
import openai.types.responses as openai_responses_types
|
import openai.types.responses as openai_responses_types
|
||||||
@@ -248,10 +248,10 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
raw_request.state.request_metadata = request_metadata
|
raw_request.state.request_metadata = request_metadata
|
||||||
|
|
||||||
if self.tool_server is not None and isinstance(
|
if self.tool_server is not None and isinstance(
|
||||||
self.tool_server, MCPToolServer
|
self.tool_server,
|
||||||
) and (request.background or request.stream) and request.tools and any(
|
MCPToolServer) and request.stream and request.tools and any(
|
||||||
tool.type in ["web_search_preview", "code_interpreter"]
|
tool.type in ["web_search_preview", "code_interpreter"]
|
||||||
for tool in request.tools):
|
for tool in request.tools):
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
"MCP tool server is not supported in background mode and "
|
"MCP tool server is not supported in background mode and "
|
||||||
"streaming mode")
|
"streaming mode")
|
||||||
@@ -265,103 +265,70 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
builtin_tool_list.append("browser")
|
builtin_tool_list.append("browser")
|
||||||
if self.tool_server.has_tool("python"):
|
if self.tool_server.has_tool("python"):
|
||||||
builtin_tool_list.append("python")
|
builtin_tool_list.append("python")
|
||||||
async with AsyncExitStack() as exit_stack:
|
|
||||||
try:
|
|
||||||
if self.tool_server is not None:
|
|
||||||
# TODO: initialize tool sessions lazily when the session
|
|
||||||
# is actually used.
|
|
||||||
tool_session_ctxs: dict[str, Any] = {
|
|
||||||
tool_name:
|
|
||||||
exit_stack.enter_async_context(
|
|
||||||
self.tool_server.new_session(tool_name))
|
|
||||||
for tool_name in builtin_tool_list
|
|
||||||
}
|
|
||||||
tool_sessions = {}
|
|
||||||
for tool_name in builtin_tool_list:
|
|
||||||
tool_sessions[tool_name] = (
|
|
||||||
await tool_session_ctxs[tool_name])
|
|
||||||
else:
|
|
||||||
assert len(builtin_tool_list) == 0
|
|
||||||
tool_sessions = {}
|
|
||||||
for i, engine_prompt in enumerate(engine_prompts):
|
|
||||||
default_max_tokens = self.max_model_len - len(
|
|
||||||
engine_prompt["prompt_token_ids"])
|
|
||||||
sampling_params = request.to_sampling_params(
|
|
||||||
default_max_tokens, self.default_sampling_params)
|
|
||||||
|
|
||||||
trace_headers = (None if raw_request is None else await
|
if self.tool_server is not None:
|
||||||
self._get_trace_headers(
|
available_tools = builtin_tool_list
|
||||||
raw_request.headers))
|
else:
|
||||||
|
assert len(builtin_tool_list) == 0
|
||||||
|
available_tools = []
|
||||||
|
try:
|
||||||
|
for i, engine_prompt in enumerate(engine_prompts):
|
||||||
|
default_max_tokens = self.max_model_len - len(
|
||||||
|
engine_prompt["prompt_token_ids"])
|
||||||
|
sampling_params = request.to_sampling_params(
|
||||||
|
default_max_tokens, self.default_sampling_params)
|
||||||
|
|
||||||
context: ConversationContext
|
trace_headers = (None if raw_request is None else await
|
||||||
if self.use_harmony:
|
self._get_trace_headers(raw_request.headers))
|
||||||
if request.stream:
|
|
||||||
context = StreamingHarmonyContext(
|
context: ConversationContext
|
||||||
messages, tool_sessions)
|
if self.use_harmony:
|
||||||
else:
|
if request.stream:
|
||||||
context = HarmonyContext(messages, tool_sessions)
|
context = StreamingHarmonyContext(
|
||||||
|
messages, available_tools)
|
||||||
else:
|
else:
|
||||||
context = SimpleContext()
|
context = HarmonyContext(messages, available_tools)
|
||||||
generator = self._generate_with_builtin_tools(
|
else:
|
||||||
request_id=request.request_id,
|
context = SimpleContext()
|
||||||
request_prompt=request_prompts[i],
|
generator = self._generate_with_builtin_tools(
|
||||||
engine_prompt=engine_prompt,
|
request_id=request.request_id,
|
||||||
sampling_params=sampling_params,
|
request_prompt=request_prompts[i],
|
||||||
context=context,
|
engine_prompt=engine_prompt,
|
||||||
lora_request=lora_request,
|
sampling_params=sampling_params,
|
||||||
priority=request.priority,
|
context=context,
|
||||||
trace_headers=trace_headers,
|
lora_request=lora_request,
|
||||||
)
|
priority=request.priority,
|
||||||
generators.append(generator)
|
trace_headers=trace_headers,
|
||||||
except ValueError as e:
|
|
||||||
# TODO: Use a vllm-specific Validation Error
|
|
||||||
return self.create_error_response(str(e))
|
|
||||||
|
|
||||||
assert len(generators) == 1
|
|
||||||
result_generator, = generators
|
|
||||||
|
|
||||||
# Store the input messages.
|
|
||||||
if request.store:
|
|
||||||
self.msg_store[request.request_id] = messages
|
|
||||||
|
|
||||||
if request.background:
|
|
||||||
created_time = int(time.time())
|
|
||||||
response = ResponsesResponse.from_request(
|
|
||||||
request,
|
|
||||||
sampling_params,
|
|
||||||
model_name=model_name,
|
|
||||||
created_time=created_time,
|
|
||||||
output=[],
|
|
||||||
status="queued",
|
|
||||||
usage=None,
|
|
||||||
)
|
)
|
||||||
async with self.response_store_lock:
|
generators.append(generator)
|
||||||
self.response_store[response.id] = response
|
except ValueError as e:
|
||||||
|
# TODO: Use a vllm-specific Validation Error
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
# Run the request in the background.
|
assert len(generators) == 1
|
||||||
task = asyncio.create_task(
|
result_generator, = generators
|
||||||
self._run_background_request(
|
|
||||||
request,
|
|
||||||
sampling_params,
|
|
||||||
result_generator,
|
|
||||||
context,
|
|
||||||
model_name,
|
|
||||||
tokenizer,
|
|
||||||
request_metadata,
|
|
||||||
created_time,
|
|
||||||
),
|
|
||||||
name=f"create_{response.id}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# For cleanup.
|
# Store the input messages.
|
||||||
response_id = response.id
|
if request.store:
|
||||||
self.background_tasks[response_id] = task
|
self.msg_store[request.request_id] = messages
|
||||||
task.add_done_callback(
|
|
||||||
lambda _: self.background_tasks.pop(response_id, None))
|
|
||||||
return response
|
|
||||||
|
|
||||||
if request.stream:
|
if request.background:
|
||||||
return self.responses_stream_generator(
|
created_time = int(time.time())
|
||||||
|
response = ResponsesResponse.from_request(
|
||||||
|
request,
|
||||||
|
sampling_params,
|
||||||
|
model_name=model_name,
|
||||||
|
created_time=created_time,
|
||||||
|
output=[],
|
||||||
|
status="queued",
|
||||||
|
usage=None,
|
||||||
|
)
|
||||||
|
async with self.response_store_lock:
|
||||||
|
self.response_store[response.id] = response
|
||||||
|
|
||||||
|
# Run the request in the background.
|
||||||
|
task = asyncio.create_task(
|
||||||
|
self._run_background_request(
|
||||||
request,
|
request,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
result_generator,
|
result_generator,
|
||||||
@@ -369,21 +336,41 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
model_name,
|
model_name,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
request_metadata,
|
request_metadata,
|
||||||
)
|
created_time,
|
||||||
|
),
|
||||||
|
name=f"create_{response.id}",
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
# For cleanup.
|
||||||
return await self.responses_full_generator(
|
response_id = response.id
|
||||||
request,
|
self.background_tasks[response_id] = task
|
||||||
sampling_params,
|
task.add_done_callback(
|
||||||
result_generator,
|
lambda _: self.background_tasks.pop(response_id, None))
|
||||||
context,
|
return response
|
||||||
model_name,
|
|
||||||
tokenizer,
|
if request.stream:
|
||||||
request_metadata,
|
return self.responses_stream_generator(
|
||||||
)
|
request,
|
||||||
except Exception as e:
|
sampling_params,
|
||||||
return self.create_error_response(str(e))
|
result_generator,
|
||||||
return self.create_error_response("Should not reach here")
|
context,
|
||||||
|
model_name,
|
||||||
|
tokenizer,
|
||||||
|
request_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self.responses_full_generator(
|
||||||
|
request,
|
||||||
|
sampling_params,
|
||||||
|
result_generator,
|
||||||
|
context,
|
||||||
|
model_name,
|
||||||
|
tokenizer,
|
||||||
|
request_metadata,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
async def _make_request(
|
async def _make_request(
|
||||||
self,
|
self,
|
||||||
@@ -439,14 +426,16 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
if created_time is None:
|
if created_time is None:
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
|
|
||||||
try:
|
async with AsyncExitStack() as exit_stack:
|
||||||
async for _ in result_generator:
|
try:
|
||||||
pass
|
await context.init_tool_sessions(self.tool_server, exit_stack)
|
||||||
except asyncio.CancelledError:
|
async for _ in result_generator:
|
||||||
return self.create_error_response("Client disconnected")
|
pass
|
||||||
except ValueError as e:
|
except asyncio.CancelledError:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
return self.create_error_response("Client disconnected")
|
||||||
return self.create_error_response(str(e))
|
except ValueError as e:
|
||||||
|
# TODO: Use a vllm-specific Validation Error
|
||||||
|
return self.create_error_response(str(e))
|
||||||
|
|
||||||
if self.use_harmony:
|
if self.use_harmony:
|
||||||
assert isinstance(context, HarmonyContext)
|
assert isinstance(context, HarmonyContext)
|
||||||
@@ -838,7 +827,7 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
status_code=HTTPStatus.BAD_REQUEST,
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def responses_stream_generator(
|
async def _process_streaming_events(
|
||||||
self,
|
self,
|
||||||
request: ResponsesRequest,
|
request: ResponsesRequest,
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
@@ -847,18 +836,8 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
request_metadata: RequestResponseMetadata,
|
request_metadata: RequestResponseMetadata,
|
||||||
created_time: Optional[int] = None,
|
created_time: int,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
# TODO:
|
|
||||||
# 1. Handle disconnect
|
|
||||||
|
|
||||||
if not isinstance(context, StreamingHarmonyContext):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Streaming is not supported for responses API without Harmony."
|
|
||||||
)
|
|
||||||
|
|
||||||
created_time = created_time or int(time.time())
|
|
||||||
|
|
||||||
sequence_number = 0
|
sequence_number = 0
|
||||||
|
|
||||||
def _send_event(event: BaseModel):
|
def _send_event(event: BaseModel):
|
||||||
@@ -1270,3 +1249,31 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
sequence_number=-1,
|
sequence_number=-1,
|
||||||
response=final_response.model_dump(),
|
response=final_response.model_dump(),
|
||||||
))
|
))
|
||||||
|
|
||||||
|
async def responses_stream_generator(
|
||||||
|
self,
|
||||||
|
request: ResponsesRequest,
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
result_generator: AsyncIterator[Optional[ConversationContext]],
|
||||||
|
context: ConversationContext,
|
||||||
|
model_name: str,
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
request_metadata: RequestResponseMetadata,
|
||||||
|
created_time: Optional[int] = None,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
# TODO:
|
||||||
|
# 1. Handle disconnect
|
||||||
|
|
||||||
|
if not isinstance(context, StreamingHarmonyContext):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Streaming is not supported for responses API without Harmony."
|
||||||
|
)
|
||||||
|
|
||||||
|
created_time = created_time or int(time.time())
|
||||||
|
|
||||||
|
async with AsyncExitStack() as exit_stack:
|
||||||
|
await context.init_tool_sessions(self.tool_server, exit_stack)
|
||||||
|
async for event_data in self._process_streaming_events(
|
||||||
|
request, sampling_params, result_generator, context,
|
||||||
|
model_name, tokenizer, request_metadata, created_time):
|
||||||
|
yield event_data
|
||||||
|
|||||||
Reference in New Issue
Block a user