#!/usr/bin/env python3 """ vLLM shim with custom weights download. Intercepts `python -m vllm.entrypoints.openai.api_server` so that if --model or the positional model arg (after "serve") points to a URL, we download + extract it to a local cache dir, then replace it with the local path before handing off to the real vLLM server. Supported archive formats (detected from URL extension): .tar, .tar.gz, .tgz, .tar.bz2, .tar.xz, .zip """ import os import sys import subprocess import datetime import shutil import time import urllib.parse import urllib.request # Where to cache downloaded+extracted weights # Production stack mounts the PVC at /data — use a subdir so it persists across pod restarts CACHE_DIR = os.environ.get("VLLM_WEIGHTS_CACHE", "/data/weights") # The shim dir that shadows the vllm package — must be stripped from PYTHONPATH # before exec'ing the real vLLM, otherwise we loop forever. SHIM_DIR = "/opt/vllm-shim" def log(msg: str): """Write to both stdout and the shim log file.""" log_path = os.environ.get("VLLM_SHIM_LOG", "/tmp/vllm-shim.log") ts = datetime.datetime.now().isoformat() line = f"[{ts}] {msg}" print(line, flush=True) try: with open(log_path, "a") as f: f.write(line + "\n") except Exception: pass def is_url(value: str) -> bool: return value.startswith("http://") or value.startswith("https://") def detect_archive_type(url: str) -> str: """ Detect archive type from URL path extension. Returns one of: 'tar', 'tar.gz', 'tar.bz2', 'tar.xz', 'zip', or '' (unknown). """ path = urllib.parse.urlparse(url).path for ext in (".tar.gz", ".tar.bz2", ".tar.xz"): if path.endswith(ext): return ext.lstrip(".") _, ext = os.path.splitext(path) mapping = { ".tar": "tar", ".tgz": "tar.gz", ".zip": "zip", } return mapping.get(ext.lower(), "") MAX_DOWNLOAD_RETRIES = int(os.environ.get("VLLM_SHIM_MAX_RETRIES", "5")) RETRY_DELAY_SECONDS = 5 def download_file(url: str, dest: str): """Download url to dest with retries and a progress indicator.""" for attempt in range(1, MAX_DOWNLOAD_RETRIES + 1): try: log(f"Downloading {url} -> {dest} (attempt {attempt}/{MAX_DOWNLOAD_RETRIES})") urllib.request.urlretrieve(url, dest, reporthook=_download_progress) log(f"Download complete: {dest}") return except Exception as e: log(f"Download attempt {attempt} failed: {e}") if os.path.exists(dest): os.remove(dest) if attempt < MAX_DOWNLOAD_RETRIES: wait = RETRY_DELAY_SECONDS * attempt log(f"Retrying in {wait}s...") time.sleep(wait) else: log(f"All {MAX_DOWNLOAD_RETRIES} download attempts failed") raise def _download_progress(block_num, block_size, total_size): """Simple download progress callback.""" if total_size <= 0: return downloaded = block_num * block_size pct = min(downloaded * 100 // total_size, 100) if pct % 10 == 0 and pct > 0: mb_down = downloaded / (1024 * 1024) mb_total = total_size / (1024 * 1024) sys.stdout.write(f"\r {pct}% ({mb_down:.0f}/{mb_total:.0f} MB)") sys.stdout.flush() def extract_archive(archive_path: str, dest_dir: str, archive_type: str): """Extract archive to dest_dir based on archive_type.""" log(f"Extracting {archive_path} ({archive_type}) -> {dest_dir}") if archive_type == "tar.gz" or archive_type == "tgz": shutil.unpack_archive(archive_path, dest_dir, "gztar") elif archive_type == "tar.bz2": shutil.unpack_archive(archive_path, dest_dir, "bztar") elif archive_type == "tar.xz": subprocess.run( ["tar", "-xJf", archive_path, "-C", dest_dir], check=True, ) elif archive_type == "tar": shutil.unpack_archive(archive_path, dest_dir, "tar") elif archive_type == "zip": shutil.unpack_archive(archive_path, dest_dir, "zip") else: raise ValueError(f"Unsupported archive type: {archive_type}") log(f"Extraction complete: {dest_dir}") def find_model_dir(extract_dir: str) -> str: """ After extraction, find the directory containing the actual model weights. Walks the tree looking for .safetensors files and returns the directory that contains one. This handles archives with extra parent dirs, nested structures, or flat extractions. """ for root, dirs, files in os.walk(extract_dir): if any(f.endswith(".safetensors") for f in files): return root log("WARNING: No .safetensors files found in extracted archive, falling back to single-dir heuristic") entries = [e for e in os.listdir(extract_dir) if not e.startswith(".") and e != "__MACOSX"] if len(entries) == 1 and os.path.isdir(os.path.join(extract_dir, entries[0])): return os.path.join(extract_dir, entries[0]) return extract_dir def download_and_extract_model(url: str) -> str: """ Download a model from URL, extract it, and return the local path. Uses a cache keyed by URL filename to avoid re-downloading. """ url_filename = os.path.basename(urllib.parse.urlparse(url).path) cache_key = os.path.splitext(url_filename)[0] local_dir = os.path.join(CACHE_DIR, cache_key) if os.path.isdir(local_dir) and os.listdir(local_dir): model_path = find_model_dir(local_dir) log(f"Using cached weights: {model_path}") return model_path os.makedirs(local_dir, exist_ok=True) archive_type = detect_archive_type(url) if not archive_type: raise ValueError( f"Cannot determine archive type from URL: {url}\n" f"Supported extensions: .tar, .tar.gz, .tgz, .tar.bz2, .tar.xz, .zip" ) tmp_archive = os.path.join(CACHE_DIR, url_filename + ".tmp") try: download_file(url, tmp_archive) extract_archive(tmp_archive, local_dir, archive_type) finally: if os.path.exists(tmp_archive): os.remove(tmp_archive) return find_model_dir(local_dir) def parse_args(args): """ Parse argv, intercepting --model and positional model args. Production stack invokes: python -m vllm.entrypoints.openai.api_server serve ... The model can appear as: - --model - --model= - A positional arg after "serve" subcommand If the value is a URL, download+extract and replace with local path. Returns the modified argv list. """ result = [] i = 0 model_replaced = False saw_serve = False while i < len(args): arg = args[i] # --model= if arg.startswith("--model="): value = arg.split("=", 1)[1] if is_url(value): local_path = download_and_extract_model(value) result.append(f"--model={local_path}") model_replaced = True else: result.append(arg) i += 1 continue # --model if arg == "--model": result.append(arg) i += 1 if i < len(args): value = args[i] if is_url(value): local_path = download_and_extract_model(value) result.append(local_path) model_replaced = True else: result.append(value) i += 1 continue # "serve" subcommand — next positional is the model if arg == "serve": result.append(arg) saw_serve = True i += 1 # The next non-flag argument is the model if i < len(args) and not args[i].startswith("-") and is_url(args[i]): local_path = download_and_extract_model(args[i]) result.append(local_path) model_replaced = True i += 1 continue # Positional model arg when there's no "serve" subcommand # (first non-flag arg if no serve seen) if not arg.startswith("-") and not saw_serve and not model_replaced: if is_url(arg): local_path = download_and_extract_model(arg) result.append(local_path) model_replaced = True i += 1 continue result.append(arg) i += 1 if model_replaced: log("Model URL was replaced with local path") return result def strip_shim_from_pythonpath(): """ Remove the shim directory from PYTHONPATH so that when we exec the real vLLM, Python doesn't find our shadow package again (infinite loop). """ pp = os.environ.get("PYTHONPATH", "") parts = [p for p in pp.split(":") if p != SHIM_DIR] new_pp = ":".join(parts) if new_pp != pp: os.environ["PYTHONPATH"] = new_pp log(f"Stripped {SHIM_DIR} from PYTHONPATH (was: {pp!r}, now: {new_pp!r})") def main(): args = sys.argv[1:] # Determine which vllm module was actually invoked so we exec the real one # (could be vllm.entrypoints.cli.main, vllm.entrypoints.openai.api_server, etc.) invoked_module = __name__ # e.g. "vllm.entrypoints.cli.main" or "vllm.entrypoints.openai.api_server" log("=" * 50) log("vLLM Custom Weights Shim") log(f" Invoked as: python -m {invoked_module} {' '.join(args)}") log("=" * 50) # Intercept --model / positional model if it's a URL modified_args = parse_args(args) # Strip our shim from PYTHONPATH so the real vLLM resolves correctly strip_shim_from_pythonpath() # Build the real vLLM command using the same module that was invoked vllm_cmd = [sys.executable, "-m", invoked_module] + modified_args log(f"Launching vLLM: {' '.join(vllm_cmd)}") # Exec into vLLM — replace this process so signals flow through cleanly os.execvp(vllm_cmd[0], vllm_cmd) if __name__ == "__main__": main() # Also run if imported as a module (some invocation paths just import the file) main()