custom weights
This commit is contained in:
@@ -235,6 +235,25 @@ RUN apt install -y --no-install-recommends tmux cmake
|
||||
# Deprecated cleanup
|
||||
RUN pip uninstall -y pynvml && pip install nvidia-ml-py
|
||||
|
||||
# Copy vLLM shim that intercepts --model to download custom weights from URLs
|
||||
COPY vllm_shim_module.py /opt/vllm-shim/vllm_shim_module.py
|
||||
|
||||
# Shadow `python -m vllm.*` invocations via PYTHONPATH
|
||||
# The shim masquerades as the vllm package so python -m vllm/entrypoints/openai/api_server
|
||||
# hits our interceptor first, which downloads weights then execs the real vLLM
|
||||
RUN mkdir -p /opt/vllm-shim/vllm/entrypoints/openai \
|
||||
/opt/vllm-shim/vllm/entrypoints/cli && \
|
||||
cp /opt/vllm-shim/vllm_shim_module.py /opt/vllm-shim/vllm/__main__.py && \
|
||||
cp /opt/vllm-shim/vllm_shim_module.py /opt/vllm-shim/vllm/entrypoints/openai/api_server.py && \
|
||||
cp /opt/vllm-shim/vllm_shim_module.py /opt/vllm-shim/vllm/entrypoints/cli/main.py && \
|
||||
touch /opt/vllm-shim/vllm/__init__.py \
|
||||
/opt/vllm-shim/vllm/entrypoints/__init__.py \
|
||||
/opt/vllm-shim/vllm/entrypoints/openai/__init__.py \
|
||||
/opt/vllm-shim/vllm/entrypoints/cli/__init__.py
|
||||
|
||||
ENV PYTHONPATH=/opt/vllm-shim
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# API server entrypoint
|
||||
# ENTRYPOINT ["vllm", "serve"]
|
||||
CMD ["/bin/bash"]
|
||||
#CMD ["/bin/bash"]
|
||||
297
vllm/vllm_shim_module.py
Normal file
297
vllm/vllm_shim_module.py
Normal file
@@ -0,0 +1,297 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
vLLM shim with custom weights download.
|
||||
|
||||
Intercepts `python -m vllm.entrypoints.openai.api_server` so that
|
||||
if --model or the positional model arg (after "serve") points to a URL,
|
||||
we download + extract it to a local cache dir, then replace it with
|
||||
the local path before handing off to the real vLLM server.
|
||||
|
||||
Supported archive formats (detected from URL extension):
|
||||
.tar, .tar.gz, .tgz, .tar.bz2, .tar.xz, .zip
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import datetime
|
||||
import shutil
|
||||
import time
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
|
||||
# Where to cache downloaded+extracted weights
|
||||
# Production stack mounts the PVC at /data — use a subdir so it persists across pod restarts
|
||||
CACHE_DIR = os.environ.get("VLLM_WEIGHTS_CACHE", "/data/weights")
|
||||
|
||||
# The shim dir that shadows the vllm package — must be stripped from PYTHONPATH
|
||||
# before exec'ing the real vLLM, otherwise we loop forever.
|
||||
SHIM_DIR = "/opt/vllm-shim"
|
||||
|
||||
|
||||
def log(msg: str):
|
||||
"""Write to both stdout and the shim log file."""
|
||||
log_path = os.environ.get("VLLM_SHIM_LOG", "/tmp/vllm-shim.log")
|
||||
ts = datetime.datetime.now().isoformat()
|
||||
line = f"[{ts}] {msg}"
|
||||
print(line, flush=True)
|
||||
try:
|
||||
with open(log_path, "a") as f:
|
||||
f.write(line + "\n")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def is_url(value: str) -> bool:
|
||||
return value.startswith("http://") or value.startswith("https://")
|
||||
|
||||
|
||||
def detect_archive_type(url: str) -> str:
|
||||
"""
|
||||
Detect archive type from URL path extension.
|
||||
Returns one of: 'tar', 'tar.gz', 'tar.bz2', 'tar.xz', 'zip', or '' (unknown).
|
||||
"""
|
||||
path = urllib.parse.urlparse(url).path
|
||||
for ext in (".tar.gz", ".tar.bz2", ".tar.xz"):
|
||||
if path.endswith(ext):
|
||||
return ext.lstrip(".")
|
||||
_, ext = os.path.splitext(path)
|
||||
mapping = {
|
||||
".tar": "tar",
|
||||
".tgz": "tar.gz",
|
||||
".zip": "zip",
|
||||
}
|
||||
return mapping.get(ext.lower(), "")
|
||||
|
||||
|
||||
MAX_DOWNLOAD_RETRIES = int(os.environ.get("VLLM_SHIM_MAX_RETRIES", "5"))
|
||||
RETRY_DELAY_SECONDS = 5
|
||||
|
||||
|
||||
def download_file(url: str, dest: str):
|
||||
"""Download url to dest with retries and a progress indicator."""
|
||||
for attempt in range(1, MAX_DOWNLOAD_RETRIES + 1):
|
||||
try:
|
||||
log(f"Downloading {url} -> {dest} (attempt {attempt}/{MAX_DOWNLOAD_RETRIES})")
|
||||
urllib.request.urlretrieve(url, dest, reporthook=_download_progress)
|
||||
log(f"Download complete: {dest}")
|
||||
return
|
||||
except Exception as e:
|
||||
log(f"Download attempt {attempt} failed: {e}")
|
||||
if os.path.exists(dest):
|
||||
os.remove(dest)
|
||||
if attempt < MAX_DOWNLOAD_RETRIES:
|
||||
wait = RETRY_DELAY_SECONDS * attempt
|
||||
log(f"Retrying in {wait}s...")
|
||||
time.sleep(wait)
|
||||
else:
|
||||
log(f"All {MAX_DOWNLOAD_RETRIES} download attempts failed")
|
||||
raise
|
||||
|
||||
|
||||
def _download_progress(block_num, block_size, total_size):
|
||||
"""Simple download progress callback."""
|
||||
if total_size <= 0:
|
||||
return
|
||||
downloaded = block_num * block_size
|
||||
pct = min(downloaded * 100 // total_size, 100)
|
||||
if pct % 10 == 0 and pct > 0:
|
||||
mb_down = downloaded / (1024 * 1024)
|
||||
mb_total = total_size / (1024 * 1024)
|
||||
sys.stdout.write(f"\r {pct}% ({mb_down:.0f}/{mb_total:.0f} MB)")
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def extract_archive(archive_path: str, dest_dir: str, archive_type: str):
|
||||
"""Extract archive to dest_dir based on archive_type."""
|
||||
log(f"Extracting {archive_path} ({archive_type}) -> {dest_dir}")
|
||||
if archive_type == "tar.gz" or archive_type == "tgz":
|
||||
shutil.unpack_archive(archive_path, dest_dir, "gztar")
|
||||
elif archive_type == "tar.bz2":
|
||||
shutil.unpack_archive(archive_path, dest_dir, "bztar")
|
||||
elif archive_type == "tar.xz":
|
||||
subprocess.run(
|
||||
["tar", "-xJf", archive_path, "-C", dest_dir],
|
||||
check=True,
|
||||
)
|
||||
elif archive_type == "tar":
|
||||
shutil.unpack_archive(archive_path, dest_dir, "tar")
|
||||
elif archive_type == "zip":
|
||||
shutil.unpack_archive(archive_path, dest_dir, "zip")
|
||||
else:
|
||||
raise ValueError(f"Unsupported archive type: {archive_type}")
|
||||
log(f"Extraction complete: {dest_dir}")
|
||||
|
||||
|
||||
def find_model_dir(extract_dir: str) -> str:
|
||||
"""
|
||||
After extraction, find the directory containing the actual model weights.
|
||||
Walks the tree looking for .safetensors files and returns the directory
|
||||
that contains one. This handles archives with extra parent dirs,
|
||||
nested structures, or flat extractions.
|
||||
"""
|
||||
for root, dirs, files in os.walk(extract_dir):
|
||||
if any(f.endswith(".safetensors") for f in files):
|
||||
return root
|
||||
log("WARNING: No .safetensors files found in extracted archive, falling back to single-dir heuristic")
|
||||
entries = [e for e in os.listdir(extract_dir)
|
||||
if not e.startswith(".") and e != "__MACOSX"]
|
||||
if len(entries) == 1 and os.path.isdir(os.path.join(extract_dir, entries[0])):
|
||||
return os.path.join(extract_dir, entries[0])
|
||||
return extract_dir
|
||||
|
||||
|
||||
def download_and_extract_model(url: str) -> str:
|
||||
"""
|
||||
Download a model from URL, extract it, and return the local path.
|
||||
Uses a cache keyed by URL filename to avoid re-downloading.
|
||||
"""
|
||||
url_filename = os.path.basename(urllib.parse.urlparse(url).path)
|
||||
cache_key = os.path.splitext(url_filename)[0]
|
||||
local_dir = os.path.join(CACHE_DIR, cache_key)
|
||||
|
||||
if os.path.isdir(local_dir) and os.listdir(local_dir):
|
||||
model_path = find_model_dir(local_dir)
|
||||
log(f"Using cached weights: {model_path}")
|
||||
return model_path
|
||||
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
archive_type = detect_archive_type(url)
|
||||
if not archive_type:
|
||||
raise ValueError(
|
||||
f"Cannot determine archive type from URL: {url}\n"
|
||||
f"Supported extensions: .tar, .tar.gz, .tgz, .tar.bz2, .tar.xz, .zip"
|
||||
)
|
||||
|
||||
tmp_archive = os.path.join(CACHE_DIR, url_filename + ".tmp")
|
||||
try:
|
||||
download_file(url, tmp_archive)
|
||||
extract_archive(tmp_archive, local_dir, archive_type)
|
||||
finally:
|
||||
if os.path.exists(tmp_archive):
|
||||
os.remove(tmp_archive)
|
||||
|
||||
return find_model_dir(local_dir)
|
||||
|
||||
|
||||
def parse_args(args):
|
||||
"""
|
||||
Parse argv, intercepting --model and positional model args.
|
||||
Production stack invokes: python -m vllm.entrypoints.openai.api_server serve <model-url> ...
|
||||
The model can appear as:
|
||||
- --model <url>
|
||||
- --model=<url>
|
||||
- A positional arg after "serve" subcommand
|
||||
If the value is a URL, download+extract and replace with local path.
|
||||
Returns the modified argv list.
|
||||
"""
|
||||
result = []
|
||||
i = 0
|
||||
model_replaced = False
|
||||
saw_serve = False
|
||||
|
||||
while i < len(args):
|
||||
arg = args[i]
|
||||
|
||||
# --model=<value>
|
||||
if arg.startswith("--model="):
|
||||
value = arg.split("=", 1)[1]
|
||||
if is_url(value):
|
||||
local_path = download_and_extract_model(value)
|
||||
result.append(f"--model={local_path}")
|
||||
model_replaced = True
|
||||
else:
|
||||
result.append(arg)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# --model <value>
|
||||
if arg == "--model":
|
||||
result.append(arg)
|
||||
i += 1
|
||||
if i < len(args):
|
||||
value = args[i]
|
||||
if is_url(value):
|
||||
local_path = download_and_extract_model(value)
|
||||
result.append(local_path)
|
||||
model_replaced = True
|
||||
else:
|
||||
result.append(value)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# "serve" subcommand — next positional is the model
|
||||
if arg == "serve":
|
||||
result.append(arg)
|
||||
saw_serve = True
|
||||
i += 1
|
||||
# The next non-flag argument is the model
|
||||
if i < len(args) and not args[i].startswith("-") and is_url(args[i]):
|
||||
local_path = download_and_extract_model(args[i])
|
||||
result.append(local_path)
|
||||
model_replaced = True
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Positional model arg when there's no "serve" subcommand
|
||||
# (first non-flag arg if no serve seen)
|
||||
if not arg.startswith("-") and not saw_serve and not model_replaced:
|
||||
if is_url(arg):
|
||||
local_path = download_and_extract_model(arg)
|
||||
result.append(local_path)
|
||||
model_replaced = True
|
||||
i += 1
|
||||
continue
|
||||
|
||||
result.append(arg)
|
||||
i += 1
|
||||
|
||||
if model_replaced:
|
||||
log("Model URL was replaced with local path")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def strip_shim_from_pythonpath():
|
||||
"""
|
||||
Remove the shim directory from PYTHONPATH so that when we exec the
|
||||
real vLLM, Python doesn't find our shadow package again (infinite loop).
|
||||
"""
|
||||
pp = os.environ.get("PYTHONPATH", "")
|
||||
parts = [p for p in pp.split(":") if p != SHIM_DIR]
|
||||
new_pp = ":".join(parts)
|
||||
if new_pp != pp:
|
||||
os.environ["PYTHONPATH"] = new_pp
|
||||
log(f"Stripped {SHIM_DIR} from PYTHONPATH (was: {pp!r}, now: {new_pp!r})")
|
||||
|
||||
|
||||
def main():
|
||||
args = sys.argv[1:]
|
||||
|
||||
# Determine which vllm module was actually invoked so we exec the real one
|
||||
# (could be vllm.entrypoints.cli.main, vllm.entrypoints.openai.api_server, etc.)
|
||||
invoked_module = __name__ # e.g. "vllm.entrypoints.cli.main" or "vllm.entrypoints.openai.api_server"
|
||||
log("=" * 50)
|
||||
log("vLLM Custom Weights Shim")
|
||||
log(f" Invoked as: python -m {invoked_module} {' '.join(args)}")
|
||||
log("=" * 50)
|
||||
|
||||
# Intercept --model / positional model if it's a URL
|
||||
modified_args = parse_args(args)
|
||||
|
||||
# Strip our shim from PYTHONPATH so the real vLLM resolves correctly
|
||||
strip_shim_from_pythonpath()
|
||||
|
||||
# Build the real vLLM command using the same module that was invoked
|
||||
vllm_cmd = [sys.executable, "-m", invoked_module] + modified_args
|
||||
|
||||
log(f"Launching vLLM: {' '.join(vllm_cmd)}")
|
||||
|
||||
# Exec into vLLM — replace this process so signals flow through cleanly
|
||||
os.execvp(vllm_cmd[0], vllm_cmd)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
# Also run if imported as a module (some invocation paths just import the file)
|
||||
main()
|
||||
Reference in New Issue
Block a user