diff --git a/docs/getting_started/installation/gpu.cuda.inc.md b/docs/getting_started/installation/gpu.cuda.inc.md index d51cbcfab..7895a4c65 100644 --- a/docs/getting_started/installation/gpu.cuda.inc.md +++ b/docs/getting_started/installation/gpu.cuda.inc.md @@ -118,7 +118,7 @@ There are more environment variables to control the behavior of Python-only buil * `VLLM_PRECOMPILED_WHEEL_LOCATION`: specify the exact wheel URL or local file path of a pre-compiled wheel to use. All other logic to find the wheel will be skipped. * `VLLM_PRECOMPILED_WHEEL_COMMIT`: override the commit hash to download the pre-compiled wheel. It can be `nightly` to use the last **already built** commit on the main branch. -* `VLLM_PRECOMPILED_WHEEL_VARIANT`: specify the variant subdirectory to use on the nightly index, e.g., `cu129`, `cpu`. If not specified, the CUDA variant with `VLLM_MAIN_CUDA_VERSION` will be tried, then fallback to the default variant on the remote index. +* `VLLM_PRECOMPILED_WHEEL_VARIANT`: specify the variant subdirectory to use on the nightly index, e.g., `cu129`, `cu130`, `cpu`. If not specified, the variant is auto-detected based on your system's CUDA version (from PyTorch or nvidia-smi). You can also set `VLLM_MAIN_CUDA_VERSION` to override auto-detection. You can find more information about vLLM's wheels in [Install the latest code](#install-the-latest-code). diff --git a/setup.py b/setup.py index 52d4ba486..7caaa6846 100644 --- a/setup.py +++ b/setup.py @@ -438,6 +438,49 @@ class precompiled_wheel_utils: except ImportError: return False + @staticmethod + def detect_system_cuda_variant() -> str: + """Auto-detect CUDA variant from torch, nvidia-smi, or env default.""" + + # Map CUDA major version to hosted wheel variants on wheels.vllm.ai + supported = {12: "cu129", 13: "cu130"} + + # Respect explicitly set VLLM_MAIN_CUDA_VERSION + if envs.is_set("VLLM_MAIN_CUDA_VERSION"): + v = envs.VLLM_MAIN_CUDA_VERSION + print(f"Using VLLM_MAIN_CUDA_VERSION={v}") + return "cu" + v.replace(".", "")[:3] + + # Try torch.version.cuda + cuda_version = None + try: + import torch + + cuda_version = torch.version.cuda + except Exception: + pass + + # Try nvidia-smi + if not cuda_version: + try: + out = subprocess.run( + ["nvidia-smi"], capture_output=True, text=True, timeout=10 + ) + if m := re.search(r"CUDA Version:\s*(\d+\.\d+)", out.stdout): + cuda_version = m.group(1) + except Exception: + pass + + # Fall back to default + if not cuda_version: + cuda_version = envs.VLLM_MAIN_CUDA_VERSION + + # Map to supported variant + major = int(cuda_version.split(".")[0]) + variant = supported.get(major, supported[max(supported)]) + print(f"Detected CUDA {cuda_version}, using variant {variant}") + return variant + @staticmethod def find_local_rocm_wheel() -> str | None: """Search for a local vllm wheel in common locations.""" @@ -513,8 +556,8 @@ class precompiled_wheel_utils: 1. user-specified wheel location (can be either local or remote, via VLLM_PRECOMPILED_WHEEL_LOCATION) 2. user-specified variant (VLLM_PRECOMPILED_WHEEL_VARIANT) from nightly repo - 3. the variant corresponding to VLLM_MAIN_CUDA_VERSION from nightly repo - 4. the default variant from nightly repo + or auto-detected CUDA variant based on system (torch, nvidia-smi) + 3. the default variant from nightly repo If downloading from the nightly repo, the commit can be specified via VLLM_PRECOMPILED_WHEEL_COMMIT; otherwise, the head commit in the main branch @@ -533,9 +576,11 @@ class precompiled_wheel_utils: import platform arch = platform.machine() - # try to fetch the wheel metadata from the nightly wheel repo - main_variant = "cu" + envs.VLLM_MAIN_CUDA_VERSION.replace(".", "") - variant = os.getenv("VLLM_PRECOMPILED_WHEEL_VARIANT", main_variant) + # try to fetch the wheel metadata from the nightly wheel repo, + # detecting CUDA variant from system if not specified + variant = os.getenv("VLLM_PRECOMPILED_WHEEL_VARIANT", None) + if variant is None: + variant = precompiled_wheel_utils.detect_system_cuda_variant() commit = os.getenv("VLLM_PRECOMPILED_WHEEL_COMMIT", "").lower() if not commit or len(commit) != 40: print(