5 Commits

Author SHA1 Message Date
10c71a446c Remove flash-attn GIT_TAG override to main — causes FLASHATTENTION_FP8_TWO_LEVEL_INTERVAL undefined error
v0.19.0 pins a compatible flash-attn commit (2921022). The sed that
forced GIT_TAG to main pulled in newer code that references
FLASHATTENTION_FP8_TWO_LEVEL_INTERVAL which isn't defined in v0.19.0's
build config. Use the pinned commit instead.
2026-04-28 03:07:14 +00:00
550a04a0ca custom weights 2026-04-28 02:10:48 +00:00
e43c8c97f1 custom weights 2026-04-28 02:08:00 +00:00
edf12f7996 Clean up: remove PLAN-triton-kernels.md (merged into main) 2026-04-06 17:25:06 +00:00
e6cc28a942 Add triton_kernels for MoE support (vLLM v0.19.0)
- Add build-triton-kernels stage to fetch triton_kernels from Triton v3.6.0
- Install to site-packages for vLLM to find at runtime
- Resolves: No module named 'triton_kernels.matmul_ogs'
- Image tag: gh200-vllm-tfa:v0.19.0-tfa
2026-04-06 16:39:56 +00:00
3 changed files with 365 additions and 7 deletions

View File

@@ -1,13 +1,16 @@
# ==============================================================================
# ⚠️⚠️⚠️ WORKING BUILD - DO NOT TOUCH ⚠️⚠️⚠️
# Triton Kernels Build (TFA) - vLLM v0.19.0 + triton_kernels
# ==============================================================================
# Build #43 succeeded on 2026-04-03 with these exact versions:
# - vLLM: v0.18.2rc0
# - flashinfer: v0.6.7
# This branch adds triton_kernels from Triton v3.6.0 for MoE support.
#
# Based on working Build #43 (v0.18.2rc0) with vLLM upgraded to v0.19.0:
# - vLLM: v0.19.0
# - flashinfer: v0.6.6
# - flash-attention: hopper branch
# - lmcache: dev branch
# - infinistore: main
# - triton: 3.6.0 (PyPI wheel)
# - triton_kernels: v3.6.0 (from Triton repo)
# - Base: nvcr.io/nvidia/pytorch:26.03-py3 (PyTorch 2.11.0a0, CUDA 13.2.0)
#
# HARD RULES:
@@ -16,7 +19,7 @@
# 3. CLEAR ALL CHANGES WITH MIKE BEFORE MAKING THEM
# 4. ONE BUILD AT A TIME - Mike reports failure → I assess → I report
#
# If you need to modify this file, ask Mike first.
# Image tag: gh200-vllm-tfa:v0.19.0-tfa
# ==============================================================================
# ---------- Builder Base ----------
@@ -79,6 +82,11 @@ FROM build-base AS build-triton
RUN mkdir -p /wheels && \
pip download triton==3.6.0 --platform manylinux_2_27_aarch64 --only-binary=:all: --no-deps -d /wheels
# Install triton_kernels from Triton repo (v3.6.0) for MoE support
# vLLM v0.19.0 requires this for triton_kernels.matmul_ogs module
FROM build-base AS build-triton-kernels
RUN pip install --target=/wheels git+https://github.com/triton-lang/triton.git@v3.6.0#subdirectory=python/triton_kernels
# Skip xformers - vLLM has built-in FlashAttention kernels
# xformers requires TORCH_STABLE_ONLY which needs PyTorch headers not in 2.9.0
# FROM build-base AS build-xformers
@@ -158,7 +166,9 @@ RUN cd vllm && \
echo "========================================\n\n" && \
git submodule sync && \
git submodule update --init --recursive -j 8 && \
sed -i 's/GIT_TAG [a-f0-9]\{40\}/GIT_TAG main/' cmake/external_projects/vllm_flash_attn.cmake && \
# NOTE: Removed the sed that forced flash-attn GIT_TAG to main.
# v0.19.0 pins a compatible commit; building from main causes
# FLASHATTENTION_FP8_TWO_LEVEL_INTERVAL undefined errors.
sed -i 's/register_opaque_type(ModuleName, typ="value", hoist=True)/register_opaque_type(ModuleName, typ="value")/' vllm/utils/torch_utils.py && \
export MAX_JOBS=8 && \
export CMAKE_BUILD_PARALLEL_LEVEL=$MAX_JOBS && \
@@ -191,6 +201,7 @@ FROM base AS vllm-openai
COPY --from=build-flash-attention /wheels/* wheels/
COPY --from=build-flashinfer /wheels/* wheels/
COPY --from=build-triton /wheels/* wheels/
COPY --from=build-triton-kernels /wheels/triton_kernels /usr/local/lib/python3.12/dist-packages/triton_kernels
COPY --from=build-vllm /wheels/* wheels/
COPY --from=build-lmcache /wheels/* wheels/
COPY --from=build-infinistore /wheels/* wheels/
@@ -226,6 +237,28 @@ RUN apt install -y --no-install-recommends tmux cmake
# Deprecated cleanup
RUN pip uninstall -y pynvml && pip install nvidia-ml-py
# Copy over nemotron reasonong parser
COPY ./super_v3_reasoning_parser.py /opt/super_v3_reasoning_parser.py
# Copy vLLM shim that intercepts --model to download custom weights from URLs
COPY vllm_shim_module.py /opt/vllm-shim/vllm_shim_module.py
# Shadow `python -m vllm.*` invocations via PYTHONPATH
# The shim masquerades as the vllm package so python -m vllm/entrypoints/openai/api_server
# hits our interceptor first, which downloads weights then execs the real vLLM
RUN mkdir -p /opt/vllm-shim/vllm/entrypoints/openai \
/opt/vllm-shim/vllm/entrypoints/cli && \
cp /opt/vllm-shim/vllm_shim_module.py /opt/vllm-shim/vllm/__main__.py && \
cp /opt/vllm-shim/vllm_shim_module.py /opt/vllm-shim/vllm/entrypoints/openai/api_server.py && \
cp /opt/vllm-shim/vllm_shim_module.py /opt/vllm-shim/vllm/entrypoints/cli/main.py && \
touch /opt/vllm-shim/vllm/__init__.py \
/opt/vllm-shim/vllm/entrypoints/__init__.py \
/opt/vllm-shim/vllm/entrypoints/openai/__init__.py \
/opt/vllm-shim/vllm/entrypoints/cli/__init__.py
ENV PYTHONPATH=/opt/vllm-shim
ENV PYTHONUNBUFFERED=1
# API server entrypoint
# ENTRYPOINT ["vllm", "serve"]
CMD ["/bin/bash"]
#CMD ["/bin/bash"]

View File

@@ -0,0 +1,28 @@
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
@ReasoningParserManager.register_module("super_v3")
class SuperV3ReasoningParser(DeepSeekR1ReasoningParser):
def extract_reasoning(self, model_output, request):
reasoning_content, final_content = super().extract_reasoning(
model_output, request
)
if (
hasattr(request, "chat_template_kwargs")
and request.chat_template_kwargs
and (
request.chat_template_kwargs.get("enable_thinking") is False
or request.chat_template_kwargs.get("force_nonempty_content") is True
)
and final_content is None
):
"""
The original `deepseek_r1` reasoning parser this inherits from will automatically put everything in the reasoning content when it cannot parse out reasoning. This was fine for the DeepSeek R1 model that was not intended to be used without reasoning.
1. Since the Nemotron 3 Nano and Super both have thinking off modes modulated by "enable_thinking=false" in the chat template kwargs, this change instead which will properly place the content in cases where there is no thinking enabled via config.
2. There are rare cases where the model will output only reasoning without an end-think token `</think>` (e.g. reasoning exceeds max length), which results in empty content returned. End users may want to unilaterally avoid such cases and always have a content response even if the model does not finish its reasoning.
"""
# Put all nonempty content into the content, rather than return content
reasoning_content, final_content = None, reasoning_content
return reasoning_content, final_content

297
vllm/vllm_shim_module.py Normal file
View File

@@ -0,0 +1,297 @@
#!/usr/bin/env python3
"""
vLLM shim with custom weights download.
Intercepts `python -m vllm.entrypoints.openai.api_server` so that
if --model or the positional model arg (after "serve") points to a URL,
we download + extract it to a local cache dir, then replace it with
the local path before handing off to the real vLLM server.
Supported archive formats (detected from URL extension):
.tar, .tar.gz, .tgz, .tar.bz2, .tar.xz, .zip
"""
import os
import sys
import subprocess
import datetime
import shutil
import time
import urllib.parse
import urllib.request
# Where to cache downloaded+extracted weights
# Production stack mounts the PVC at /data — use a subdir so it persists across pod restarts
CACHE_DIR = os.environ.get("VLLM_WEIGHTS_CACHE", "/data/weights")
# The shim dir that shadows the vllm package — must be stripped from PYTHONPATH
# before exec'ing the real vLLM, otherwise we loop forever.
SHIM_DIR = "/opt/vllm-shim"
def log(msg: str):
"""Write to both stdout and the shim log file."""
log_path = os.environ.get("VLLM_SHIM_LOG", "/tmp/vllm-shim.log")
ts = datetime.datetime.now().isoformat()
line = f"[{ts}] {msg}"
print(line, flush=True)
try:
with open(log_path, "a") as f:
f.write(line + "\n")
except Exception:
pass
def is_url(value: str) -> bool:
return value.startswith("http://") or value.startswith("https://")
def detect_archive_type(url: str) -> str:
"""
Detect archive type from URL path extension.
Returns one of: 'tar', 'tar.gz', 'tar.bz2', 'tar.xz', 'zip', or '' (unknown).
"""
path = urllib.parse.urlparse(url).path
for ext in (".tar.gz", ".tar.bz2", ".tar.xz"):
if path.endswith(ext):
return ext.lstrip(".")
_, ext = os.path.splitext(path)
mapping = {
".tar": "tar",
".tgz": "tar.gz",
".zip": "zip",
}
return mapping.get(ext.lower(), "")
MAX_DOWNLOAD_RETRIES = int(os.environ.get("VLLM_SHIM_MAX_RETRIES", "5"))
RETRY_DELAY_SECONDS = 5
def download_file(url: str, dest: str):
"""Download url to dest with retries and a progress indicator."""
for attempt in range(1, MAX_DOWNLOAD_RETRIES + 1):
try:
log(f"Downloading {url} -> {dest} (attempt {attempt}/{MAX_DOWNLOAD_RETRIES})")
urllib.request.urlretrieve(url, dest, reporthook=_download_progress)
log(f"Download complete: {dest}")
return
except Exception as e:
log(f"Download attempt {attempt} failed: {e}")
if os.path.exists(dest):
os.remove(dest)
if attempt < MAX_DOWNLOAD_RETRIES:
wait = RETRY_DELAY_SECONDS * attempt
log(f"Retrying in {wait}s...")
time.sleep(wait)
else:
log(f"All {MAX_DOWNLOAD_RETRIES} download attempts failed")
raise
def _download_progress(block_num, block_size, total_size):
"""Simple download progress callback."""
if total_size <= 0:
return
downloaded = block_num * block_size
pct = min(downloaded * 100 // total_size, 100)
if pct % 10 == 0 and pct > 0:
mb_down = downloaded / (1024 * 1024)
mb_total = total_size / (1024 * 1024)
sys.stdout.write(f"\r {pct}% ({mb_down:.0f}/{mb_total:.0f} MB)")
sys.stdout.flush()
def extract_archive(archive_path: str, dest_dir: str, archive_type: str):
"""Extract archive to dest_dir based on archive_type."""
log(f"Extracting {archive_path} ({archive_type}) -> {dest_dir}")
if archive_type == "tar.gz" or archive_type == "tgz":
shutil.unpack_archive(archive_path, dest_dir, "gztar")
elif archive_type == "tar.bz2":
shutil.unpack_archive(archive_path, dest_dir, "bztar")
elif archive_type == "tar.xz":
subprocess.run(
["tar", "-xJf", archive_path, "-C", dest_dir],
check=True,
)
elif archive_type == "tar":
shutil.unpack_archive(archive_path, dest_dir, "tar")
elif archive_type == "zip":
shutil.unpack_archive(archive_path, dest_dir, "zip")
else:
raise ValueError(f"Unsupported archive type: {archive_type}")
log(f"Extraction complete: {dest_dir}")
def find_model_dir(extract_dir: str) -> str:
"""
After extraction, find the directory containing the actual model weights.
Walks the tree looking for .safetensors files and returns the directory
that contains one. This handles archives with extra parent dirs,
nested structures, or flat extractions.
"""
for root, dirs, files in os.walk(extract_dir):
if any(f.endswith(".safetensors") for f in files):
return root
log("WARNING: No .safetensors files found in extracted archive, falling back to single-dir heuristic")
entries = [e for e in os.listdir(extract_dir)
if not e.startswith(".") and e != "__MACOSX"]
if len(entries) == 1 and os.path.isdir(os.path.join(extract_dir, entries[0])):
return os.path.join(extract_dir, entries[0])
return extract_dir
def download_and_extract_model(url: str) -> str:
"""
Download a model from URL, extract it, and return the local path.
Uses a cache keyed by URL filename to avoid re-downloading.
"""
url_filename = os.path.basename(urllib.parse.urlparse(url).path)
cache_key = os.path.splitext(url_filename)[0]
local_dir = os.path.join(CACHE_DIR, cache_key)
if os.path.isdir(local_dir) and os.listdir(local_dir):
model_path = find_model_dir(local_dir)
log(f"Using cached weights: {model_path}")
return model_path
os.makedirs(local_dir, exist_ok=True)
archive_type = detect_archive_type(url)
if not archive_type:
raise ValueError(
f"Cannot determine archive type from URL: {url}\n"
f"Supported extensions: .tar, .tar.gz, .tgz, .tar.bz2, .tar.xz, .zip"
)
tmp_archive = os.path.join(CACHE_DIR, url_filename + ".tmp")
try:
download_file(url, tmp_archive)
extract_archive(tmp_archive, local_dir, archive_type)
finally:
if os.path.exists(tmp_archive):
os.remove(tmp_archive)
return find_model_dir(local_dir)
def parse_args(args):
"""
Parse argv, intercepting --model and positional model args.
Production stack invokes: python -m vllm.entrypoints.openai.api_server serve <model-url> ...
The model can appear as:
- --model <url>
- --model=<url>
- A positional arg after "serve" subcommand
If the value is a URL, download+extract and replace with local path.
Returns the modified argv list.
"""
result = []
i = 0
model_replaced = False
saw_serve = False
while i < len(args):
arg = args[i]
# --model=<value>
if arg.startswith("--model="):
value = arg.split("=", 1)[1]
if is_url(value):
local_path = download_and_extract_model(value)
result.append(f"--model={local_path}")
model_replaced = True
else:
result.append(arg)
i += 1
continue
# --model <value>
if arg == "--model":
result.append(arg)
i += 1
if i < len(args):
value = args[i]
if is_url(value):
local_path = download_and_extract_model(value)
result.append(local_path)
model_replaced = True
else:
result.append(value)
i += 1
continue
# "serve" subcommand — next positional is the model
if arg == "serve":
result.append(arg)
saw_serve = True
i += 1
# The next non-flag argument is the model
if i < len(args) and not args[i].startswith("-") and is_url(args[i]):
local_path = download_and_extract_model(args[i])
result.append(local_path)
model_replaced = True
i += 1
continue
# Positional model arg when there's no "serve" subcommand
# (first non-flag arg if no serve seen)
if not arg.startswith("-") and not saw_serve and not model_replaced:
if is_url(arg):
local_path = download_and_extract_model(arg)
result.append(local_path)
model_replaced = True
i += 1
continue
result.append(arg)
i += 1
if model_replaced:
log("Model URL was replaced with local path")
return result
def strip_shim_from_pythonpath():
"""
Remove the shim directory from PYTHONPATH so that when we exec the
real vLLM, Python doesn't find our shadow package again (infinite loop).
"""
pp = os.environ.get("PYTHONPATH", "")
parts = [p for p in pp.split(":") if p != SHIM_DIR]
new_pp = ":".join(parts)
if new_pp != pp:
os.environ["PYTHONPATH"] = new_pp
log(f"Stripped {SHIM_DIR} from PYTHONPATH (was: {pp!r}, now: {new_pp!r})")
def main():
args = sys.argv[1:]
# Determine which vllm module was actually invoked so we exec the real one
# (could be vllm.entrypoints.cli.main, vllm.entrypoints.openai.api_server, etc.)
invoked_module = __name__ # e.g. "vllm.entrypoints.cli.main" or "vllm.entrypoints.openai.api_server"
log("=" * 50)
log("vLLM Custom Weights Shim")
log(f" Invoked as: python -m {invoked_module} {' '.join(args)}")
log("=" * 50)
# Intercept --model / positional model if it's a URL
modified_args = parse_args(args)
# Strip our shim from PYTHONPATH so the real vLLM resolves correctly
strip_shim_from_pythonpath()
# Build the real vLLM command using the same module that was invoked
vllm_cmd = [sys.executable, "-m", invoked_module] + modified_args
log(f"Launching vLLM: {' '.join(vllm_cmd)}")
# Exec into vLLM — replace this process so signals flow through cleanly
os.execvp(vllm_cmd[0], vllm_cmd)
if __name__ == "__main__":
main()
# Also run if imported as a module (some invocation paths just import the file)
main()