[bugfix] [ROCm] Fix premature CUDA initialization in platform detection (#33941)

Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
This commit is contained in:
kourosh hakhamaneshi
2026-02-06 14:17:55 -08:00
committed by GitHub
parent 207c3a0c20
commit 4a2d00eafd
6 changed files with 133 additions and 6 deletions

View File

@@ -534,6 +534,7 @@ steps:
- tests/cuda
commands:
- pytest -v -s cuda/test_cuda_context.py
- pytest -v -s cuda/test_platform_no_cuda_init.py
- label: Samplers Test # 56min
timeout_in_minutes: 75

View File

@@ -9,6 +9,7 @@ steps:
- tests/cuda
commands:
- pytest -v -s cuda/test_cuda_context.py
- pytest -v -s cuda/test_platform_no_cuda_init.py
- label: Cudagraph
timeout_in_minutes: 20

View File

@@ -0,0 +1,23 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Check that device_count respects CUDA_VISIBLE_DEVICES after platform import."""
import os
import sys
for key in ["CUDA_VISIBLE_DEVICES", "HIP_VISIBLE_DEVICES", "ROCR_VISIBLE_DEVICES"]:
os.environ.pop(key, None)
import torch # noqa: E402
from vllm.platforms import current_platform # noqa: F401, E402
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
count = torch.cuda.device_count()
if count == 0:
sys.exit(0) # Skip: no GPUs available
assert count == 1, f"device_count()={count}, expected 1"
print("OK")

View File

@@ -0,0 +1,20 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Check that vllm.platforms import does not initialize CUDA."""
import os
for key in ["CUDA_VISIBLE_DEVICES", "HIP_VISIBLE_DEVICES", "ROCR_VISIBLE_DEVICES"]:
os.environ.pop(key, None)
import torch # noqa: E402
assert not torch.cuda.is_initialized(), "CUDA initialized before import"
from vllm.platforms import current_platform # noqa: E402
assert not torch.cuda.is_initialized(), (
f"CUDA was initialized during vllm.platforms import on {current_platform}"
)
print("OK")

View File

@@ -0,0 +1,48 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test that platform imports do not prematurely initialize CUDA.
This is critical for Ray-based multi-GPU setups where workers need to
set CUDA_VISIBLE_DEVICES after importing vLLM but before CUDA is initialized.
If CUDA is initialized during import, device_count() gets locked and ignores
subsequent env var changes.
"""
import subprocess
import sys
from pathlib import Path
import pytest
SCRIPTS_DIR = Path(__file__).parent / "scripts"
def run_script(script_name: str) -> subprocess.CompletedProcess:
"""Run a test script in a subprocess with clean CUDA state."""
script_path = SCRIPTS_DIR / script_name
return subprocess.run(
[sys.executable, str(script_path)],
capture_output=True,
text=True,
)
def test_platform_import_does_not_init_cuda():
"""Test that importing vllm.platforms does not initialize CUDA."""
result = run_script("check_platform_no_cuda_init.py")
if result.returncode != 0:
pytest.fail(f"Platform import initialized CUDA:\n{result.stderr}")
def test_device_count_respects_env_after_platform_import():
"""Test that device_count respects CUDA_VISIBLE_DEVICES after import."""
result = run_script("check_device_count_respects_env.py")
if result.returncode != 0:
pytest.fail(
f"device_count does not respect env var after import:\n{result.stderr}"
)
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -87,33 +87,67 @@ def with_amdsmi_context(fn):
return wrapper
@with_amdsmi_context
def _query_gcn_arch_from_amdsmi() -> str:
"""Query GCN arch from amdsmi. Raises if not available."""
handles = amdsmi_get_processor_handles()
if handles:
asic_info = amdsmi_get_gpu_asic_info(handles[0])
# Use target_graphics_version which contains the gfx name
# e.g., 'gfx942' for MI300X/MI325X
target_gfx = asic_info.get("target_graphics_version", "")
if target_gfx:
return target_gfx
raise RuntimeError("amdsmi did not return valid GCN arch")
@cache
def _get_gcn_arch_via_amdsmi() -> str:
"""
Get the GCN architecture name using amdsmi instead of torch.cuda.
This avoids initializing CUDA, which is important for Ray workers
that need to set CUDA_VISIBLE_DEVICES after importing vLLM.
"""
try:
return _query_gcn_arch_from_amdsmi()
except Exception as e:
logger.debug("Failed to get GCN arch via amdsmi: %s", e)
logger.warning_once(
"Failed to get GCN arch via amdsmi, falling back to torch.cuda. "
"This will initialize CUDA and may cause "
"issues if CUDA_VISIBLE_DEVICES is not set yet."
)
# Ultimate fallback: use torch.cuda (will initialize CUDA)
return torch.cuda.get_device_properties("cuda").gcnArchName
@cache
def on_gfx1x() -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
GPU_ARCH = _get_gcn_arch_via_amdsmi()
return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
@cache
def on_mi3xx() -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
GPU_ARCH = _get_gcn_arch_via_amdsmi()
return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])
@cache
def on_gfx9() -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
GPU_ARCH = _get_gcn_arch_via_amdsmi()
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
@cache
def on_gfx942() -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
GPU_ARCH = _get_gcn_arch_via_amdsmi()
return any(arch in GPU_ARCH for arch in ["gfx942"])
@cache
def on_gfx950() -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
GPU_ARCH = _get_gcn_arch_via_amdsmi()
return any(arch in GPU_ARCH for arch in ["gfx950"])
@@ -129,7 +163,7 @@ def use_rocm_custom_paged_attention(
alibi_slopes: torch.Tensor | None = None,
sinks: torch.Tensor | None = None,
) -> bool:
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
GPU_ARCH = _get_gcn_arch_via_amdsmi()
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])