#!/usr/bin/env python3 # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Pin vLLM dependencies to exact versions of custom ROCm wheels. This script modifies vLLM's requirements files to replace version constraints with exact versions of custom-built ROCm wheels (torch, triton, torchvision, amdsmi). This ensures that 'pip install vllm' automatically installs the correct custom wheels instead of allowing pip to download different versions from PyPI. """ import sys from pathlib import Path import regex as re def extract_version_from_wheel(wheel_name: str) -> str: """ Extract version from wheel filename. Example: torch-2.9.0a0+git1c57644-cp312-cp312-linux_x86_64.whl -> 2.9.0a0+git1c57644 triton-3.4.0-cp312-cp312-linux_x86_64.whl -> 3.4.0 """ # Wheel format: # {distribution}-{version}(-{build tag})?-{python}-{abi}-{platform}.whl parts = wheel_name.replace(".whl", "").split("-") if len(parts) < 5: raise ValueError(f"Invalid wheel filename format: {wheel_name}") # Version is the second part version = parts[1] return version def get_custom_wheel_versions(install_dir: str) -> dict[str, str]: """ Read /install directory and extract versions of custom wheels. Returns: Dict mapping package names to exact versions """ install_path = Path(install_dir) if not install_path.exists(): print(f"ERROR: Install directory not found: {install_dir}", file=sys.stderr) sys.exit(1) versions = {} # Map wheel prefixes to package names # IMPORTANT: Use dashes to avoid matching substrings # (e.g., 'torch' would match 'torchvision') # ORDER MATTERS: This order is preserved when pinning dependencies # in requirements files package_mapping = [ ("torch-", "torch"), # Match torch- (not torchvision) ("triton-", "triton"), # Match triton- (not triton_kernels) ("triton_kernels-", "triton-kernels"), # Match triton_kernels- ("torchvision-", "torchvision"), # Match torchvision- ("torchaudio-", "torchaudio"), # Match torchaudio- ("amdsmi-", "amdsmi"), # Match amdsmi- ("flash_attn-", "flash-attn"), # Match flash_attn- ("aiter-", "aiter"), # Match aiter- ] for wheel_file in install_path.glob("*.whl"): wheel_name = wheel_file.name for prefix, package_name in package_mapping: if wheel_name.startswith(prefix): try: version = extract_version_from_wheel(wheel_name) versions[package_name] = version print(f"Found {package_name}=={version}", file=sys.stderr) except Exception as e: print( f"WARNING: Could not extract version from {wheel_name}: {e}", file=sys.stderr, ) break # Return versions in the order defined by package_mapping ordered_versions = {} for _, package_name in package_mapping: if package_name in versions: ordered_versions[package_name] = versions[package_name] return ordered_versions def pin_dependencies_in_requirements(requirements_path: str, versions: dict[str, str]): """ Insert custom wheel pins at the TOP of requirements file. This ensures that when setup.py processes the file line-by-line, custom wheels (torch, triton, etc.) are encountered FIRST, before any `-r common.txt` includes that might pull in other dependencies. Creates: # Custom ROCm wheel pins (auto-generated) torch==2.9.0a0+git1c57644 triton==3.4.0 torchvision==0.23.0a0+824e8c8 amdsmi==26.1.0+5df6c765 -r common.txt ... rest of file ... """ requirements_file = Path(requirements_path) if not requirements_file.exists(): print( f"ERROR: Requirements file not found: {requirements_path}", file=sys.stderr ) sys.exit(1) # Backup original file backup_file = requirements_file.with_suffix(requirements_file.suffix + ".bak") with open(requirements_file) as f: original_lines = f.readlines() # Write backup with open(backup_file, "w") as f: f.writelines(original_lines) # Build header with pinned custom wheels header_lines = [ "# Custom ROCm wheel pins (auto-generated by pin_rocm_dependencies.py)\n", "# These must come FIRST to ensure correct dependency resolution\n", ] for package_name, exact_version in versions.items(): header_lines.append(f"{package_name}=={exact_version}\n") header_lines.append("\n") # Blank line separator # Filter out any existing entries for custom packages from original file filtered_lines = [] removed_packages = [] for line in original_lines: stripped = line.strip() should_keep = True # Check if this line is for one of our custom packages if stripped and not stripped.startswith("#") and not stripped.startswith("-"): for package_name in versions: # Handle both hyphen and underscore variations pattern_name = package_name.replace("-", "[-_]") pattern = rf"^{pattern_name}\s*[=<>]=?\s*[\d.a-zA-Z+]+" if re.match(pattern, stripped, re.IGNORECASE): removed_packages.append(f"{package_name}: {stripped}") should_keep = False break if should_keep: filtered_lines.append(line) # Combine: header + filtered original content final_lines = header_lines + filtered_lines # Write modified content with open(requirements_file, "w") as f: f.writelines(final_lines) # Print summary print("\nāœ“ Inserted custom wheel pins at TOP of requirements:", file=sys.stderr) for package_name, exact_version in versions.items(): print(f" - {package_name}=={exact_version}", file=sys.stderr) if removed_packages: print("\nāœ“ Removed old package entries:", file=sys.stderr) for pkg in removed_packages: print(f" - {pkg}", file=sys.stderr) print(f"\nāœ“ Patched requirements file: {requirements_path}", file=sys.stderr) print(f" Backup saved: {backup_file}", file=sys.stderr) def main(): if len(sys.argv) != 3: print( f"Usage: {sys.argv[0]} ", file=sys.stderr ) print( f"Example: {sys.argv[0]} /install /app/vllm/requirements/rocm.txt", file=sys.stderr, ) sys.exit(1) install_dir = sys.argv[1] requirements_path = sys.argv[2] print("=" * 70, file=sys.stderr) print("Pinning vLLM dependencies to custom ROCm wheel versions", file=sys.stderr) print("=" * 70, file=sys.stderr) # Get versions from custom wheels print(f"\nScanning {install_dir} for custom wheels...", file=sys.stderr) versions = get_custom_wheel_versions(install_dir) if not versions: print("\nERROR: No custom wheels found in /install!", file=sys.stderr) sys.exit(1) # Pin dependencies in requirements file print(f"\nPatching {requirements_path}...", file=sys.stderr) pin_dependencies_in_requirements(requirements_path, versions) print("\n" + "=" * 70, file=sys.stderr) print("āœ“ Dependency pinning complete!", file=sys.stderr) print("=" * 70, file=sys.stderr) sys.exit(0) if __name__ == "__main__": main()