339 lines
12 KiB
Python
339 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
vLLM -> SGLang Python shim.
|
|
Catches `python -m vllm.entrypoints.openai.api_server` (and similar)
|
|
and launches SGLang behind haproxy + middleware instead.
|
|
|
|
Dynamically translates vLLM CLI args to SGLang equivalents.
|
|
No hardcoded model name or tensor-parallel size.
|
|
|
|
Architecture:
|
|
haproxy on the vLLM port (front door)
|
|
/metrics → 200 empty (stub)
|
|
/health → 200 if SGLang backend is up, 503 if not (instant)
|
|
/* → proxy to middleware on port+2
|
|
middleware on port+2 (strips vLLM-only params, fixes tool schemas)
|
|
SGLang on port+1 (internal)
|
|
"""
|
|
import os
|
|
import sys
|
|
import subprocess
|
|
import time
|
|
import datetime
|
|
|
|
# ── vLLM → SGLang argument mapping ──────────────────────────
|
|
# Key = vLLM flag, value = (sglang_flag, has_value)
|
|
# has_value=True means the flag takes an argument (e.g. --port 8000)
|
|
# has_value=False means it's a boolean flag (e.g. --no-enable-prefix-caching)
|
|
ARG_MAP = {
|
|
# Direct renames (vLLM name → SGLang name)
|
|
"--tensor-parallel-size": ("--tp", True),
|
|
"--gpu_memory_utilization": ("--mem-fraction-static", True),
|
|
"--max_model_len": ("--max-running-requests", True), # approximate
|
|
"--max-model-len": ("--max-running-requests", True), # kebab variant
|
|
"--enforce_eager": ("--enable-torch-compile", False), # opposite intent, skip by default
|
|
"--trust_remote_code": ("--trust-remote-code", False),
|
|
"--trust-remote-code": ("--trust-remote-code", False),
|
|
|
|
# vLLM flags with no SGLang equivalent → skip
|
|
"--no-enable-prefix-caching": (None, False),
|
|
"--enable-prefix-caching": (None, False),
|
|
"--enable-chunked-prefill": (None, False),
|
|
"--no-enable-chunked-prefill":(None, False),
|
|
"--disable-log-requests": (None, False),
|
|
"--disable-log-stats": (None, False),
|
|
"--swap-space": (None, True),
|
|
"--block-size": (None, True),
|
|
"--num-gpu-blocks-override": (None, True),
|
|
"--num-cpu-blocks-override": (None, True),
|
|
"--max-num-seqs": (None, True),
|
|
"--max-num-batched-tokens": (None, True),
|
|
"--distributed-executor-backend": (None, True),
|
|
"--pipeline-parallel-size": (None, True),
|
|
"--data-parallel-size": (None, True),
|
|
"--revision": (None, True),
|
|
"--code-revision": (None, True),
|
|
"--tokenizer-revision": (None, True),
|
|
"--tokenizer-mode": (None, True),
|
|
"--quantization": (None, True),
|
|
"--dtype": (None, True),
|
|
"--max-seq-len-to-capture": (None, True),
|
|
"--enable-lora": (None, False),
|
|
"--max-lora-rank": (None, True),
|
|
"--max-cpu-loras": (None, True),
|
|
"--lora-dtype": (None, True),
|
|
"--enable-prompt-adapter": (None, False),
|
|
"--scheduler-delay-factor": (None, True),
|
|
"--enable-multi-modal": (None, False),
|
|
"--limit-mm-per-prompt": (None, True),
|
|
}
|
|
|
|
# Default tool-call-parser; override with SGLANG_TOOL_CALL_PARSER env var
|
|
DEFAULT_TOOL_CALL_PARSER = "qwen3_coder"
|
|
|
|
|
|
def parse_vllm_args(args):
|
|
"""
|
|
Parse vLLM CLI args and extract model, host, port,
|
|
plus any args we should translate to SGLang.
|
|
Returns (model, host, port, sglang_extra_args, skipped_args).
|
|
"""
|
|
model = None
|
|
host = "0.0.0.0"
|
|
port = "8000"
|
|
sglang_extra = [] # translated args for SGLang
|
|
skipped = [] # vLLM args we're ignoring
|
|
|
|
i = 0
|
|
while i < len(args):
|
|
arg = args[i]
|
|
|
|
# 'serve' subcommand — skip
|
|
if arg == "serve":
|
|
i += 1
|
|
continue
|
|
|
|
# Positional model argument (first non-flag after serve, or standalone)
|
|
if not arg.startswith("-") and model is None:
|
|
model = arg
|
|
i += 1
|
|
continue
|
|
|
|
# --flag=value form
|
|
if "=" in arg and arg.startswith("--"):
|
|
flag, val = arg.split("=", 1)
|
|
if flag == "--host":
|
|
host = val
|
|
elif flag == "--port":
|
|
port = val
|
|
elif flag in ARG_MAP:
|
|
sglang_flag, has_val = ARG_MAP[flag]
|
|
if sglang_flag is None:
|
|
skipped.append(arg)
|
|
elif has_val:
|
|
sglang_extra.extend([sglang_flag, val])
|
|
else:
|
|
# boolean flag with =value (unusual but valid)
|
|
sglang_extra.append(sglang_flag)
|
|
else:
|
|
# Unknown flag — pass through as-is (might be a SGLang flag too)
|
|
sglang_extra.append(arg)
|
|
i += 1
|
|
continue
|
|
|
|
# --flag value form
|
|
if arg in ("--host",):
|
|
if i + 1 < len(args):
|
|
host = args[i + 1]
|
|
i += 2
|
|
continue
|
|
if arg in ("--port",):
|
|
if i + 1 < len(args):
|
|
port = args[i + 1]
|
|
i += 2
|
|
continue
|
|
|
|
if arg in ARG_MAP:
|
|
sglang_flag, has_val = ARG_MAP[arg]
|
|
if sglang_flag is None:
|
|
skipped.append(arg)
|
|
if has_val and i + 1 < len(args) and not args[i + 1].startswith("-"):
|
|
skipped.append(args[i + 1])
|
|
i += 2
|
|
else:
|
|
i += 1
|
|
elif has_val:
|
|
if i + 1 < len(args):
|
|
sglang_extra.extend([sglang_flag, args[i + 1]])
|
|
i += 2
|
|
else:
|
|
i += 1
|
|
else:
|
|
sglang_extra.append(sglang_flag)
|
|
i += 1
|
|
continue
|
|
|
|
# --tool-call-parser — pass through to SGLang
|
|
if arg == "--tool-call-parser":
|
|
if i + 1 < len(args):
|
|
sglang_extra.extend(["--tool-call-parser", args[i + 1]])
|
|
i += 2
|
|
else:
|
|
i += 1
|
|
continue
|
|
|
|
# Unknown flag — pass through if it takes a value, might be valid for SGLang
|
|
if arg.startswith("--") and i + 1 < len(args) and not args[i + 1].startswith("-"):
|
|
sglang_extra.extend([arg, args[i + 1]])
|
|
i += 2
|
|
elif arg.startswith("--"):
|
|
sglang_extra.append(arg)
|
|
i += 1
|
|
else:
|
|
# Unknown positional — probably the model if we don't have it yet
|
|
if model is None:
|
|
model = arg
|
|
i += 1
|
|
|
|
return model, host, port, sglang_extra, skipped
|
|
|
|
|
|
def main():
|
|
args = sys.argv[1:]
|
|
|
|
log_path = os.environ.get("VLLM_SHIM_LOG", "/tmp/vllm-shim.log")
|
|
with open(log_path, "a") as f:
|
|
f.write(f"\n{datetime.datetime.now().isoformat()} vLLM -> SGLang Shim (Python module)\n")
|
|
f.write(f" Invoked as: python -m {__name__} {' '.join(args)}\n")
|
|
f.write(" All arguments received:\n")
|
|
for i, arg in enumerate(args, 1):
|
|
f.write(f" [{i}] {arg}\n")
|
|
f.write("\n")
|
|
|
|
print()
|
|
print("==========================================")
|
|
print(" vLLM -> SGLang Shim (Python module)")
|
|
print("==========================================")
|
|
print(f" Invoked as: python -m {__name__} {' '.join(args)}")
|
|
print()
|
|
print(" All arguments received:")
|
|
for i, arg in enumerate(args, 1):
|
|
print(f" [{i}] {arg}")
|
|
print("==========================================")
|
|
print()
|
|
|
|
model, host, port, sglang_extra, skipped = parse_vllm_args(args)
|
|
|
|
if not model:
|
|
print("ERROR: No model specified in vLLM args!")
|
|
os._exit(1)
|
|
|
|
# SGLang port scheme: original+1 = SGLang, original+2 = middleware
|
|
sglang_port = str(int(port) + 1)
|
|
middleware_port = str(int(port) + 2)
|
|
|
|
# Build SGLang command
|
|
sglang_cmd = [
|
|
sys.executable, "-m", "sglang.launch_server",
|
|
"--model-path", model,
|
|
"--host", host,
|
|
"--port", sglang_port,
|
|
]
|
|
|
|
# Add tool-call-parser (env override or default)
|
|
tcp = os.environ.get("SGLANG_TOOL_CALL_PARSER", DEFAULT_TOOL_CALL_PARSER)
|
|
if tcp:
|
|
sglang_cmd.extend(["--tool-call-parser", tcp])
|
|
|
|
# Add translated/forwarded args
|
|
sglang_cmd.extend(sglang_extra)
|
|
|
|
print(f"Model: {model}")
|
|
print(f"SGLang host: {host}:{sglang_port}")
|
|
print(f"Middleware: {host}:{middleware_port}")
|
|
print(f"haproxy: {host}:{port}")
|
|
if sglang_extra:
|
|
print(f"Translated args: {' '.join(sglang_extra)}")
|
|
if skipped:
|
|
print(f"Skipped (no SGLang equivalent): {' '.join(skipped)}")
|
|
print()
|
|
print(f"SGLang command: {' '.join(sglang_cmd)}")
|
|
print()
|
|
|
|
# ── haproxy setup ────────────────────────────────────────
|
|
|
|
os.makedirs("/tmp/haproxy-errors", exist_ok=True)
|
|
with open("/tmp/haproxy-errors/200-empty.http", "w") as f:
|
|
f.write("HTTP/1.0 200 OK\r\nContent-Length: 0\r\nConnection: close\r\n\r\n")
|
|
with open("/tmp/haproxy-errors/503-sglang.http", "w") as f:
|
|
f.write("HTTP/1.0 503 Service Unavailable\r\nContent-Length: 16\r\nConnection: close\r\nContent-Type: text/plain\r\n\r\nSGLang not ready")
|
|
|
|
haproxy_cfg = "/tmp/haproxy-shim.cfg"
|
|
with open(haproxy_cfg, "w") as f:
|
|
f.write(f"""global
|
|
maxconn 4096
|
|
|
|
defaults
|
|
mode http
|
|
timeout connect 5s
|
|
timeout client 300s
|
|
timeout server 300s
|
|
|
|
frontend proxy
|
|
bind {host}:{port}
|
|
|
|
# /metrics stub — instant 200 empty (vLLM stack expects this)
|
|
acl is_metrics path /metrics
|
|
http-request deny deny_status 200 if is_metrics
|
|
errorfile 200 /tmp/haproxy-errors/200-empty.http
|
|
|
|
# /health — instant response based on SGLang backend state
|
|
acl is_health path /health
|
|
acl sglang_up nbsrv(sglang) gt 0
|
|
http-request deny deny_status 200 if is_health sglang_up
|
|
http-request deny deny_status 503 if is_health
|
|
errorfile 503 /tmp/haproxy-errors/503-sglang.http
|
|
|
|
default_backend sglang
|
|
|
|
backend sglang
|
|
option httpchk GET /health
|
|
http-check expect status 200
|
|
server s1 127.0.0.1:{middleware_port} check inter 5s fall 3 rise 2
|
|
""")
|
|
|
|
with open(log_path, "a") as f:
|
|
f.write(f"haproxy config written to {haproxy_cfg}\n")
|
|
f.write(f"Model: {model}, SGLang port: {sglang_port}, middleware port: {middleware_port}, haproxy port: {port}\n")
|
|
f.write(f"SGLang command: {' '.join(sglang_cmd)}\n")
|
|
if skipped:
|
|
f.write(f"Skipped vLLM args: {' '.join(skipped)}\n")
|
|
|
|
# ── Launch processes ─────────────────────────────────────
|
|
|
|
sglang_proc = subprocess.Popen(sglang_cmd)
|
|
|
|
middleware_env = os.environ.copy()
|
|
middleware_env["SGLANG_HOST"] = host
|
|
middleware_env["SGLANG_PORT"] = sglang_port
|
|
middleware_env["MIDDLEWARE_PORT"] = middleware_port
|
|
middleware_proc = subprocess.Popen(
|
|
[sys.executable, "/opt/vllm-shim/vllm_middleware.py"],
|
|
env=middleware_env,
|
|
)
|
|
|
|
time.sleep(2)
|
|
|
|
haproxy_proc = subprocess.Popen(["haproxy", "-f", haproxy_cfg])
|
|
|
|
with open(log_path, "a") as f:
|
|
f.write(f"SGLang PID: {sglang_proc.pid}, middleware PID: {middleware_proc.pid}, haproxy PID: {haproxy_proc.pid}\n")
|
|
|
|
# Wait for whichever dies first
|
|
while True:
|
|
sglang_ret = sglang_proc.poll()
|
|
middleware_ret = middleware_proc.poll()
|
|
haproxy_ret = haproxy_proc.poll()
|
|
if sglang_ret is not None:
|
|
print(f"SGLang exited (code {sglang_ret}), shutting down")
|
|
middleware_proc.terminate()
|
|
haproxy_proc.terminate()
|
|
os._exit(sglang_ret)
|
|
if middleware_ret is not None:
|
|
print(f"Middleware exited (code {middleware_ret}), shutting down")
|
|
sglang_proc.terminate()
|
|
haproxy_proc.terminate()
|
|
os._exit(middleware_ret)
|
|
if haproxy_ret is not None:
|
|
print(f"haproxy exited (code {haproxy_ret}), shutting down")
|
|
sglang_proc.terminate()
|
|
middleware_proc.terminate()
|
|
os._exit(haproxy_ret)
|
|
time.sleep(1)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|
|
# Also run if imported as a module (some invocation paths just import the file)
|
|
main()
|