[gpt-oss][1a] create_responses stream outputs BaseModel type, api server is SSE still (#24759)

Signed-off-by: Andrew Xia <axia@meta.com>
This commit is contained in:
Andrew Xia
2025-09-15 13:08:08 -07:00
committed by GitHub
parent 25aba2b6a3
commit 73df49ef3a
2 changed files with 90 additions and 71 deletions

View File

@@ -15,7 +15,7 @@ import socket
import tempfile
import uuid
from argparse import Namespace
from collections.abc import AsyncIterator, Awaitable
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable
from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
@@ -29,6 +29,7 @@ from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from openai import BaseModel
from prometheus_client import make_asgi_app
from prometheus_fastapi_instrumentator import Instrumentator
from starlette.concurrency import iterate_in_threadpool
@@ -577,6 +578,18 @@ async def show_version():
return JSONResponse(content=ver)
async def _convert_stream_to_sse_events(
generator: AsyncGenerator[BaseModel,
None]) -> AsyncGenerator[str, None]:
"""Convert the generator to a stream of events in SSE format"""
async for event in generator:
event_type = getattr(event, 'type', 'unknown')
# https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format
event_data = (f"event: {event_type}\n"
f"data: {event.model_dump_json(indent=None)}\n\n")
yield event_data
@router.post("/v1/responses",
dependencies=[Depends(validate_json_request)],
responses={
@@ -612,7 +625,9 @@ async def create_responses(request: ResponsesRequest, raw_request: Request):
status_code=generator.error.code)
elif isinstance(generator, ResponsesResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
return StreamingResponse(content=_convert_stream_to_sse_events(generator),
media_type="text/event-stream")
@router.get("/v1/responses/{response_id}")
@@ -640,10 +655,10 @@ async def retrieve_responses(
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.error.code)
elif stream:
return StreamingResponse(content=response,
media_type="text/event-stream")
return JSONResponse(content=response.model_dump())
elif isinstance(response, ResponsesResponse):
return JSONResponse(content=response.model_dump())
return StreamingResponse(content=_convert_stream_to_sse_events(response),
media_type="text/event-stream")
@router.post("/v1/responses/{response_id}/cancel")