diff --git a/vllm/Dockerfile b/vllm/Dockerfile index 398bdfe..c1216d0 100644 --- a/vllm/Dockerfile +++ b/vllm/Dockerfile @@ -235,6 +235,25 @@ RUN apt install -y --no-install-recommends tmux cmake # Deprecated cleanup RUN pip uninstall -y pynvml && pip install nvidia-ml-py +# Copy vLLM shim that intercepts --model to download custom weights from URLs +COPY vllm_shim_module.py /opt/vllm-shim/vllm_shim_module.py + +# Shadow `python -m vllm.*` invocations via PYTHONPATH +# The shim masquerades as the vllm package so python -m vllm/entrypoints/openai/api_server +# hits our interceptor first, which downloads weights then execs the real vLLM +RUN mkdir -p /opt/vllm-shim/vllm/entrypoints/openai \ + /opt/vllm-shim/vllm/entrypoints/cli && \ + cp /opt/vllm-shim/vllm_shim_module.py /opt/vllm-shim/vllm/__main__.py && \ + cp /opt/vllm-shim/vllm_shim_module.py /opt/vllm-shim/vllm/entrypoints/openai/api_server.py && \ + cp /opt/vllm-shim/vllm_shim_module.py /opt/vllm-shim/vllm/entrypoints/cli/main.py && \ + touch /opt/vllm-shim/vllm/__init__.py \ + /opt/vllm-shim/vllm/entrypoints/__init__.py \ + /opt/vllm-shim/vllm/entrypoints/openai/__init__.py \ + /opt/vllm-shim/vllm/entrypoints/cli/__init__.py + +ENV PYTHONPATH=/opt/vllm-shim +ENV PYTHONUNBUFFERED=1 + # API server entrypoint # ENTRYPOINT ["vllm", "serve"] -CMD ["/bin/bash"] +#CMD ["/bin/bash"] \ No newline at end of file diff --git a/vllm/vllm_shim_module.py b/vllm/vllm_shim_module.py new file mode 100644 index 0000000..c440782 --- /dev/null +++ b/vllm/vllm_shim_module.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python3 +""" +vLLM shim with custom weights download. + +Intercepts `python -m vllm.entrypoints.openai.api_server` so that +if --model or the positional model arg (after "serve") points to a URL, +we download + extract it to a local cache dir, then replace it with +the local path before handing off to the real vLLM server. + +Supported archive formats (detected from URL extension): + .tar, .tar.gz, .tgz, .tar.bz2, .tar.xz, .zip +""" +import os +import sys +import subprocess +import datetime +import shutil +import time +import urllib.parse +import urllib.request + +# Where to cache downloaded+extracted weights +# Production stack mounts the PVC at /data — use a subdir so it persists across pod restarts +CACHE_DIR = os.environ.get("VLLM_WEIGHTS_CACHE", "/data/weights") + +# The shim dir that shadows the vllm package — must be stripped from PYTHONPATH +# before exec'ing the real vLLM, otherwise we loop forever. +SHIM_DIR = "/opt/vllm-shim" + + +def log(msg: str): + """Write to both stdout and the shim log file.""" + log_path = os.environ.get("VLLM_SHIM_LOG", "/tmp/vllm-shim.log") + ts = datetime.datetime.now().isoformat() + line = f"[{ts}] {msg}" + print(line, flush=True) + try: + with open(log_path, "a") as f: + f.write(line + "\n") + except Exception: + pass + + +def is_url(value: str) -> bool: + return value.startswith("http://") or value.startswith("https://") + + +def detect_archive_type(url: str) -> str: + """ + Detect archive type from URL path extension. + Returns one of: 'tar', 'tar.gz', 'tar.bz2', 'tar.xz', 'zip', or '' (unknown). + """ + path = urllib.parse.urlparse(url).path + for ext in (".tar.gz", ".tar.bz2", ".tar.xz"): + if path.endswith(ext): + return ext.lstrip(".") + _, ext = os.path.splitext(path) + mapping = { + ".tar": "tar", + ".tgz": "tar.gz", + ".zip": "zip", + } + return mapping.get(ext.lower(), "") + + +MAX_DOWNLOAD_RETRIES = int(os.environ.get("VLLM_SHIM_MAX_RETRIES", "5")) +RETRY_DELAY_SECONDS = 5 + + +def download_file(url: str, dest: str): + """Download url to dest with retries and a progress indicator.""" + for attempt in range(1, MAX_DOWNLOAD_RETRIES + 1): + try: + log(f"Downloading {url} -> {dest} (attempt {attempt}/{MAX_DOWNLOAD_RETRIES})") + urllib.request.urlretrieve(url, dest, reporthook=_download_progress) + log(f"Download complete: {dest}") + return + except Exception as e: + log(f"Download attempt {attempt} failed: {e}") + if os.path.exists(dest): + os.remove(dest) + if attempt < MAX_DOWNLOAD_RETRIES: + wait = RETRY_DELAY_SECONDS * attempt + log(f"Retrying in {wait}s...") + time.sleep(wait) + else: + log(f"All {MAX_DOWNLOAD_RETRIES} download attempts failed") + raise + + +def _download_progress(block_num, block_size, total_size): + """Simple download progress callback.""" + if total_size <= 0: + return + downloaded = block_num * block_size + pct = min(downloaded * 100 // total_size, 100) + if pct % 10 == 0 and pct > 0: + mb_down = downloaded / (1024 * 1024) + mb_total = total_size / (1024 * 1024) + sys.stdout.write(f"\r {pct}% ({mb_down:.0f}/{mb_total:.0f} MB)") + sys.stdout.flush() + + +def extract_archive(archive_path: str, dest_dir: str, archive_type: str): + """Extract archive to dest_dir based on archive_type.""" + log(f"Extracting {archive_path} ({archive_type}) -> {dest_dir}") + if archive_type == "tar.gz" or archive_type == "tgz": + shutil.unpack_archive(archive_path, dest_dir, "gztar") + elif archive_type == "tar.bz2": + shutil.unpack_archive(archive_path, dest_dir, "bztar") + elif archive_type == "tar.xz": + subprocess.run( + ["tar", "-xJf", archive_path, "-C", dest_dir], + check=True, + ) + elif archive_type == "tar": + shutil.unpack_archive(archive_path, dest_dir, "tar") + elif archive_type == "zip": + shutil.unpack_archive(archive_path, dest_dir, "zip") + else: + raise ValueError(f"Unsupported archive type: {archive_type}") + log(f"Extraction complete: {dest_dir}") + + +def find_model_dir(extract_dir: str) -> str: + """ + After extraction, find the directory containing the actual model weights. + Walks the tree looking for .safetensors files and returns the directory + that contains one. This handles archives with extra parent dirs, + nested structures, or flat extractions. + """ + for root, dirs, files in os.walk(extract_dir): + if any(f.endswith(".safetensors") for f in files): + return root + log("WARNING: No .safetensors files found in extracted archive, falling back to single-dir heuristic") + entries = [e for e in os.listdir(extract_dir) + if not e.startswith(".") and e != "__MACOSX"] + if len(entries) == 1 and os.path.isdir(os.path.join(extract_dir, entries[0])): + return os.path.join(extract_dir, entries[0]) + return extract_dir + + +def download_and_extract_model(url: str) -> str: + """ + Download a model from URL, extract it, and return the local path. + Uses a cache keyed by URL filename to avoid re-downloading. + """ + url_filename = os.path.basename(urllib.parse.urlparse(url).path) + cache_key = os.path.splitext(url_filename)[0] + local_dir = os.path.join(CACHE_DIR, cache_key) + + if os.path.isdir(local_dir) and os.listdir(local_dir): + model_path = find_model_dir(local_dir) + log(f"Using cached weights: {model_path}") + return model_path + + os.makedirs(local_dir, exist_ok=True) + archive_type = detect_archive_type(url) + if not archive_type: + raise ValueError( + f"Cannot determine archive type from URL: {url}\n" + f"Supported extensions: .tar, .tar.gz, .tgz, .tar.bz2, .tar.xz, .zip" + ) + + tmp_archive = os.path.join(CACHE_DIR, url_filename + ".tmp") + try: + download_file(url, tmp_archive) + extract_archive(tmp_archive, local_dir, archive_type) + finally: + if os.path.exists(tmp_archive): + os.remove(tmp_archive) + + return find_model_dir(local_dir) + + +def parse_args(args): + """ + Parse argv, intercepting --model and positional model args. + Production stack invokes: python -m vllm.entrypoints.openai.api_server serve ... + The model can appear as: + - --model + - --model= + - A positional arg after "serve" subcommand + If the value is a URL, download+extract and replace with local path. + Returns the modified argv list. + """ + result = [] + i = 0 + model_replaced = False + saw_serve = False + + while i < len(args): + arg = args[i] + + # --model= + if arg.startswith("--model="): + value = arg.split("=", 1)[1] + if is_url(value): + local_path = download_and_extract_model(value) + result.append(f"--model={local_path}") + model_replaced = True + else: + result.append(arg) + i += 1 + continue + + # --model + if arg == "--model": + result.append(arg) + i += 1 + if i < len(args): + value = args[i] + if is_url(value): + local_path = download_and_extract_model(value) + result.append(local_path) + model_replaced = True + else: + result.append(value) + i += 1 + continue + + # "serve" subcommand — next positional is the model + if arg == "serve": + result.append(arg) + saw_serve = True + i += 1 + # The next non-flag argument is the model + if i < len(args) and not args[i].startswith("-") and is_url(args[i]): + local_path = download_and_extract_model(args[i]) + result.append(local_path) + model_replaced = True + i += 1 + continue + + # Positional model arg when there's no "serve" subcommand + # (first non-flag arg if no serve seen) + if not arg.startswith("-") and not saw_serve and not model_replaced: + if is_url(arg): + local_path = download_and_extract_model(arg) + result.append(local_path) + model_replaced = True + i += 1 + continue + + result.append(arg) + i += 1 + + if model_replaced: + log("Model URL was replaced with local path") + + return result + + +def strip_shim_from_pythonpath(): + """ + Remove the shim directory from PYTHONPATH so that when we exec the + real vLLM, Python doesn't find our shadow package again (infinite loop). + """ + pp = os.environ.get("PYTHONPATH", "") + parts = [p for p in pp.split(":") if p != SHIM_DIR] + new_pp = ":".join(parts) + if new_pp != pp: + os.environ["PYTHONPATH"] = new_pp + log(f"Stripped {SHIM_DIR} from PYTHONPATH (was: {pp!r}, now: {new_pp!r})") + + +def main(): + args = sys.argv[1:] + + # Determine which vllm module was actually invoked so we exec the real one + # (could be vllm.entrypoints.cli.main, vllm.entrypoints.openai.api_server, etc.) + invoked_module = __name__ # e.g. "vllm.entrypoints.cli.main" or "vllm.entrypoints.openai.api_server" + log("=" * 50) + log("vLLM Custom Weights Shim") + log(f" Invoked as: python -m {invoked_module} {' '.join(args)}") + log("=" * 50) + + # Intercept --model / positional model if it's a URL + modified_args = parse_args(args) + + # Strip our shim from PYTHONPATH so the real vLLM resolves correctly + strip_shim_from_pythonpath() + + # Build the real vLLM command using the same module that was invoked + vllm_cmd = [sys.executable, "-m", invoked_module] + modified_args + + log(f"Launching vLLM: {' '.join(vllm_cmd)}") + + # Exec into vLLM — replace this process so signals flow through cleanly + os.execvp(vllm_cmd[0], vllm_cmd) + + +if __name__ == "__main__": + main() + +# Also run if imported as a module (some invocation paths just import the file) +main()