diff --git a/vllm_middleware.py b/vllm_middleware.py index d4caffc..aedd967 100644 --- a/vllm_middleware.py +++ b/vllm_middleware.py @@ -31,19 +31,19 @@ LISTEN_PORT = int(os.environ.get("MIDDLEWARE_PORT", "8002")) # Extend this set as more incompatibilities are discovered. STRIP_PARAMS = {"logprobs", "top_logprobs"} -app = FastAPI() client: httpx.AsyncClient | None = None _sglang_ready = False -@app.on_event("startup") -async def startup(): +async def _lifespan(app_instance): global client client = httpx.AsyncClient( timeout=httpx.Timeout(300.0, connect=10.0), ) # Background task: wait for SGLang to become available asyncio.create_task(_wait_for_sglang()) + yield + await client.aclose() async def _wait_for_sglang(): @@ -64,14 +64,13 @@ async def _wait_for_sglang(): await asyncio.sleep(2) -@app.on_event("shutdown") -async def shutdown(): - await client.aclose() +app = FastAPI(lifespan=_lifespan) @app.get("/health") async def health(): """Health check — haproxy polls this. Returns 200 only if SGLang is up.""" + global _sglang_ready if not _sglang_ready: return Response(content="SGLang not ready", status_code=503) try: