[gpt-oss][1a] create_responses stream outputs BaseModel type, api server is SSE still (#24759)
Signed-off-by: Andrew Xia <axia@meta.com>
This commit is contained in:
@@ -15,7 +15,7 @@ import socket
|
||||
import tempfile
|
||||
import uuid
|
||||
from argparse import Namespace
|
||||
from collections.abc import AsyncIterator, Awaitable
|
||||
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from http import HTTPStatus
|
||||
@@ -29,6 +29,7 @@ from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from openai import BaseModel
|
||||
from prometheus_client import make_asgi_app
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
@@ -577,6 +578,18 @@ async def show_version():
|
||||
return JSONResponse(content=ver)
|
||||
|
||||
|
||||
async def _convert_stream_to_sse_events(
|
||||
generator: AsyncGenerator[BaseModel,
|
||||
None]) -> AsyncGenerator[str, None]:
|
||||
"""Convert the generator to a stream of events in SSE format"""
|
||||
async for event in generator:
|
||||
event_type = getattr(event, 'type', 'unknown')
|
||||
# https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format
|
||||
event_data = (f"event: {event_type}\n"
|
||||
f"data: {event.model_dump_json(indent=None)}\n\n")
|
||||
yield event_data
|
||||
|
||||
|
||||
@router.post("/v1/responses",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
@@ -612,7 +625,9 @@ async def create_responses(request: ResponsesRequest, raw_request: Request):
|
||||
status_code=generator.error.code)
|
||||
elif isinstance(generator, ResponsesResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
return StreamingResponse(content=_convert_stream_to_sse_events(generator),
|
||||
media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.get("/v1/responses/{response_id}")
|
||||
@@ -640,10 +655,10 @@ async def retrieve_responses(
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.error.code)
|
||||
elif stream:
|
||||
return StreamingResponse(content=response,
|
||||
media_type="text/event-stream")
|
||||
return JSONResponse(content=response.model_dump())
|
||||
elif isinstance(response, ResponsesResponse):
|
||||
return JSONResponse(content=response.model_dump())
|
||||
return StreamingResponse(content=_convert_stream_to_sse_events(response),
|
||||
media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post("/v1/responses/{response_id}/cancel")
|
||||
|
||||
Reference in New Issue
Block a user