261 lines
9.5 KiB
Python
261 lines
9.5 KiB
Python
"""
|
|
vLLM → SGLang request middleware.
|
|
|
|
Sits between haproxy and SGLang to strip vLLM-only parameters
|
|
that cause SGLang to return 422/400 errors.
|
|
|
|
Currently strips: logprobs, top_logprobs
|
|
(SGLang's Mistral tool-call parser rejects these; vLLM accepts them.)
|
|
|
|
Architecture:
|
|
haproxy (port N) → middleware (port N+2) → SGLang (port N+1)
|
|
|
|
haproxy still handles /metrics stub and /health instant responses.
|
|
This middleware only touches the proxied request bodies.
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import asyncio
|
|
import httpx
|
|
from datetime import datetime
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import StreamingResponse, Response
|
|
import uvicorn
|
|
|
|
SGLANG_HOST = os.environ.get("SGLANG_HOST", "127.0.0.1")
|
|
SGLANG_PORT = int(os.environ.get("SGLANG_PORT", "8001"))
|
|
LISTEN_PORT = int(os.environ.get("MIDDLEWARE_PORT", "8002"))
|
|
|
|
# Params that vLLM accepts but SGLang rejects.
|
|
# Extend this set as more incompatibilities are discovered.
|
|
STRIP_PARAMS = {"logprobs", "top_logprobs", "chat_template_kwargs", "guided_json", "guided_regex"}
|
|
|
|
client: httpx.AsyncClient | None = None
|
|
_sglang_ready = False
|
|
|
|
|
|
async def _lifespan(app_instance):
|
|
global client
|
|
client = httpx.AsyncClient(
|
|
timeout=httpx.Timeout(300.0, connect=10.0),
|
|
)
|
|
# Background task: wait for SGLang to become available
|
|
asyncio.create_task(_wait_for_sglang())
|
|
yield
|
|
await client.aclose()
|
|
|
|
|
|
async def _wait_for_sglang():
|
|
"""Poll SGLang until it's accepting connections, then mark ready."""
|
|
global _sglang_ready
|
|
while True:
|
|
try:
|
|
resp = await client.get(
|
|
f"http://{SGLANG_HOST}:{SGLANG_PORT}/health",
|
|
timeout=httpx.Timeout(5.0, connect=2.0),
|
|
)
|
|
if resp.status_code == 200:
|
|
_sglang_ready = True
|
|
print(f"Middleware: SGLang is ready at {SGLANG_HOST}:{SGLANG_PORT}")
|
|
return
|
|
except (httpx.ConnectError, httpx.TimeoutException):
|
|
pass
|
|
await asyncio.sleep(2)
|
|
|
|
|
|
app = FastAPI(lifespan=_lifespan)
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
"""Health check — haproxy polls this. Returns 200 only if SGLang is up."""
|
|
global _sglang_ready
|
|
if not _sglang_ready:
|
|
return Response(content="SGLang not ready", status_code=503)
|
|
try:
|
|
resp = await client.get(
|
|
f"http://{SGLANG_HOST}:{SGLANG_PORT}/health",
|
|
timeout=httpx.Timeout(5.0, connect=2.0),
|
|
)
|
|
return Response(content=resp.content, status_code=resp.status_code,
|
|
media_type=resp.headers.get("content-type"))
|
|
except (httpx.ConnectError, httpx.TimeoutException):
|
|
_sglang_ready = False
|
|
# Re-trigger background wait
|
|
asyncio.create_task(_wait_for_sglang())
|
|
return Response(content="SGLang not ready", status_code=503)
|
|
|
|
|
|
ERROR_LOG = os.environ.get("VLLM_SHIM_LOG", "/tmp/vllm-shim.log")
|
|
|
|
|
|
def _fix_schema(schema: dict) -> bool:
|
|
"""Recursively fix a JSON Schema dict: properties must be object, required must be list of strings."""
|
|
fixed = False
|
|
# Fix 'properties' — must be dict, not array/null
|
|
if "properties" in schema and not isinstance(schema["properties"], dict):
|
|
schema["properties"] = {}
|
|
fixed = True
|
|
# Fix 'required' — must be list of strings or absent
|
|
if "required" in schema and not isinstance(schema["required"], list):
|
|
del schema["required"]
|
|
fixed = True
|
|
# Recurse into every property value
|
|
if isinstance(schema.get("properties"), dict):
|
|
for val in schema["properties"].values():
|
|
if isinstance(val, dict):
|
|
if _fix_schema(val):
|
|
fixed = True
|
|
# Recurse into items (for array-of-objects)
|
|
if isinstance(schema.get("items"), dict):
|
|
if _fix_schema(schema["items"]):
|
|
fixed = True
|
|
# Recurse into anyOf, allOf, oneOf
|
|
for key in ("anyOf", "allOf", "oneOf"):
|
|
if isinstance(schema.get(key), list):
|
|
for item in schema[key]:
|
|
if isinstance(item, dict):
|
|
if _fix_schema(item):
|
|
fixed = True
|
|
# Recurse into additionalProperties if it's a schema
|
|
if isinstance(schema.get("additionalProperties"), dict):
|
|
if _fix_schema(schema["additionalProperties"]):
|
|
fixed = True
|
|
return fixed
|
|
|
|
|
|
def _dump_error(request_body: bytes, status_code: int, resp_headers: dict, resp_body_raw: bytes, path: str = ""):
|
|
"""Log full request + response payload when SGLang returns an error (4xx/5xx)."""
|
|
try:
|
|
ts = datetime.now().isoformat()
|
|
req_json = None
|
|
try:
|
|
req_json = json.loads(request_body)
|
|
except (json.JSONDecodeError, UnicodeDecodeError):
|
|
pass
|
|
|
|
resp_text = resp_body_raw.decode("utf-8", errors="replace")[:4000]
|
|
resp_json = None
|
|
try:
|
|
resp_json = json.loads(resp_text)
|
|
except (json.JSONDecodeError, UnicodeDecodeError):
|
|
pass
|
|
|
|
with open(ERROR_LOG, "a") as f:
|
|
f.write(f"\n{'='*60}\n")
|
|
f.write(f"[{ts}] ERROR DUMP — SGLang returned HTTP {status_code}\n")
|
|
f.write(f"Path: {path}\n")
|
|
f.write(f"--- Request Body ---\n")
|
|
if req_json:
|
|
f.write(json.dumps(req_json, indent=2, ensure_ascii=False)[:8000])
|
|
else:
|
|
f.write(request_body.decode("utf-8", errors="replace")[:8000])
|
|
f.write(f"\n--- Response (HTTP {status_code}) ---\n")
|
|
if resp_json:
|
|
f.write(json.dumps(resp_json, indent=2, ensure_ascii=False)[:4000])
|
|
else:
|
|
f.write(resp_text)
|
|
f.write(f"\n{'='*60}\n")
|
|
|
|
print(f"[{ts}] ERROR DUMP: HTTP {status_code} on {path} — full payload written to {ERROR_LOG}")
|
|
except Exception as e:
|
|
print(f"_dump_error failed: {e}")
|
|
|
|
|
|
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
|
|
async def proxy(path: str, request: Request):
|
|
body = await request.body()
|
|
is_streaming = False
|
|
|
|
# Strip incompatible params from chat completion POST requests
|
|
if request.method == "POST" and "chat/completions" in path and body:
|
|
try:
|
|
data = json.loads(body)
|
|
is_streaming = data.get("stream", False)
|
|
stripped_any = False
|
|
for key in STRIP_PARAMS:
|
|
if key in data:
|
|
del data[key]
|
|
stripped_any = True
|
|
|
|
# Fix tool function parameters: recurse to fix ALL bad properties/required
|
|
tools = data.get("tools")
|
|
if isinstance(tools, list):
|
|
for tool in tools:
|
|
func = tool.get("function") if isinstance(tool, dict) else None
|
|
if not isinstance(func, dict):
|
|
continue
|
|
if not isinstance(func.get("parameters"), dict):
|
|
func["parameters"] = {"type": "object", "properties": {}}
|
|
stripped_any = True
|
|
if _fix_schema(func["parameters"]):
|
|
stripped_any = True
|
|
|
|
if stripped_any:
|
|
body = json.dumps(data).encode()
|
|
except (json.JSONDecodeError, UnicodeDecodeError):
|
|
pass
|
|
|
|
# Forward headers (skip hop-by-hop and ones we're replacing)
|
|
fwd_headers = {
|
|
k: v for k, v in request.headers.items()
|
|
if k.lower() not in ("host", "content-length", "transfer-encoding")
|
|
}
|
|
fwd_headers["content-length"] = str(len(body))
|
|
|
|
url = f"http://{SGLANG_HOST}:{SGLANG_PORT}/{path}"
|
|
if request.query_params:
|
|
url += f"?{request.query_params}"
|
|
|
|
try:
|
|
if is_streaming:
|
|
req = client.build_request(request.method, url, content=body, headers=fwd_headers)
|
|
resp = await client.send(req, stream=True)
|
|
|
|
# Dump on error for streaming responses
|
|
if resp.status_code >= 400:
|
|
error_body = await resp.aread()
|
|
_dump_error(body, resp.status_code, resp_headers=dict(resp.headers), resp_body_raw=error_body, path=path)
|
|
await resp.aclose()
|
|
return Response(
|
|
content=error_body,
|
|
status_code=resp.status_code,
|
|
media_type=resp.headers.get("content-type"),
|
|
)
|
|
|
|
async def stream_body():
|
|
try:
|
|
async for chunk in resp.aiter_bytes():
|
|
yield chunk
|
|
finally:
|
|
await resp.aclose()
|
|
|
|
return StreamingResponse(
|
|
stream_body(),
|
|
status_code=resp.status_code,
|
|
headers={"content-type": resp.headers.get("content-type", "text/event-stream")},
|
|
)
|
|
else:
|
|
resp = await client.request(request.method, url, content=body, headers=fwd_headers)
|
|
|
|
# Dump on error
|
|
if resp.status_code >= 400:
|
|
_dump_error(body, resp.status_code, resp_headers=dict(resp.headers), resp_body_raw=resp.content, path=path)
|
|
|
|
return Response(
|
|
content=resp.content,
|
|
status_code=resp.status_code,
|
|
media_type=resp.headers.get("content-type"),
|
|
)
|
|
except (httpx.ConnectError, httpx.TimeoutException) as e:
|
|
return Response(
|
|
content=json.dumps({"error": {"message": f"SGLang backend unavailable: {e}", "type": "backend_error"}}),
|
|
status_code=503,
|
|
media_type="application/json",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run(app, host="0.0.0.0", port=LISTEN_PORT, log_level="warning")
|