custom weights

This commit is contained in:
2026-04-28 02:08:00 +00:00
parent edf12f7996
commit e43c8c97f1
2 changed files with 317 additions and 1 deletions

View File

@@ -235,6 +235,25 @@ RUN apt install -y --no-install-recommends tmux cmake
# Deprecated cleanup
RUN pip uninstall -y pynvml && pip install nvidia-ml-py
# Copy vLLM shim that intercepts --model to download custom weights from URLs
COPY vllm_shim_module.py /opt/vllm-shim/vllm_shim_module.py
# Shadow `python -m vllm.*` invocations via PYTHONPATH
# The shim masquerades as the vllm package so python -m vllm/entrypoints/openai/api_server
# hits our interceptor first, which downloads weights then execs the real vLLM
RUN mkdir -p /opt/vllm-shim/vllm/entrypoints/openai \
/opt/vllm-shim/vllm/entrypoints/cli && \
cp /opt/vllm-shim/vllm_shim_module.py /opt/vllm-shim/vllm/__main__.py && \
cp /opt/vllm-shim/vllm_shim_module.py /opt/vllm-shim/vllm/entrypoints/openai/api_server.py && \
cp /opt/vllm-shim/vllm_shim_module.py /opt/vllm-shim/vllm/entrypoints/cli/main.py && \
touch /opt/vllm-shim/vllm/__init__.py \
/opt/vllm-shim/vllm/entrypoints/__init__.py \
/opt/vllm-shim/vllm/entrypoints/openai/__init__.py \
/opt/vllm-shim/vllm/entrypoints/cli/__init__.py
ENV PYTHONPATH=/opt/vllm-shim
ENV PYTHONUNBUFFERED=1
# API server entrypoint
# ENTRYPOINT ["vllm", "serve"]
CMD ["/bin/bash"]
#CMD ["/bin/bash"]

297
vllm/vllm_shim_module.py Normal file
View File

@@ -0,0 +1,297 @@
#!/usr/bin/env python3
"""
vLLM shim with custom weights download.
Intercepts `python -m vllm.entrypoints.openai.api_server` so that
if --model or the positional model arg (after "serve") points to a URL,
we download + extract it to a local cache dir, then replace it with
the local path before handing off to the real vLLM server.
Supported archive formats (detected from URL extension):
.tar, .tar.gz, .tgz, .tar.bz2, .tar.xz, .zip
"""
import os
import sys
import subprocess
import datetime
import shutil
import time
import urllib.parse
import urllib.request
# Where to cache downloaded+extracted weights
# Production stack mounts the PVC at /data — use a subdir so it persists across pod restarts
CACHE_DIR = os.environ.get("VLLM_WEIGHTS_CACHE", "/data/weights")
# The shim dir that shadows the vllm package — must be stripped from PYTHONPATH
# before exec'ing the real vLLM, otherwise we loop forever.
SHIM_DIR = "/opt/vllm-shim"
def log(msg: str):
"""Write to both stdout and the shim log file."""
log_path = os.environ.get("VLLM_SHIM_LOG", "/tmp/vllm-shim.log")
ts = datetime.datetime.now().isoformat()
line = f"[{ts}] {msg}"
print(line, flush=True)
try:
with open(log_path, "a") as f:
f.write(line + "\n")
except Exception:
pass
def is_url(value: str) -> bool:
return value.startswith("http://") or value.startswith("https://")
def detect_archive_type(url: str) -> str:
"""
Detect archive type from URL path extension.
Returns one of: 'tar', 'tar.gz', 'tar.bz2', 'tar.xz', 'zip', or '' (unknown).
"""
path = urllib.parse.urlparse(url).path
for ext in (".tar.gz", ".tar.bz2", ".tar.xz"):
if path.endswith(ext):
return ext.lstrip(".")
_, ext = os.path.splitext(path)
mapping = {
".tar": "tar",
".tgz": "tar.gz",
".zip": "zip",
}
return mapping.get(ext.lower(), "")
MAX_DOWNLOAD_RETRIES = int(os.environ.get("VLLM_SHIM_MAX_RETRIES", "5"))
RETRY_DELAY_SECONDS = 5
def download_file(url: str, dest: str):
"""Download url to dest with retries and a progress indicator."""
for attempt in range(1, MAX_DOWNLOAD_RETRIES + 1):
try:
log(f"Downloading {url} -> {dest} (attempt {attempt}/{MAX_DOWNLOAD_RETRIES})")
urllib.request.urlretrieve(url, dest, reporthook=_download_progress)
log(f"Download complete: {dest}")
return
except Exception as e:
log(f"Download attempt {attempt} failed: {e}")
if os.path.exists(dest):
os.remove(dest)
if attempt < MAX_DOWNLOAD_RETRIES:
wait = RETRY_DELAY_SECONDS * attempt
log(f"Retrying in {wait}s...")
time.sleep(wait)
else:
log(f"All {MAX_DOWNLOAD_RETRIES} download attempts failed")
raise
def _download_progress(block_num, block_size, total_size):
"""Simple download progress callback."""
if total_size <= 0:
return
downloaded = block_num * block_size
pct = min(downloaded * 100 // total_size, 100)
if pct % 10 == 0 and pct > 0:
mb_down = downloaded / (1024 * 1024)
mb_total = total_size / (1024 * 1024)
sys.stdout.write(f"\r {pct}% ({mb_down:.0f}/{mb_total:.0f} MB)")
sys.stdout.flush()
def extract_archive(archive_path: str, dest_dir: str, archive_type: str):
"""Extract archive to dest_dir based on archive_type."""
log(f"Extracting {archive_path} ({archive_type}) -> {dest_dir}")
if archive_type == "tar.gz" or archive_type == "tgz":
shutil.unpack_archive(archive_path, dest_dir, "gztar")
elif archive_type == "tar.bz2":
shutil.unpack_archive(archive_path, dest_dir, "bztar")
elif archive_type == "tar.xz":
subprocess.run(
["tar", "-xJf", archive_path, "-C", dest_dir],
check=True,
)
elif archive_type == "tar":
shutil.unpack_archive(archive_path, dest_dir, "tar")
elif archive_type == "zip":
shutil.unpack_archive(archive_path, dest_dir, "zip")
else:
raise ValueError(f"Unsupported archive type: {archive_type}")
log(f"Extraction complete: {dest_dir}")
def find_model_dir(extract_dir: str) -> str:
"""
After extraction, find the directory containing the actual model weights.
Walks the tree looking for .safetensors files and returns the directory
that contains one. This handles archives with extra parent dirs,
nested structures, or flat extractions.
"""
for root, dirs, files in os.walk(extract_dir):
if any(f.endswith(".safetensors") for f in files):
return root
log("WARNING: No .safetensors files found in extracted archive, falling back to single-dir heuristic")
entries = [e for e in os.listdir(extract_dir)
if not e.startswith(".") and e != "__MACOSX"]
if len(entries) == 1 and os.path.isdir(os.path.join(extract_dir, entries[0])):
return os.path.join(extract_dir, entries[0])
return extract_dir
def download_and_extract_model(url: str) -> str:
"""
Download a model from URL, extract it, and return the local path.
Uses a cache keyed by URL filename to avoid re-downloading.
"""
url_filename = os.path.basename(urllib.parse.urlparse(url).path)
cache_key = os.path.splitext(url_filename)[0]
local_dir = os.path.join(CACHE_DIR, cache_key)
if os.path.isdir(local_dir) and os.listdir(local_dir):
model_path = find_model_dir(local_dir)
log(f"Using cached weights: {model_path}")
return model_path
os.makedirs(local_dir, exist_ok=True)
archive_type = detect_archive_type(url)
if not archive_type:
raise ValueError(
f"Cannot determine archive type from URL: {url}\n"
f"Supported extensions: .tar, .tar.gz, .tgz, .tar.bz2, .tar.xz, .zip"
)
tmp_archive = os.path.join(CACHE_DIR, url_filename + ".tmp")
try:
download_file(url, tmp_archive)
extract_archive(tmp_archive, local_dir, archive_type)
finally:
if os.path.exists(tmp_archive):
os.remove(tmp_archive)
return find_model_dir(local_dir)
def parse_args(args):
"""
Parse argv, intercepting --model and positional model args.
Production stack invokes: python -m vllm.entrypoints.openai.api_server serve <model-url> ...
The model can appear as:
- --model <url>
- --model=<url>
- A positional arg after "serve" subcommand
If the value is a URL, download+extract and replace with local path.
Returns the modified argv list.
"""
result = []
i = 0
model_replaced = False
saw_serve = False
while i < len(args):
arg = args[i]
# --model=<value>
if arg.startswith("--model="):
value = arg.split("=", 1)[1]
if is_url(value):
local_path = download_and_extract_model(value)
result.append(f"--model={local_path}")
model_replaced = True
else:
result.append(arg)
i += 1
continue
# --model <value>
if arg == "--model":
result.append(arg)
i += 1
if i < len(args):
value = args[i]
if is_url(value):
local_path = download_and_extract_model(value)
result.append(local_path)
model_replaced = True
else:
result.append(value)
i += 1
continue
# "serve" subcommand — next positional is the model
if arg == "serve":
result.append(arg)
saw_serve = True
i += 1
# The next non-flag argument is the model
if i < len(args) and not args[i].startswith("-") and is_url(args[i]):
local_path = download_and_extract_model(args[i])
result.append(local_path)
model_replaced = True
i += 1
continue
# Positional model arg when there's no "serve" subcommand
# (first non-flag arg if no serve seen)
if not arg.startswith("-") and not saw_serve and not model_replaced:
if is_url(arg):
local_path = download_and_extract_model(arg)
result.append(local_path)
model_replaced = True
i += 1
continue
result.append(arg)
i += 1
if model_replaced:
log("Model URL was replaced with local path")
return result
def strip_shim_from_pythonpath():
"""
Remove the shim directory from PYTHONPATH so that when we exec the
real vLLM, Python doesn't find our shadow package again (infinite loop).
"""
pp = os.environ.get("PYTHONPATH", "")
parts = [p for p in pp.split(":") if p != SHIM_DIR]
new_pp = ":".join(parts)
if new_pp != pp:
os.environ["PYTHONPATH"] = new_pp
log(f"Stripped {SHIM_DIR} from PYTHONPATH (was: {pp!r}, now: {new_pp!r})")
def main():
args = sys.argv[1:]
# Determine which vllm module was actually invoked so we exec the real one
# (could be vllm.entrypoints.cli.main, vllm.entrypoints.openai.api_server, etc.)
invoked_module = __name__ # e.g. "vllm.entrypoints.cli.main" or "vllm.entrypoints.openai.api_server"
log("=" * 50)
log("vLLM Custom Weights Shim")
log(f" Invoked as: python -m {invoked_module} {' '.join(args)}")
log("=" * 50)
# Intercept --model / positional model if it's a URL
modified_args = parse_args(args)
# Strip our shim from PYTHONPATH so the real vLLM resolves correctly
strip_shim_from_pythonpath()
# Build the real vLLM command using the same module that was invoked
vllm_cmd = [sys.executable, "-m", invoked_module] + modified_args
log(f"Launching vLLM: {' '.join(vllm_cmd)}")
# Exec into vLLM — replace this process so signals flow through cleanly
os.execvp(vllm_cmd[0], vllm_cmd)
if __name__ == "__main__":
main()
# Also run if imported as a module (some invocation paths just import the file)
main()