""" vLLM → SGLang request middleware. Sits between haproxy and SGLang to strip vLLM-only parameters that cause SGLang to return 422/400 errors. Currently strips: logprobs, top_logprobs (SGLang's Mistral tool-call parser rejects these; vLLM accepts them.) Architecture: haproxy (port N) → middleware (port N+2) → SGLang (port N+1) haproxy still handles /metrics stub and /health instant responses. This middleware only touches the proxied request bodies. """ import json import os import asyncio import httpx from datetime import datetime from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse, Response import uvicorn SGLANG_HOST = os.environ.get("SGLANG_HOST", "127.0.0.1") SGLANG_PORT = int(os.environ.get("SGLANG_PORT", "8001")) LISTEN_PORT = int(os.environ.get("MIDDLEWARE_PORT", "8002")) # Params that vLLM accepts but SGLang rejects. # Extend this set as more incompatibilities are discovered. STRIP_PARAMS = {"logprobs", "top_logprobs", "chat_template_kwargs", "guided_json", "guided_regex"} client: httpx.AsyncClient | None = None _sglang_ready = False async def _lifespan(app_instance): global client client = httpx.AsyncClient( timeout=httpx.Timeout(300.0, connect=10.0), ) # Background task: wait for SGLang to become available asyncio.create_task(_wait_for_sglang()) yield await client.aclose() async def _wait_for_sglang(): """Poll SGLang until it's accepting connections, then mark ready.""" global _sglang_ready while True: try: resp = await client.get( f"http://{SGLANG_HOST}:{SGLANG_PORT}/health", timeout=httpx.Timeout(5.0, connect=2.0), ) if resp.status_code == 200: _sglang_ready = True print(f"Middleware: SGLang is ready at {SGLANG_HOST}:{SGLANG_PORT}") return except (httpx.ConnectError, httpx.TimeoutException): pass await asyncio.sleep(2) app = FastAPI(lifespan=_lifespan) @app.get("/health") async def health(): """Health check — haproxy polls this. Returns 200 only if SGLang is up.""" global _sglang_ready if not _sglang_ready: return Response(content="SGLang not ready", status_code=503) try: resp = await client.get( f"http://{SGLANG_HOST}:{SGLANG_PORT}/health", timeout=httpx.Timeout(5.0, connect=2.0), ) return Response(content=resp.content, status_code=resp.status_code, media_type=resp.headers.get("content-type")) except (httpx.ConnectError, httpx.TimeoutException): _sglang_ready = False # Re-trigger background wait asyncio.create_task(_wait_for_sglang()) return Response(content="SGLang not ready", status_code=503) ERROR_LOG = os.environ.get("VLLM_SHIM_LOG", "/tmp/vllm-shim.log") def _fix_schema(schema: dict) -> bool: """Recursively fix a JSON Schema dict: properties must be object, required must be list of strings.""" fixed = False # Fix 'properties' — must be dict, not array/null if "properties" in schema and not isinstance(schema["properties"], dict): schema["properties"] = {} fixed = True # Fix 'required' — must be list of strings or absent if "required" in schema and not isinstance(schema["required"], list): del schema["required"] fixed = True # Recurse into every property value if isinstance(schema.get("properties"), dict): for val in schema["properties"].values(): if isinstance(val, dict): if _fix_schema(val): fixed = True # Recurse into items (for array-of-objects) if isinstance(schema.get("items"), dict): if _fix_schema(schema["items"]): fixed = True # Recurse into anyOf, allOf, oneOf for key in ("anyOf", "allOf", "oneOf"): if isinstance(schema.get(key), list): for item in schema[key]: if isinstance(item, dict): if _fix_schema(item): fixed = True # Recurse into additionalProperties if it's a schema if isinstance(schema.get("additionalProperties"), dict): if _fix_schema(schema["additionalProperties"]): fixed = True return fixed def _dump_error(request_body: bytes, status_code: int, resp_headers: dict, resp_body_raw: bytes, path: str = ""): """Log full request + response payload when SGLang returns an error (4xx/5xx).""" try: ts = datetime.now().isoformat() req_json = None try: req_json = json.loads(request_body) except (json.JSONDecodeError, UnicodeDecodeError): pass resp_text = resp_body_raw.decode("utf-8", errors="replace")[:4000] resp_json = None try: resp_json = json.loads(resp_text) except (json.JSONDecodeError, UnicodeDecodeError): pass with open(ERROR_LOG, "a") as f: f.write(f"\n{'='*60}\n") f.write(f"[{ts}] ERROR DUMP — SGLang returned HTTP {status_code}\n") f.write(f"Path: {path}\n") f.write(f"--- Request Body ---\n") if req_json: f.write(json.dumps(req_json, indent=2, ensure_ascii=False)[:8000]) else: f.write(request_body.decode("utf-8", errors="replace")[:8000]) f.write(f"\n--- Response (HTTP {status_code}) ---\n") if resp_json: f.write(json.dumps(resp_json, indent=2, ensure_ascii=False)[:4000]) else: f.write(resp_text) f.write(f"\n{'='*60}\n") print(f"[{ts}] ERROR DUMP: HTTP {status_code} on {path} — full payload written to {ERROR_LOG}") except Exception as e: print(f"_dump_error failed: {e}") @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"]) async def proxy(path: str, request: Request): body = await request.body() is_streaming = False # Strip incompatible params from chat completion POST requests if request.method == "POST" and "chat/completions" in path and body: try: data = json.loads(body) is_streaming = data.get("stream", False) stripped_any = False for key in STRIP_PARAMS: if key in data: del data[key] stripped_any = True # Fix tool function parameters: recurse to fix ALL bad properties/required tools = data.get("tools") if isinstance(tools, list): for tool in tools: func = tool.get("function") if isinstance(tool, dict) else None if not isinstance(func, dict): continue if not isinstance(func.get("parameters"), dict): func["parameters"] = {"type": "object", "properties": {}} stripped_any = True if _fix_schema(func["parameters"]): stripped_any = True if stripped_any: body = json.dumps(data).encode() except (json.JSONDecodeError, UnicodeDecodeError): pass # Forward headers (skip hop-by-hop and ones we're replacing) fwd_headers = { k: v for k, v in request.headers.items() if k.lower() not in ("host", "content-length", "transfer-encoding") } fwd_headers["content-length"] = str(len(body)) url = f"http://{SGLANG_HOST}:{SGLANG_PORT}/{path}" if request.query_params: url += f"?{request.query_params}" try: if is_streaming: req = client.build_request(request.method, url, content=body, headers=fwd_headers) resp = await client.send(req, stream=True) # Dump on error for streaming responses if resp.status_code >= 400: error_body = await resp.aread() _dump_error(body, resp.status_code, resp_headers=dict(resp.headers), resp_body_raw=error_body, path=path) await resp.aclose() return Response( content=error_body, status_code=resp.status_code, media_type=resp.headers.get("content-type"), ) async def stream_body(): try: async for chunk in resp.aiter_bytes(): yield chunk finally: await resp.aclose() return StreamingResponse( stream_body(), status_code=resp.status_code, headers={"content-type": resp.headers.get("content-type", "text/event-stream")}, ) else: resp = await client.request(request.method, url, content=body, headers=fwd_headers) # Dump on error if resp.status_code >= 400: _dump_error(body, resp.status_code, resp_headers=dict(resp.headers), resp_body_raw=resp.content, path=path) return Response( content=resp.content, status_code=resp.status_code, media_type=resp.headers.get("content-type"), ) except (httpx.ConnectError, httpx.TimeoutException) as e: return Response( content=json.dumps({"error": {"message": f"SGLang backend unavailable: {e}", "type": "backend_error"}}), status_code=503, media_type="application/json", ) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=LISTEN_PORT, log_level="warning")