From 94029ffaf02f0b73e296e11cab721c23fd5a5f97 Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Fri, 27 Feb 2026 23:55:28 -0600 Subject: [PATCH] [ROCm] Derive device capability from GCN arch string without CUDA init (#35069) Signed-off-by: Andreas Karatzas --- vllm/platforms/rocm.py | 111 ++++++++++++++++++++++++++++++++++--- vllm/utils/system_utils.py | 16 ++++++ 2 files changed, 120 insertions(+), 7 deletions(-) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index e867ebbd6..ab4c3e074 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -6,6 +6,7 @@ from datetime import timedelta from functools import cache, lru_cache, wraps from typing import TYPE_CHECKING +import regex as re import torch from torch.distributed import PrefixStore, ProcessGroup from torch.distributed.distributed_c10d import is_nccl_available @@ -64,13 +65,29 @@ _ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = { "0x744c": "AMD_Radeon_RX7900XTX", } -# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES` -if "HIP_VISIBLE_DEVICES" in os.environ: - val = os.environ["HIP_VISIBLE_DEVICES"] - if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None): - assert val == cuda_val - else: - os.environ["CUDA_VISIBLE_DEVICES"] = val + +def _sync_hip_cuda_env_vars(): + """Ensure HIP_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES are consistent. + Treats empty string as unset. Raises on genuine conflicts.""" + hip_val = os.environ.get("HIP_VISIBLE_DEVICES") or None + cuda_val = os.environ.get("CUDA_VISIBLE_DEVICES") or None + + if hip_val is not None and cuda_val is not None: + if hip_val != cuda_val: + raise ValueError( + f"Inconsistent GPU visibility env vars: " + f"HIP_VISIBLE_DEVICES='{hip_val}' vs " + f"CUDA_VISIBLE_DEVICES='{cuda_val}'. " + f"Please set only one, or ensure they match." + ) + elif hip_val is not None: + os.environ["CUDA_VISIBLE_DEVICES"] = hip_val + elif cuda_val is not None: + os.environ["HIP_VISIBLE_DEVICES"] = cuda_val + + +# Sync at import time - catches misconfigurations from process start. +_sync_hip_cuda_env_vars() # AMDSMI utils # Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`, @@ -134,6 +151,77 @@ _ON_GFX942 = "gfx942" in _GCN_ARCH _ON_GFX950 = "gfx950" in _GCN_ARCH +def _capability_from_gcn_arch(gcn_arch: str) -> tuple[int, int] | None: + """ + Parse (major, minor) from a GCN arch string, mirroring how + HIP derives hipDeviceProp_t.major / .minor. + + Format: gfx + - 1-digit major (gfx9xx): "gfx" + M + m + stepping + - 2-digit major (gfx1xxx): "gfx" + MM + m + stepping + + Examples: + gfx90a -> (9, 0) gfx942 -> (9, 4) gfx950 -> (9, 5) + gfx1100 -> (11, 0) gfx1101 -> (11, 0) gfx1200 -> (12, 0) + + Returns None only when the string is not gfx-prefixed at all + (i.e. not a ROCm arch string). Raises on any string that looks + like a GCN arch but does not match a known layout. + """ + m = re.match(r"gfx(\d+)", gcn_arch) + if not m: + # Not a gfx string at all — caller should fall back to torch.cuda + return None + + digits = m.group(1) + n = len(digits) + + if n < 2: + raise ValueError( + f"GCN arch '{gcn_arch}' has too few digits ({n}) after 'gfx' " + f"to derive a (major, minor) capability. " + f"Please file a vLLM issue with your GPU model." + ) + + if n in (2, 3): + # 1-digit major: gfx9 family + # len 2: major + minor (e.g. gfx90 from gfx90a) + # len 3: major + minor + step (e.g. gfx942) + major = int(digits[0]) + minor = int(digits[1]) + elif n == 4: + # 2-digit major: gfx10xx, gfx11xx, gfx12xx + # major(2) + minor(1) + stepping(1) + major = int(digits[:2]) + minor = int(digits[2]) + elif n >= 5: + raise ValueError( + f"GCN arch '{gcn_arch}' has {n} digits after 'gfx', which " + f"exceeds the known 4-digit layout (MMms). Cannot determine " + f"major/minor split unambiguously. " + f"Please file a vLLM issue with your GPU model." + ) + + if major < 9: + raise ValueError( + f"Parsed unknown ROCm architecture from GCN arch '{gcn_arch}': " + f"major={major}, minor={minor}. " + f"Major version < 9 is not expected for any supported AMD GPU. " + f"Please file a vLLM issue with your GPU model." + ) + + if major > 12: + raise ValueError( + f"Parsed unknown ROCm architecture from GCN arch '{gcn_arch}': " + f"major={major}, minor={minor}. " + f"Major version > 12 is beyond currently known AMD generations. " + f"Please file a vLLM issue with your GPU model so support " + f"can be added." + ) + + return (major, minor) + + def on_gfx1x() -> bool: return _ON_GFX1X @@ -444,6 +532,15 @@ class RocmPlatform(Platform): @classmethod @lru_cache(maxsize=8) def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None: + cap = _capability_from_gcn_arch(_GCN_ARCH) + if cap is not None: + return DeviceCapability(major=cap[0], minor=cap[1]) + + logger.warning_once( + "Could not derive device capability from GCN arch '%s', " + "falling back to torch.cuda (this will initialize CUDA).", + _GCN_ARCH, + ) major, minor = torch.cuda.get_device_capability(device_id) return DeviceCapability(major=major, minor=minor) diff --git a/vllm/utils/system_utils.py b/vllm/utils/system_utils.py index 840056e8b..4bd538879 100644 --- a/vllm/utils/system_utils.py +++ b/vllm/utils/system_utils.py @@ -16,6 +16,7 @@ import psutil import vllm.envs as envs from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.platforms.interface import in_wsl from vllm.ray.lazy_utils import is_in_ray_actor @@ -111,6 +112,17 @@ def unique_filepath(fn: Callable[[int], Path]) -> Path: # Process management utilities +def _sync_visible_devices_env_vars(): + """Sync HIP/CUDA visibility env vars before spawning (ROCm only).""" + + if not current_platform.is_rocm(): + return + + from vllm.platforms.rocm import _sync_hip_cuda_env_vars + + _sync_hip_cuda_env_vars() + + def _maybe_force_spawn(): """Check if we need to force the use of the `spawn` multiprocessing start method. @@ -156,6 +168,10 @@ def get_mp_context(): VLLM_WORKER_MULTIPROC_METHOD. """ _maybe_force_spawn() + # (ROCm): Sync GPU visibility env vars so spawned children inherit + # consistent values. Must run after _maybe_force_spawn and regardless + # of whether spawn was already set. + _sync_visible_devices_env_vars() mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD return multiprocessing.get_context(mp_method)