diff --git a/vllm_middleware.py b/vllm_middleware.py index 088dbe9..2c1a66b 100644 --- a/vllm_middleware.py +++ b/vllm_middleware.py @@ -8,7 +8,7 @@ Currently strips: logprobs, top_logprobs (SGLang's Mistral tool-call parser rejects these; vLLM accepts them.) Architecture: - haproxy (original port) → middleware (port+2) → SGLang (port+1) + haproxy (port N) → middleware (port N+2) → SGLang (port N+1) haproxy still handles /metrics stub and /health instant responses. This middleware only touches the proxied request bodies. @@ -16,11 +16,13 @@ This middleware only touches the proxied request bodies. import json import os +import asyncio import httpx from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse, Response import uvicorn +SGLANG_HOST = os.environ.get("SGLANG_HOST", "127.0.0.1") SGLANG_PORT = int(os.environ.get("SGLANG_PORT", "8001")) LISTEN_PORT = int(os.environ.get("MIDDLEWARE_PORT", "8002")) @@ -30,15 +32,35 @@ STRIP_PARAMS = {"logprobs", "top_logprobs"} app = FastAPI() client: httpx.AsyncClient | None = None +_sglang_ready = False @app.on_event("startup") async def startup(): global client client = httpx.AsyncClient( - base_url=f"http://127.0.0.1:{SGLANG_PORT}", - timeout=httpx.Timeout(300.0), + timeout=httpx.Timeout(300.0, connect=10.0), ) + # Background task: wait for SGLang to become available + asyncio.create_task(_wait_for_sglang()) + + +async def _wait_for_sglang(): + """Poll SGLang until it's accepting connections, then mark ready.""" + global _sglang_ready + while True: + try: + resp = await client.get( + f"http://{SGLANG_HOST}:{SGLANG_PORT}/health", + timeout=httpx.Timeout(5.0, connect=2.0), + ) + if resp.status_code == 200: + _sglang_ready = True + print(f"Middleware: SGLang is ready at {SGLANG_HOST}:{SGLANG_PORT}") + return + except (httpx.ConnectError, httpx.TimeoutException): + pass + await asyncio.sleep(2) @app.on_event("shutdown") @@ -46,6 +68,25 @@ async def shutdown(): await client.aclose() +@app.get("/health") +async def health(): + """Health check — haproxy polls this. Returns 200 only if SGLang is up.""" + if not _sglang_ready: + return Response(content="SGLang not ready", status_code=503) + try: + resp = await client.get( + f"http://{SGLANG_HOST}:{SGLANG_PORT}/health", + timeout=httpx.Timeout(5.0, connect=2.0), + ) + return Response(content=resp.content, status_code=resp.status_code, + media_type=resp.headers.get("content-type")) + except (httpx.ConnectError, httpx.TimeoutException): + _sglang_ready = False + # Re-trigger background wait + asyncio.create_task(_wait_for_sglang()) + return Response(content="SGLang not ready", status_code=503) + + @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]) async def proxy(path: str, request: Request): body = await request.body() @@ -73,32 +114,39 @@ async def proxy(path: str, request: Request): } fwd_headers["content-length"] = str(len(body)) - url = f"http://127.0.0.1:{SGLANG_PORT}/{path}" + url = f"http://{SGLANG_HOST}:{SGLANG_PORT}/{path}" if request.query_params: url += f"?{request.query_params}" - if is_streaming: - req = client.build_request(request.method, url, content=body, headers=fwd_headers) - resp = await client.send(req, stream=True) + try: + if is_streaming: + req = client.build_request(request.method, url, content=body, headers=fwd_headers) + resp = await client.send(req, stream=True) - async def stream_body(): - try: - async for chunk in resp.aiter_bytes(): - yield chunk - finally: - await resp.aclose() + async def stream_body(): + try: + async for chunk in resp.aiter_bytes(): + yield chunk + finally: + await resp.aclose() - return StreamingResponse( - stream_body(), - status_code=resp.status_code, - headers={"content-type": resp.headers.get("content-type", "text/event-stream")}, - ) - else: - resp = await client.request(request.method, url, content=body, headers=fwd_headers) + return StreamingResponse( + stream_body(), + status_code=resp.status_code, + headers={"content-type": resp.headers.get("content-type", "text/event-stream")}, + ) + else: + resp = await client.request(request.method, url, content=body, headers=fwd_headers) + return Response( + content=resp.content, + status_code=resp.status_code, + media_type=resp.headers.get("content-type"), + ) + except (httpx.ConnectError, httpx.TimeoutException) as e: return Response( - content=resp.content, - status_code=resp.status_code, - media_type=resp.headers.get("content-type"), + content=json.dumps({"error": {"message": f"SGLang backend unavailable: {e}", "type": "backend_error"}}), + status_code=503, + media_type="application/json", )