303 lines
10 KiB
Python
303 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
vLLM shim with custom weights download.
|
|
|
|
Intercepts `python -m vllm.entrypoints.openai.api_server` so that
|
|
if --model or the positional model arg (after "serve") points to a URL,
|
|
we download + extract it to a local cache dir, then replace it with
|
|
the local path before handing off to the real vLLM server.
|
|
|
|
Supported archive formats (detected from URL extension):
|
|
.tar, .tar.gz, .tgz, .tar.bz2, .tar.xz, .zip
|
|
"""
|
|
import os
|
|
import sys
|
|
import subprocess
|
|
import datetime
|
|
import shutil
|
|
import time
|
|
import urllib.parse
|
|
import urllib.request
|
|
|
|
# Where to cache downloaded+extracted weights
|
|
# Production stack mounts the PVC at /data — use a subdir so it persists across pod restarts
|
|
CACHE_DIR = os.environ.get("VLLM_WEIGHTS_CACHE", "/data/weights")
|
|
|
|
# The shim dir that shadows the vllm package — must be stripped from PYTHONPATH
|
|
# before exec'ing the real vLLM, otherwise we loop forever.
|
|
SHIM_DIR = "/opt/vllm-shim"
|
|
|
|
|
|
def log(msg: str):
|
|
"""Write to both stdout and the shim log file."""
|
|
log_path = os.environ.get("VLLM_SHIM_LOG", "/tmp/vllm-shim.log")
|
|
ts = datetime.datetime.now().isoformat()
|
|
line = f"[{ts}] {msg}"
|
|
print(line, flush=True)
|
|
try:
|
|
with open(log_path, "a") as f:
|
|
f.write(line + "\n")
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def is_url(value: str) -> bool:
|
|
return value.startswith("http://") or value.startswith("https://")
|
|
|
|
|
|
def detect_archive_type(url: str) -> str:
|
|
"""
|
|
Detect archive type from URL path extension.
|
|
Returns one of: 'tar', 'tar.gz', 'tar.bz2', 'tar.xz', 'zip', or '' (unknown).
|
|
"""
|
|
path = urllib.parse.urlparse(url).path
|
|
for ext in (".tar.gz", ".tar.bz2", ".tar.xz"):
|
|
if path.endswith(ext):
|
|
return ext.lstrip(".")
|
|
_, ext = os.path.splitext(path)
|
|
mapping = {
|
|
".tar": "tar",
|
|
".tgz": "tar.gz",
|
|
".zip": "zip",
|
|
}
|
|
return mapping.get(ext.lower(), "")
|
|
|
|
|
|
MAX_DOWNLOAD_RETRIES = int(os.environ.get("VLLM_SHIM_MAX_RETRIES", "5"))
|
|
RETRY_DELAY_SECONDS = 5
|
|
|
|
|
|
def download_file(url: str, dest: str):
|
|
"""Download url to dest with retries and a progress indicator."""
|
|
for attempt in range(1, MAX_DOWNLOAD_RETRIES + 1):
|
|
try:
|
|
log(f"Downloading {url} -> {dest} (attempt {attempt}/{MAX_DOWNLOAD_RETRIES})")
|
|
urllib.request.urlretrieve(url, dest)
|
|
log(f"Download complete: {dest}")
|
|
return
|
|
except Exception as e:
|
|
log(f"Download attempt {attempt} failed: {e}")
|
|
if os.path.exists(dest):
|
|
os.remove(dest)
|
|
if attempt < MAX_DOWNLOAD_RETRIES:
|
|
wait = RETRY_DELAY_SECONDS * attempt
|
|
log(f"Retrying in {wait}s...")
|
|
time.sleep(wait)
|
|
else:
|
|
log(f"All {MAX_DOWNLOAD_RETRIES} download attempts failed")
|
|
raise
|
|
|
|
|
|
def extract_archive(archive_path: str, dest_dir: str, archive_type: str):
|
|
"""Extract archive to dest_dir based on archive_type."""
|
|
log(f"Extracting {archive_path} ({archive_type}) -> {dest_dir}")
|
|
if archive_type == "tar.gz" or archive_type == "tgz":
|
|
shutil.unpack_archive(archive_path, dest_dir, "gztar")
|
|
elif archive_type == "tar.bz2":
|
|
shutil.unpack_archive(archive_path, dest_dir, "bztar")
|
|
elif archive_type == "tar.xz":
|
|
subprocess.run(
|
|
["tar", "-xJf", archive_path, "-C", dest_dir],
|
|
check=True,
|
|
)
|
|
elif archive_type == "tar":
|
|
shutil.unpack_archive(archive_path, dest_dir, "tar")
|
|
elif archive_type == "zip":
|
|
shutil.unpack_archive(archive_path, dest_dir, "zip")
|
|
else:
|
|
raise ValueError(f"Unsupported archive type: {archive_type}")
|
|
log(f"Extraction complete: {dest_dir}")
|
|
|
|
|
|
def find_model_dir(extract_dir: str) -> str:
|
|
"""
|
|
After extraction, find the directory containing the actual model weights.
|
|
Walks the tree looking for .safetensors files and returns the directory
|
|
that contains one. This handles archives with extra parent dirs,
|
|
nested structures, or flat extractions.
|
|
"""
|
|
for root, dirs, files in os.walk(extract_dir):
|
|
if any(f.endswith(".safetensors") for f in files):
|
|
return root
|
|
log("WARNING: No .safetensors files found in extracted archive, falling back to single-dir heuristic")
|
|
entries = [e for e in os.listdir(extract_dir)
|
|
if not e.startswith(".") and e != "__MACOSX"]
|
|
if len(entries) == 1 and os.path.isdir(os.path.join(extract_dir, entries[0])):
|
|
return os.path.join(extract_dir, entries[0])
|
|
return extract_dir
|
|
|
|
|
|
def download_and_extract_model(url: str) -> str:
|
|
"""
|
|
Download a model from URL, extract it, and return the local path.
|
|
Uses a cache keyed by URL filename to avoid re-downloading.
|
|
"""
|
|
url_filename = os.path.basename(urllib.parse.urlparse(url).path)
|
|
cache_key = os.path.splitext(url_filename)[0]
|
|
local_dir = os.path.join(CACHE_DIR, cache_key)
|
|
|
|
if os.path.isdir(local_dir) and os.listdir(local_dir):
|
|
model_path = find_model_dir(local_dir)
|
|
log(f"Using cached weights: {model_path}")
|
|
return model_path
|
|
|
|
os.makedirs(local_dir, exist_ok=True)
|
|
archive_type = detect_archive_type(url)
|
|
if not archive_type:
|
|
raise ValueError(
|
|
f"Cannot determine archive type from URL: {url}\n"
|
|
f"Supported extensions: .tar, .tar.gz, .tgz, .tar.bz2, .tar.xz, .zip"
|
|
)
|
|
|
|
tmp_archive = os.path.join(CACHE_DIR, url_filename + ".tmp")
|
|
try:
|
|
download_file(url, tmp_archive)
|
|
extract_archive(tmp_archive, local_dir, archive_type)
|
|
finally:
|
|
if os.path.exists(tmp_archive):
|
|
os.remove(tmp_archive)
|
|
|
|
return find_model_dir(local_dir)
|
|
|
|
|
|
def parse_args(args):
|
|
"""
|
|
Parse argv, intercepting --model and positional model args.
|
|
Production stack invokes: python -m vllm.entrypoints.openai.api_server serve <model-url> ...
|
|
The model can appear as:
|
|
- --model <url>
|
|
- --model=<url>
|
|
- A positional arg after "serve" subcommand
|
|
If the value is a URL, download+extract and replace with local path.
|
|
Returns the modified argv list.
|
|
"""
|
|
result = []
|
|
i = 0
|
|
model_replaced = False
|
|
saw_serve = False
|
|
|
|
while i < len(args):
|
|
arg = args[i]
|
|
|
|
# --model=<value>
|
|
if arg.startswith("--model="):
|
|
value = arg.split("=", 1)[1]
|
|
if is_url(value):
|
|
local_path = download_and_extract_model(value)
|
|
result.append(f"--model={local_path}")
|
|
model_replaced = True
|
|
else:
|
|
result.append(arg)
|
|
i += 1
|
|
continue
|
|
|
|
# --model <value>
|
|
if arg == "--model":
|
|
result.append(arg)
|
|
i += 1
|
|
if i < len(args):
|
|
value = args[i]
|
|
if is_url(value):
|
|
local_path = download_and_extract_model(value)
|
|
result.append(local_path)
|
|
model_replaced = True
|
|
else:
|
|
result.append(value)
|
|
i += 1
|
|
continue
|
|
|
|
# "serve" subcommand — next positional is the model
|
|
if arg == "serve":
|
|
result.append(arg)
|
|
saw_serve = True
|
|
i += 1
|
|
# The next non-flag argument is the model
|
|
if i < len(args) and not args[i].startswith("-") and is_url(args[i]):
|
|
local_path = download_and_extract_model(args[i])
|
|
result.append(local_path)
|
|
model_replaced = True
|
|
i += 1
|
|
continue
|
|
|
|
# Positional model arg when there's no "serve" subcommand
|
|
# (first non-flag arg if no serve seen)
|
|
if not arg.startswith("-") and not saw_serve and not model_replaced:
|
|
if is_url(arg):
|
|
local_path = download_and_extract_model(arg)
|
|
result.append(local_path)
|
|
model_replaced = True
|
|
i += 1
|
|
continue
|
|
|
|
result.append(arg)
|
|
i += 1
|
|
|
|
if model_replaced:
|
|
log("Model URL was replaced with local path")
|
|
|
|
return result
|
|
|
|
|
|
def strip_shim_from_pythonpath():
|
|
"""
|
|
Remove the shim directory from PYTHONPATH so that when we exec the
|
|
real vLLM, Python doesn't find our shadow package again (infinite loop).
|
|
"""
|
|
pp = os.environ.get("PYTHONPATH", "")
|
|
parts = [p for p in pp.split(":") if p != SHIM_DIR]
|
|
new_pp = ":".join(parts)
|
|
if new_pp != pp:
|
|
os.environ["PYTHONPATH"] = new_pp
|
|
log(f"Stripped {SHIM_DIR} from PYTHONPATH (was: {pp!r}, now: {new_pp!r})")
|
|
|
|
|
|
def invoked_module_path() -> str:
|
|
"""
|
|
Derive the dotted module path from this file's location in the shadow package.
|
|
|
|
When invoked via `python -m vllm.entrypoints.openai.api_server`, __name__ is
|
|
"__main__" — useless for re-invocation. Instead, figure out the module path
|
|
from the file path relative to the shim root (/opt/vllm-shim).
|
|
"""
|
|
# e.g. /opt/vllm-shim/vllm/entrypoints/openai/api_server.py
|
|
filepath = os.path.abspath(__file__)
|
|
# Strip the shim root + trailing .py, convert / to .
|
|
rel = os.path.relpath(filepath, SHIM_DIR)
|
|
# Remove .py extension
|
|
if rel.endswith(".py"):
|
|
rel = rel[:-3]
|
|
return rel.replace(os.sep, ".")
|
|
|
|
|
|
def main():
|
|
args = sys.argv[1:]
|
|
|
|
# Determine which vllm module was actually invoked so we exec the real one
|
|
# (could be vllm.entrypoints.cli.main, vllm.entrypoints.openai.api_server, etc.)
|
|
invoked_module = invoked_module_path()
|
|
log("=" * 50)
|
|
log("vLLM Custom Weights Shim")
|
|
log(f" Invoked as: python -m {invoked_module} {' '.join(args)}")
|
|
log("=" * 50)
|
|
|
|
# Intercept --model / positional model if it's a URL
|
|
modified_args = parse_args(args)
|
|
|
|
# Strip our shim from PYTHONPATH so the real vLLM resolves correctly
|
|
strip_shim_from_pythonpath()
|
|
|
|
# Build the real vLLM command using the same module that was invoked
|
|
vllm_cmd = [sys.executable, "-m", invoked_module] + modified_args
|
|
|
|
log(f"Launching vLLM: {' '.join(vllm_cmd)}")
|
|
|
|
# Exec into vLLM — replace this process so signals flow through cleanly
|
|
os.execvp(vllm_cmd[0], vllm_cmd)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|
|
# Also run if imported as a module (some invocation paths just import the file)
|
|
main()
|