223 lines
7.5 KiB
Python
223 lines
7.5 KiB
Python
#!/usr/bin/env python3
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""
|
|
Pin vLLM dependencies to exact versions of custom ROCm wheels.
|
|
|
|
This script modifies vLLM's requirements files to replace version constraints
|
|
with exact versions of custom-built ROCm wheels (torch, triton, torchvision, amdsmi).
|
|
|
|
This ensures that 'pip install vllm' automatically installs the correct custom wheels
|
|
instead of allowing pip to download different versions from PyPI.
|
|
"""
|
|
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import regex as re
|
|
|
|
|
|
def extract_version_from_wheel(wheel_name: str) -> str:
|
|
"""
|
|
Extract version from wheel filename.
|
|
|
|
Example:
|
|
torch-2.9.0a0+git1c57644-cp312-cp312-linux_x86_64.whl -> 2.9.0a0+git1c57644
|
|
triton-3.4.0-cp312-cp312-linux_x86_64.whl -> 3.4.0
|
|
"""
|
|
# Wheel format:
|
|
# {distribution}-{version}(-{build tag})?-{python}-{abi}-{platform}.whl
|
|
parts = wheel_name.replace(".whl", "").split("-")
|
|
|
|
if len(parts) < 5:
|
|
raise ValueError(f"Invalid wheel filename format: {wheel_name}")
|
|
|
|
# Version is the second part
|
|
version = parts[1]
|
|
return version
|
|
|
|
|
|
def get_custom_wheel_versions(install_dir: str) -> dict[str, str]:
|
|
"""
|
|
Read /install directory and extract versions of custom wheels.
|
|
|
|
Returns:
|
|
Dict mapping package names to exact versions
|
|
"""
|
|
install_path = Path(install_dir)
|
|
if not install_path.exists():
|
|
print(f"ERROR: Install directory not found: {install_dir}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
versions = {}
|
|
|
|
# Map wheel prefixes to package names
|
|
# IMPORTANT: Use dashes to avoid matching substrings
|
|
# (e.g., 'torch' would match 'torchvision')
|
|
# ORDER MATTERS: This order is preserved when pinning dependencies
|
|
# in requirements files
|
|
package_mapping = [
|
|
("torch-", "torch"), # Match torch- (not torchvision)
|
|
("triton-", "triton"), # Match triton- (not triton_kernels)
|
|
("triton_kernels-", "triton-kernels"), # Match triton_kernels-
|
|
("torchvision-", "torchvision"), # Match torchvision-
|
|
("torchaudio-", "torchaudio"), # Match torchaudio-
|
|
("amdsmi-", "amdsmi"), # Match amdsmi-
|
|
("flash_attn-", "flash-attn"), # Match flash_attn-
|
|
("amd_aiter-", "amd-aiter"), # Match amd_aiter-
|
|
]
|
|
|
|
for wheel_file in install_path.glob("*.whl"):
|
|
wheel_name = wheel_file.name
|
|
|
|
for prefix, package_name in package_mapping:
|
|
if wheel_name.startswith(prefix):
|
|
try:
|
|
version = extract_version_from_wheel(wheel_name)
|
|
versions[package_name] = version
|
|
print(f"Found {package_name}=={version}", file=sys.stderr)
|
|
except Exception as e:
|
|
print(
|
|
f"WARNING: Could not extract version from {wheel_name}: {e}",
|
|
file=sys.stderr,
|
|
)
|
|
break
|
|
|
|
# Return versions in the order defined by package_mapping
|
|
ordered_versions = {}
|
|
for _, package_name in package_mapping:
|
|
if package_name in versions:
|
|
ordered_versions[package_name] = versions[package_name]
|
|
return ordered_versions
|
|
|
|
|
|
def pin_dependencies_in_requirements(requirements_path: str, versions: dict[str, str]):
|
|
"""
|
|
Insert custom wheel pins at the TOP of requirements file.
|
|
|
|
This ensures that when setup.py processes the file line-by-line,
|
|
custom wheels (torch, triton, etc.) are encountered FIRST, before
|
|
any `-r common.txt` includes that might pull in other dependencies.
|
|
|
|
Creates:
|
|
# Custom ROCm wheel pins (auto-generated)
|
|
torch==2.9.0a0+git1c57644
|
|
triton==3.4.0
|
|
torchvision==0.23.0a0+824e8c8
|
|
amdsmi==26.1.0+5df6c765
|
|
|
|
-r common.txt
|
|
... rest of file ...
|
|
"""
|
|
requirements_file = Path(requirements_path)
|
|
|
|
if not requirements_file.exists():
|
|
print(
|
|
f"ERROR: Requirements file not found: {requirements_path}", file=sys.stderr
|
|
)
|
|
sys.exit(1)
|
|
|
|
# Backup original file
|
|
backup_file = requirements_file.with_suffix(requirements_file.suffix + ".bak")
|
|
with open(requirements_file) as f:
|
|
original_lines = f.readlines()
|
|
|
|
# Write backup
|
|
with open(backup_file, "w") as f:
|
|
f.writelines(original_lines)
|
|
|
|
# Build header with pinned custom wheels
|
|
header_lines = [
|
|
"# Custom ROCm wheel pins (auto-generated by pin_rocm_dependencies.py)\n",
|
|
"# These must come FIRST to ensure correct dependency resolution\n",
|
|
]
|
|
|
|
for package_name, exact_version in versions.items():
|
|
header_lines.append(f"{package_name}=={exact_version}\n")
|
|
|
|
header_lines.append("\n") # Blank line separator
|
|
|
|
# Filter out any existing entries for custom packages from original file
|
|
filtered_lines = []
|
|
removed_packages = []
|
|
|
|
for line in original_lines:
|
|
stripped = line.strip()
|
|
should_keep = True
|
|
|
|
# Check if this line is for one of our custom packages
|
|
if stripped and not stripped.startswith("#") and not stripped.startswith("-"):
|
|
for package_name in versions:
|
|
# Handle both hyphen and underscore variations
|
|
pattern_name = package_name.replace("-", "[-_]")
|
|
pattern = rf"^{pattern_name}\s*[=<>]=?\s*[\d.a-zA-Z+]+"
|
|
|
|
if re.match(pattern, stripped, re.IGNORECASE):
|
|
removed_packages.append(f"{package_name}: {stripped}")
|
|
should_keep = False
|
|
break
|
|
|
|
if should_keep:
|
|
filtered_lines.append(line)
|
|
|
|
# Combine: header + filtered original content
|
|
final_lines = header_lines + filtered_lines
|
|
|
|
# Write modified content
|
|
with open(requirements_file, "w") as f:
|
|
f.writelines(final_lines)
|
|
|
|
# Print summary
|
|
print("\n✓ Inserted custom wheel pins at TOP of requirements:", file=sys.stderr)
|
|
for package_name, exact_version in versions.items():
|
|
print(f" - {package_name}=={exact_version}", file=sys.stderr)
|
|
|
|
if removed_packages:
|
|
print("\n✓ Removed old package entries:", file=sys.stderr)
|
|
for pkg in removed_packages:
|
|
print(f" - {pkg}", file=sys.stderr)
|
|
|
|
print(f"\n✓ Patched requirements file: {requirements_path}", file=sys.stderr)
|
|
print(f" Backup saved: {backup_file}", file=sys.stderr)
|
|
|
|
|
|
def main():
|
|
if len(sys.argv) != 3:
|
|
print(
|
|
f"Usage: {sys.argv[0]} <install_dir> <requirements_file>", file=sys.stderr
|
|
)
|
|
print(
|
|
f"Example: {sys.argv[0]} /install /app/vllm/requirements/rocm.txt",
|
|
file=sys.stderr,
|
|
)
|
|
sys.exit(1)
|
|
|
|
install_dir = sys.argv[1]
|
|
requirements_path = sys.argv[2]
|
|
|
|
print("=" * 70, file=sys.stderr)
|
|
print("Pinning vLLM dependencies to custom ROCm wheel versions", file=sys.stderr)
|
|
print("=" * 70, file=sys.stderr)
|
|
|
|
# Get versions from custom wheels
|
|
print(f"\nScanning {install_dir} for custom wheels...", file=sys.stderr)
|
|
versions = get_custom_wheel_versions(install_dir)
|
|
|
|
if not versions:
|
|
print("\nERROR: No custom wheels found in /install!", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
# Pin dependencies in requirements file
|
|
print(f"\nPatching {requirements_path}...", file=sys.stderr)
|
|
pin_dependencies_in_requirements(requirements_path, versions)
|
|
|
|
print("\n" + "=" * 70, file=sys.stderr)
|
|
print("✓ Dependency pinning complete!", file=sys.stderr)
|
|
print("=" * 70, file=sys.stderr)
|
|
|
|
sys.exit(0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|