[bugfix] [ROCm] Fix premature CUDA initialization in platform detection (#33941)
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
This commit is contained in:
committed by
GitHub
parent
207c3a0c20
commit
4a2d00eafd
@@ -534,6 +534,7 @@ steps:
|
||||
- tests/cuda
|
||||
commands:
|
||||
- pytest -v -s cuda/test_cuda_context.py
|
||||
- pytest -v -s cuda/test_platform_no_cuda_init.py
|
||||
|
||||
- label: Samplers Test # 56min
|
||||
timeout_in_minutes: 75
|
||||
|
||||
@@ -9,6 +9,7 @@ steps:
|
||||
- tests/cuda
|
||||
commands:
|
||||
- pytest -v -s cuda/test_cuda_context.py
|
||||
- pytest -v -s cuda/test_platform_no_cuda_init.py
|
||||
|
||||
- label: Cudagraph
|
||||
timeout_in_minutes: 20
|
||||
|
||||
23
tests/cuda/scripts/check_device_count_respects_env.py
Normal file
23
tests/cuda/scripts/check_device_count_respects_env.py
Normal file
@@ -0,0 +1,23 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Check that device_count respects CUDA_VISIBLE_DEVICES after platform import."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
for key in ["CUDA_VISIBLE_DEVICES", "HIP_VISIBLE_DEVICES", "ROCR_VISIBLE_DEVICES"]:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
import torch # noqa: E402
|
||||
|
||||
from vllm.platforms import current_platform # noqa: F401, E402
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
count = torch.cuda.device_count()
|
||||
|
||||
if count == 0:
|
||||
sys.exit(0) # Skip: no GPUs available
|
||||
|
||||
assert count == 1, f"device_count()={count}, expected 1"
|
||||
print("OK")
|
||||
20
tests/cuda/scripts/check_platform_no_cuda_init.py
Normal file
20
tests/cuda/scripts/check_platform_no_cuda_init.py
Normal file
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Check that vllm.platforms import does not initialize CUDA."""
|
||||
|
||||
import os
|
||||
|
||||
for key in ["CUDA_VISIBLE_DEVICES", "HIP_VISIBLE_DEVICES", "ROCR_VISIBLE_DEVICES"]:
|
||||
os.environ.pop(key, None)
|
||||
|
||||
import torch # noqa: E402
|
||||
|
||||
assert not torch.cuda.is_initialized(), "CUDA initialized before import"
|
||||
|
||||
from vllm.platforms import current_platform # noqa: E402
|
||||
|
||||
assert not torch.cuda.is_initialized(), (
|
||||
f"CUDA was initialized during vllm.platforms import on {current_platform}"
|
||||
)
|
||||
print("OK")
|
||||
48
tests/cuda/test_platform_no_cuda_init.py
Normal file
48
tests/cuda/test_platform_no_cuda_init.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Test that platform imports do not prematurely initialize CUDA.
|
||||
|
||||
This is critical for Ray-based multi-GPU setups where workers need to
|
||||
set CUDA_VISIBLE_DEVICES after importing vLLM but before CUDA is initialized.
|
||||
If CUDA is initialized during import, device_count() gets locked and ignores
|
||||
subsequent env var changes.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
SCRIPTS_DIR = Path(__file__).parent / "scripts"
|
||||
|
||||
|
||||
def run_script(script_name: str) -> subprocess.CompletedProcess:
|
||||
"""Run a test script in a subprocess with clean CUDA state."""
|
||||
script_path = SCRIPTS_DIR / script_name
|
||||
return subprocess.run(
|
||||
[sys.executable, str(script_path)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
|
||||
def test_platform_import_does_not_init_cuda():
|
||||
"""Test that importing vllm.platforms does not initialize CUDA."""
|
||||
result = run_script("check_platform_no_cuda_init.py")
|
||||
if result.returncode != 0:
|
||||
pytest.fail(f"Platform import initialized CUDA:\n{result.stderr}")
|
||||
|
||||
|
||||
def test_device_count_respects_env_after_platform_import():
|
||||
"""Test that device_count respects CUDA_VISIBLE_DEVICES after import."""
|
||||
result = run_script("check_device_count_respects_env.py")
|
||||
if result.returncode != 0:
|
||||
pytest.fail(
|
||||
f"device_count does not respect env var after import:\n{result.stderr}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -87,33 +87,67 @@ def with_amdsmi_context(fn):
|
||||
return wrapper
|
||||
|
||||
|
||||
@with_amdsmi_context
|
||||
def _query_gcn_arch_from_amdsmi() -> str:
|
||||
"""Query GCN arch from amdsmi. Raises if not available."""
|
||||
handles = amdsmi_get_processor_handles()
|
||||
if handles:
|
||||
asic_info = amdsmi_get_gpu_asic_info(handles[0])
|
||||
# Use target_graphics_version which contains the gfx name
|
||||
# e.g., 'gfx942' for MI300X/MI325X
|
||||
target_gfx = asic_info.get("target_graphics_version", "")
|
||||
if target_gfx:
|
||||
return target_gfx
|
||||
raise RuntimeError("amdsmi did not return valid GCN arch")
|
||||
|
||||
|
||||
@cache
|
||||
def _get_gcn_arch_via_amdsmi() -> str:
|
||||
"""
|
||||
Get the GCN architecture name using amdsmi instead of torch.cuda.
|
||||
This avoids initializing CUDA, which is important for Ray workers
|
||||
that need to set CUDA_VISIBLE_DEVICES after importing vLLM.
|
||||
"""
|
||||
try:
|
||||
return _query_gcn_arch_from_amdsmi()
|
||||
except Exception as e:
|
||||
logger.debug("Failed to get GCN arch via amdsmi: %s", e)
|
||||
logger.warning_once(
|
||||
"Failed to get GCN arch via amdsmi, falling back to torch.cuda. "
|
||||
"This will initialize CUDA and may cause "
|
||||
"issues if CUDA_VISIBLE_DEVICES is not set yet."
|
||||
)
|
||||
# Ultimate fallback: use torch.cuda (will initialize CUDA)
|
||||
return torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
|
||||
|
||||
@cache
|
||||
def on_gfx1x() -> bool:
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
GPU_ARCH = _get_gcn_arch_via_amdsmi()
|
||||
return any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
|
||||
|
||||
|
||||
@cache
|
||||
def on_mi3xx() -> bool:
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
GPU_ARCH = _get_gcn_arch_via_amdsmi()
|
||||
return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])
|
||||
|
||||
|
||||
@cache
|
||||
def on_gfx9() -> bool:
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
GPU_ARCH = _get_gcn_arch_via_amdsmi()
|
||||
return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||
|
||||
|
||||
@cache
|
||||
def on_gfx942() -> bool:
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
GPU_ARCH = _get_gcn_arch_via_amdsmi()
|
||||
return any(arch in GPU_ARCH for arch in ["gfx942"])
|
||||
|
||||
|
||||
@cache
|
||||
def on_gfx950() -> bool:
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
GPU_ARCH = _get_gcn_arch_via_amdsmi()
|
||||
return any(arch in GPU_ARCH for arch in ["gfx950"])
|
||||
|
||||
|
||||
@@ -129,7 +163,7 @@ def use_rocm_custom_paged_attention(
|
||||
alibi_slopes: torch.Tensor | None = None,
|
||||
sinks: torch.Tensor | None = None,
|
||||
) -> bool:
|
||||
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
GPU_ARCH = _get_gcn_arch_via_amdsmi()
|
||||
ON_GFX9 = any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942", "gfx950"])
|
||||
ON_GFX11_GFX12 = any(arch in GPU_ARCH for arch in ["gfx11", "gfx12"])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user