[ROCm] Derive device capability from GCN arch string without CUDA init (#35069)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -6,6 +6,7 @@ from datetime import timedelta
|
|||||||
from functools import cache, lru_cache, wraps
|
from functools import cache, lru_cache, wraps
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import regex as re
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed import PrefixStore, ProcessGroup
|
from torch.distributed import PrefixStore, ProcessGroup
|
||||||
from torch.distributed.distributed_c10d import is_nccl_available
|
from torch.distributed.distributed_c10d import is_nccl_available
|
||||||
@@ -64,13 +65,29 @@ _ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
|
|||||||
"0x744c": "AMD_Radeon_RX7900XTX",
|
"0x744c": "AMD_Radeon_RX7900XTX",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES`
|
|
||||||
if "HIP_VISIBLE_DEVICES" in os.environ:
|
def _sync_hip_cuda_env_vars():
|
||||||
val = os.environ["HIP_VISIBLE_DEVICES"]
|
"""Ensure HIP_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES are consistent.
|
||||||
if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None):
|
Treats empty string as unset. Raises on genuine conflicts."""
|
||||||
assert val == cuda_val
|
hip_val = os.environ.get("HIP_VISIBLE_DEVICES") or None
|
||||||
else:
|
cuda_val = os.environ.get("CUDA_VISIBLE_DEVICES") or None
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = val
|
|
||||||
|
if hip_val is not None and cuda_val is not None:
|
||||||
|
if hip_val != cuda_val:
|
||||||
|
raise ValueError(
|
||||||
|
f"Inconsistent GPU visibility env vars: "
|
||||||
|
f"HIP_VISIBLE_DEVICES='{hip_val}' vs "
|
||||||
|
f"CUDA_VISIBLE_DEVICES='{cuda_val}'. "
|
||||||
|
f"Please set only one, or ensure they match."
|
||||||
|
)
|
||||||
|
elif hip_val is not None:
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = hip_val
|
||||||
|
elif cuda_val is not None:
|
||||||
|
os.environ["HIP_VISIBLE_DEVICES"] = cuda_val
|
||||||
|
|
||||||
|
|
||||||
|
# Sync at import time - catches misconfigurations from process start.
|
||||||
|
_sync_hip_cuda_env_vars()
|
||||||
|
|
||||||
# AMDSMI utils
|
# AMDSMI utils
|
||||||
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
|
# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`,
|
||||||
@@ -134,6 +151,77 @@ _ON_GFX942 = "gfx942" in _GCN_ARCH
|
|||||||
_ON_GFX950 = "gfx950" in _GCN_ARCH
|
_ON_GFX950 = "gfx950" in _GCN_ARCH
|
||||||
|
|
||||||
|
|
||||||
|
def _capability_from_gcn_arch(gcn_arch: str) -> tuple[int, int] | None:
|
||||||
|
"""
|
||||||
|
Parse (major, minor) from a GCN arch string, mirroring how
|
||||||
|
HIP derives hipDeviceProp_t.major / .minor.
|
||||||
|
|
||||||
|
Format: gfx<MAJOR><MINOR><STEPPING>
|
||||||
|
- 1-digit major (gfx9xx): "gfx" + M + m + stepping
|
||||||
|
- 2-digit major (gfx1xxx): "gfx" + MM + m + stepping
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
gfx90a -> (9, 0) gfx942 -> (9, 4) gfx950 -> (9, 5)
|
||||||
|
gfx1100 -> (11, 0) gfx1101 -> (11, 0) gfx1200 -> (12, 0)
|
||||||
|
|
||||||
|
Returns None only when the string is not gfx-prefixed at all
|
||||||
|
(i.e. not a ROCm arch string). Raises on any string that looks
|
||||||
|
like a GCN arch but does not match a known layout.
|
||||||
|
"""
|
||||||
|
m = re.match(r"gfx(\d+)", gcn_arch)
|
||||||
|
if not m:
|
||||||
|
# Not a gfx string at all — caller should fall back to torch.cuda
|
||||||
|
return None
|
||||||
|
|
||||||
|
digits = m.group(1)
|
||||||
|
n = len(digits)
|
||||||
|
|
||||||
|
if n < 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"GCN arch '{gcn_arch}' has too few digits ({n}) after 'gfx' "
|
||||||
|
f"to derive a (major, minor) capability. "
|
||||||
|
f"Please file a vLLM issue with your GPU model."
|
||||||
|
)
|
||||||
|
|
||||||
|
if n in (2, 3):
|
||||||
|
# 1-digit major: gfx9 family
|
||||||
|
# len 2: major + minor (e.g. gfx90 from gfx90a)
|
||||||
|
# len 3: major + minor + step (e.g. gfx942)
|
||||||
|
major = int(digits[0])
|
||||||
|
minor = int(digits[1])
|
||||||
|
elif n == 4:
|
||||||
|
# 2-digit major: gfx10xx, gfx11xx, gfx12xx
|
||||||
|
# major(2) + minor(1) + stepping(1)
|
||||||
|
major = int(digits[:2])
|
||||||
|
minor = int(digits[2])
|
||||||
|
elif n >= 5:
|
||||||
|
raise ValueError(
|
||||||
|
f"GCN arch '{gcn_arch}' has {n} digits after 'gfx', which "
|
||||||
|
f"exceeds the known 4-digit layout (MMms). Cannot determine "
|
||||||
|
f"major/minor split unambiguously. "
|
||||||
|
f"Please file a vLLM issue with your GPU model."
|
||||||
|
)
|
||||||
|
|
||||||
|
if major < 9:
|
||||||
|
raise ValueError(
|
||||||
|
f"Parsed unknown ROCm architecture from GCN arch '{gcn_arch}': "
|
||||||
|
f"major={major}, minor={minor}. "
|
||||||
|
f"Major version < 9 is not expected for any supported AMD GPU. "
|
||||||
|
f"Please file a vLLM issue with your GPU model."
|
||||||
|
)
|
||||||
|
|
||||||
|
if major > 12:
|
||||||
|
raise ValueError(
|
||||||
|
f"Parsed unknown ROCm architecture from GCN arch '{gcn_arch}': "
|
||||||
|
f"major={major}, minor={minor}. "
|
||||||
|
f"Major version > 12 is beyond currently known AMD generations. "
|
||||||
|
f"Please file a vLLM issue with your GPU model so support "
|
||||||
|
f"can be added."
|
||||||
|
)
|
||||||
|
|
||||||
|
return (major, minor)
|
||||||
|
|
||||||
|
|
||||||
def on_gfx1x() -> bool:
|
def on_gfx1x() -> bool:
|
||||||
return _ON_GFX1X
|
return _ON_GFX1X
|
||||||
|
|
||||||
@@ -444,6 +532,15 @@ class RocmPlatform(Platform):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@lru_cache(maxsize=8)
|
@lru_cache(maxsize=8)
|
||||||
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
|
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
|
||||||
|
cap = _capability_from_gcn_arch(_GCN_ARCH)
|
||||||
|
if cap is not None:
|
||||||
|
return DeviceCapability(major=cap[0], minor=cap[1])
|
||||||
|
|
||||||
|
logger.warning_once(
|
||||||
|
"Could not derive device capability from GCN arch '%s', "
|
||||||
|
"falling back to torch.cuda (this will initialize CUDA).",
|
||||||
|
_GCN_ARCH,
|
||||||
|
)
|
||||||
major, minor = torch.cuda.get_device_capability(device_id)
|
major, minor = torch.cuda.get_device_capability(device_id)
|
||||||
return DeviceCapability(major=major, minor=minor)
|
return DeviceCapability(major=major, minor=minor)
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import psutil
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.platforms.interface import in_wsl
|
from vllm.platforms.interface import in_wsl
|
||||||
from vllm.ray.lazy_utils import is_in_ray_actor
|
from vllm.ray.lazy_utils import is_in_ray_actor
|
||||||
|
|
||||||
@@ -111,6 +112,17 @@ def unique_filepath(fn: Callable[[int], Path]) -> Path:
|
|||||||
# Process management utilities
|
# Process management utilities
|
||||||
|
|
||||||
|
|
||||||
|
def _sync_visible_devices_env_vars():
|
||||||
|
"""Sync HIP/CUDA visibility env vars before spawning (ROCm only)."""
|
||||||
|
|
||||||
|
if not current_platform.is_rocm():
|
||||||
|
return
|
||||||
|
|
||||||
|
from vllm.platforms.rocm import _sync_hip_cuda_env_vars
|
||||||
|
|
||||||
|
_sync_hip_cuda_env_vars()
|
||||||
|
|
||||||
|
|
||||||
def _maybe_force_spawn():
|
def _maybe_force_spawn():
|
||||||
"""Check if we need to force the use of the `spawn` multiprocessing start
|
"""Check if we need to force the use of the `spawn` multiprocessing start
|
||||||
method.
|
method.
|
||||||
@@ -156,6 +168,10 @@ def get_mp_context():
|
|||||||
VLLM_WORKER_MULTIPROC_METHOD.
|
VLLM_WORKER_MULTIPROC_METHOD.
|
||||||
"""
|
"""
|
||||||
_maybe_force_spawn()
|
_maybe_force_spawn()
|
||||||
|
# (ROCm): Sync GPU visibility env vars so spawned children inherit
|
||||||
|
# consistent values. Must run after _maybe_force_spawn and regardless
|
||||||
|
# of whether spawn was already set.
|
||||||
|
_sync_visible_devices_env_vars()
|
||||||
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
|
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
|
||||||
return multiprocessing.get_context(mp_method)
|
return multiprocessing.get_context(mp_method)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user