Cherry pick [ROCm] [CI] [Release] Rocm wheel pipeline with sccache #32264
Signed-off-by: Kevin H. Luu <khluu000@gmail.com>
This commit is contained in:
221
tools/vllm-rocm/pin_rocm_dependencies.py
Normal file
221
tools/vllm-rocm/pin_rocm_dependencies.py
Normal file
@@ -0,0 +1,221 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Pin vLLM dependencies to exact versions of custom ROCm wheels.
|
||||
|
||||
This script modifies vLLM's requirements files to replace version constraints
|
||||
with exact versions of custom-built ROCm wheels (torch, triton, torchvision, amdsmi).
|
||||
|
||||
This ensures that 'pip install vllm' automatically installs the correct custom wheels
|
||||
instead of allowing pip to download different versions from PyPI.
|
||||
"""
|
||||
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def extract_version_from_wheel(wheel_name: str) -> str:
|
||||
"""
|
||||
Extract version from wheel filename.
|
||||
|
||||
Example:
|
||||
torch-2.9.0a0+git1c57644-cp312-cp312-linux_x86_64.whl -> 2.9.0a0+git1c57644
|
||||
triton-3.4.0-cp312-cp312-linux_x86_64.whl -> 3.4.0
|
||||
"""
|
||||
# Wheel format:
|
||||
# {distribution}-{version}(-{build tag})?-{python}-{abi}-{platform}.whl
|
||||
parts = wheel_name.replace(".whl", "").split("-")
|
||||
|
||||
if len(parts) < 5:
|
||||
raise ValueError(f"Invalid wheel filename format: {wheel_name}")
|
||||
|
||||
# Version is the second part
|
||||
version = parts[1]
|
||||
return version
|
||||
|
||||
|
||||
def get_custom_wheel_versions(install_dir: str) -> dict[str, str]:
|
||||
"""
|
||||
Read /install directory and extract versions of custom wheels.
|
||||
|
||||
Returns:
|
||||
Dict mapping package names to exact versions
|
||||
"""
|
||||
install_path = Path(install_dir)
|
||||
if not install_path.exists():
|
||||
print(f"ERROR: Install directory not found: {install_dir}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
versions = {}
|
||||
|
||||
# Map wheel prefixes to package names
|
||||
# IMPORTANT: Use dashes to avoid matching substrings
|
||||
# (e.g., 'torch' would match 'torchvision')
|
||||
# ORDER MATTERS: This order is preserved when pinning dependencies
|
||||
# in requirements files
|
||||
package_mapping = [
|
||||
("torch-", "torch"), # Match torch- (not torchvision)
|
||||
("triton-", "triton"), # Match triton- (not triton_kernels)
|
||||
("triton_kernels-", "triton-kernels"), # Match triton_kernels-
|
||||
("torchvision-", "torchvision"), # Match torchvision-
|
||||
("torchaudio-", "torchaudio"), # Match torchaudio-
|
||||
("amdsmi-", "amdsmi"), # Match amdsmi-
|
||||
("flash_attn-", "flash-attn"), # Match flash_attn-
|
||||
("aiter-", "aiter"), # Match aiter-
|
||||
]
|
||||
|
||||
for wheel_file in install_path.glob("*.whl"):
|
||||
wheel_name = wheel_file.name
|
||||
|
||||
for prefix, package_name in package_mapping:
|
||||
if wheel_name.startswith(prefix):
|
||||
try:
|
||||
version = extract_version_from_wheel(wheel_name)
|
||||
versions[package_name] = version
|
||||
print(f"Found {package_name}=={version}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"WARNING: Could not extract version from {wheel_name}: {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
break
|
||||
|
||||
# Return versions in the order defined by package_mapping
|
||||
ordered_versions = {}
|
||||
for _, package_name in package_mapping:
|
||||
if package_name in versions:
|
||||
ordered_versions[package_name] = versions[package_name]
|
||||
return ordered_versions
|
||||
|
||||
|
||||
def pin_dependencies_in_requirements(requirements_path: str, versions: dict[str, str]):
|
||||
"""
|
||||
Insert custom wheel pins at the TOP of requirements file.
|
||||
|
||||
This ensures that when setup.py processes the file line-by-line,
|
||||
custom wheels (torch, triton, etc.) are encountered FIRST, before
|
||||
any `-r common.txt` includes that might pull in other dependencies.
|
||||
|
||||
Creates:
|
||||
# Custom ROCm wheel pins (auto-generated)
|
||||
torch==2.9.0a0+git1c57644
|
||||
triton==3.4.0
|
||||
torchvision==0.23.0a0+824e8c8
|
||||
amdsmi==26.1.0+5df6c765
|
||||
|
||||
-r common.txt
|
||||
... rest of file ...
|
||||
"""
|
||||
requirements_file = Path(requirements_path)
|
||||
|
||||
if not requirements_file.exists():
|
||||
print(
|
||||
f"ERROR: Requirements file not found: {requirements_path}", file=sys.stderr
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Backup original file
|
||||
backup_file = requirements_file.with_suffix(requirements_file.suffix + ".bak")
|
||||
with open(requirements_file) as f:
|
||||
original_lines = f.readlines()
|
||||
|
||||
# Write backup
|
||||
with open(backup_file, "w") as f:
|
||||
f.writelines(original_lines)
|
||||
|
||||
# Build header with pinned custom wheels
|
||||
header_lines = [
|
||||
"# Custom ROCm wheel pins (auto-generated by pin_rocm_dependencies.py)\n",
|
||||
"# These must come FIRST to ensure correct dependency resolution\n",
|
||||
]
|
||||
|
||||
for package_name, exact_version in versions.items():
|
||||
header_lines.append(f"{package_name}=={exact_version}\n")
|
||||
|
||||
header_lines.append("\n") # Blank line separator
|
||||
|
||||
# Filter out any existing entries for custom packages from original file
|
||||
filtered_lines = []
|
||||
removed_packages = []
|
||||
|
||||
for line in original_lines:
|
||||
stripped = line.strip()
|
||||
should_keep = True
|
||||
|
||||
# Check if this line is for one of our custom packages
|
||||
if stripped and not stripped.startswith("#") and not stripped.startswith("-"):
|
||||
for package_name in versions:
|
||||
# Handle both hyphen and underscore variations
|
||||
pattern_name = package_name.replace("-", "[-_]")
|
||||
pattern = rf"^{pattern_name}\s*[=<>]=?\s*[\d.a-zA-Z+]+"
|
||||
|
||||
if re.match(pattern, stripped, re.IGNORECASE):
|
||||
removed_packages.append(f"{package_name}: {stripped}")
|
||||
should_keep = False
|
||||
break
|
||||
|
||||
if should_keep:
|
||||
filtered_lines.append(line)
|
||||
|
||||
# Combine: header + filtered original content
|
||||
final_lines = header_lines + filtered_lines
|
||||
|
||||
# Write modified content
|
||||
with open(requirements_file, "w") as f:
|
||||
f.writelines(final_lines)
|
||||
|
||||
# Print summary
|
||||
print("\n✓ Inserted custom wheel pins at TOP of requirements:", file=sys.stderr)
|
||||
for package_name, exact_version in versions.items():
|
||||
print(f" - {package_name}=={exact_version}", file=sys.stderr)
|
||||
|
||||
if removed_packages:
|
||||
print("\n✓ Removed old package entries:", file=sys.stderr)
|
||||
for pkg in removed_packages:
|
||||
print(f" - {pkg}", file=sys.stderr)
|
||||
|
||||
print(f"\n✓ Patched requirements file: {requirements_path}", file=sys.stderr)
|
||||
print(f" Backup saved: {backup_file}", file=sys.stderr)
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) != 3:
|
||||
print(
|
||||
f"Usage: {sys.argv[0]} <install_dir> <requirements_file>", file=sys.stderr
|
||||
)
|
||||
print(
|
||||
f"Example: {sys.argv[0]} /install /app/vllm/requirements/rocm.txt",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
install_dir = sys.argv[1]
|
||||
requirements_path = sys.argv[2]
|
||||
|
||||
print("=" * 70, file=sys.stderr)
|
||||
print("Pinning vLLM dependencies to custom ROCm wheel versions", file=sys.stderr)
|
||||
print("=" * 70, file=sys.stderr)
|
||||
|
||||
# Get versions from custom wheels
|
||||
print(f"\nScanning {install_dir} for custom wheels...", file=sys.stderr)
|
||||
versions = get_custom_wheel_versions(install_dir)
|
||||
|
||||
if not versions:
|
||||
print("\nERROR: No custom wheels found in /install!", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Pin dependencies in requirements file
|
||||
print(f"\nPatching {requirements_path}...", file=sys.stderr)
|
||||
pin_dependencies_in_requirements(requirements_path, versions)
|
||||
|
||||
print("\n" + "=" * 70, file=sys.stderr)
|
||||
print("✓ Dependency pinning complete!", file=sys.stderr)
|
||||
print("=" * 70, file=sys.stderr)
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user