replace cuda_device_count_stateless() to current_platform.device_count() (#37841)

Signed-off-by: Liao, Wei <wei.liao@intel.com>
Signed-off-by: wliao2 <wei.liao@intel.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
wliao2
2026-03-31 07:32:54 -07:00
committed by GitHub
parent e8057c00bc
commit 4dfad17ed1
20 changed files with 96 additions and 92 deletions

View File

@@ -6,7 +6,6 @@ import pytest
from vllm.config import CompilationMode
from vllm.platforms import current_platform
from vllm.utils.torch_utils import cuda_device_count_stateless
from ...utils import compare_all_settings
@@ -109,10 +108,10 @@ def test_compile_correctness(
tp_size = test_setting.tp_size
attn_backend = test_setting.attn_backend
method = test_setting.method
if cuda_device_count_stateless() < pp_size * tp_size:
if current_platform.device_count() < pp_size * tp_size:
pytest.skip(
f"Need at least {pp_size}*{tp_size} CUDA gpus but got "
f"{cuda_device_count_stateless()}"
f"{current_platform.device_count()}"
)
final_args = [

View File

@@ -412,7 +412,7 @@ def test_cudagraph_sizes_post_init(
with (
ctx,
patch("vllm.config.parallel.cuda_device_count_stateless", return_value=tp_size),
patch.object(current_platform, "device_count", return_value=tp_size),
):
kwargs = {}
if cudagraph_capture_sizes is not None:

View File

@@ -13,7 +13,6 @@ from vllm.distributed.utils import StatelessProcessGroup
from vllm.platforms import current_platform
from vllm.utils.network_utils import get_open_port
from vllm.utils.system_utils import update_environment_variables
from vllm.utils.torch_utils import cuda_device_count_stateless
from ..utils import multi_gpu_test
@@ -21,7 +20,7 @@ from ..utils import multi_gpu_test
@ray.remote
class _CUDADeviceCountStatelessTestActor:
def get_count(self):
return cuda_device_count_stateless()
return current_platform.device_count()
def set_cuda_visible_devices(self, cuda_visible_devices: str):
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})

View File

@@ -15,7 +15,7 @@ from vllm.config import VllmConfig, set_current_vllm_config
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm
from vllm.utils.torch_utils import cuda_device_count_stateless, set_random_seed
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.worker.workspace import init_workspace_manager
from .modular_kernel_tools.common import (
@@ -310,10 +310,10 @@ def test_modular_kernel_combinations_multigpu(
world_size: int,
pytestconfig,
):
if cuda_device_count_stateless() < world_size:
if current_platform.device_count() < world_size:
pytest.skip(
f"Not enough GPUs available to run, got "
f"{cuda_device_count_stateless()} expected "
f"{current_platform.device_count()} expected "
f"{world_size}."
)

View File

@@ -19,7 +19,6 @@ from vllm.model_executor.model_loader.reload.meta import (
from vllm.model_executor.model_loader.reload.types import LayerReloadingInfo
from vllm.model_executor.model_loader.reload.utils import get_layer_tensors
from vllm.platforms import current_platform
from vllm.utils.torch_utils import cuda_device_count_stateless
def test_move_metatensors():
@@ -140,7 +139,7 @@ def test_get_numel_loaded():
],
)
def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
if cuda_device_count_stateless() < tp_size:
if current_platform.device_count() < tp_size:
pytest.skip(reason="Not enough CUDA devices")
if "FP8" in base_model and not current_platform.supports_fp8():
@@ -206,8 +205,8 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
def test_online_quantize_reload(
base_model, mul_model, add_model, quantization, tp_size, vllm_runner
):
if cuda_device_count_stateless() < tp_size:
pytest.skip(reason="Not enough CUDA devices")
if current_platform.device_count() < tp_size:
pytest.skip(reason="Not enough GPU devices")
if quantization == "fp8" and not current_platform.supports_fp8():
pytest.skip(reason="Requires FP8 support")

View File

@@ -21,8 +21,8 @@ import lm_eval
import pytest
from packaging import version
from vllm.platforms import current_platform
from vllm.platforms.rocm import on_gfx950
from vllm.utils.torch_utils import cuda_device_count_stateless
MODEL_ACCURACIES = {
# Full quantization: attention linears and MoE linears
@@ -89,7 +89,7 @@ def test_gpt_oss_attention_quantization(
expected_accuracy: float,
monkeypatch: pytest.MonkeyPatch,
):
if tp_size > cuda_device_count_stateless():
if tp_size > current_platform.device_count():
pytest.skip("Not enough GPUs to run this test case")
if "amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8" in model_name and on_gfx950():

View File

@@ -58,7 +58,6 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.mem_constants import GB_bytes
from vllm.utils.network_utils import get_open_port
from vllm.utils.torch_utils import (
cuda_device_count_stateless,
set_random_seed, # noqa: F401 - re-exported for use in test files
)
@@ -384,7 +383,7 @@ class RemoteVLLMServer:
elif current_platform.is_cuda():
with _nvml():
total_used = 0
device_count = cuda_device_count_stateless()
device_count = current_platform.device_count()
for i in range(device_count):
handle = nvmlDeviceGetHandleByIndex(i)
mem_info = nvmlDeviceGetMemoryInfo(handle)
@@ -1497,7 +1496,7 @@ def multi_gpu_marks(*, num_gpus: int):
"""Get a collection of pytest marks to apply for `@multi_gpu_test`."""
test_selector = pytest.mark.distributed(num_gpus=num_gpus)
test_skipif = pytest.mark.skipif(
cuda_device_count_stateless() < num_gpus,
current_platform.device_count() < num_gpus,
reason=f"Need at least {num_gpus} GPUs to run the test.",
)
@@ -1529,7 +1528,7 @@ def gpu_tier_mark(*, min_gpus: int = 1, max_gpus: int | None = None):
@gpu_tier_mark(max_gpus=1) # only on single-GPU
@gpu_tier_mark(min_gpus=2, max_gpus=4) # 2-4 GPUs only
"""
gpu_count = cuda_device_count_stateless()
gpu_count = current_platform.device_count()
marks = []
if min_gpus > 1:

View File

@@ -11,8 +11,8 @@ from tests.v1.shutdown.utils import (
)
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.utils.torch_utils import cuda_device_count_stateless
from vllm.v1.engine.async_llm import AsyncLLM
MODELS = ["hmellor/tiny-random-LlamaForCausalLM"]
@@ -34,7 +34,7 @@ async def test_async_llm_delete(
tensor_parallel_size: degree of tensor parallelism
send_one_request: send one request to engine before deleting
"""
if cuda_device_count_stateless() < tensor_parallel_size:
if current_platform.device_count() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")
engine_args = AsyncEngineArgs(
@@ -83,7 +83,7 @@ def test_llm_delete(
enable_multiprocessing: enable workers in separate process(es)
send_one_request: send one request to engine before deleting
"""
if cuda_device_count_stateless() < tensor_parallel_size:
if current_platform.device_count() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")
with monkeypatch.context() as m:

View File

@@ -15,7 +15,7 @@ from tests.v1.shutdown.utils import (
from vllm import LLM, AsyncEngineArgs, SamplingParams
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.utils.torch_utils import cuda_device_count_stateless
from vllm.platforms import current_platform
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.exceptions import EngineDeadError
@@ -60,7 +60,7 @@ async def test_async_llm_model_error(
AsyncLLM always uses an MP client.
"""
if cuda_device_count_stateless() < tensor_parallel_size:
if current_platform.device_count() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")
# Monkeypatch an error in the model.
@@ -126,7 +126,7 @@ def test_llm_model_error(
TODO(andy) - LLM without multiprocessing; LLM with multiprocessing
and >1 rank
"""
if cuda_device_count_stateless() < tensor_parallel_size:
if current_platform.device_count() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")
with monkeypatch.context() as m:

View File

@@ -15,7 +15,7 @@ from vllm import LLM
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.utils.torch_utils import cuda_device_count_stateless
from vllm.platforms import current_platform
from vllm.v1.engine.async_llm import AsyncLLM
MODELS = ["hmellor/tiny-random-LlamaForCausalLM"]
@@ -57,7 +57,7 @@ def test_async_llm_startup_error(
Test profiling (forward()) and load weights failures.
AsyncLLM always uses an MP client.
"""
if cuda_device_count_stateless() < tensor_parallel_size:
if current_platform.device_count() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")
# Monkeypatch an error in the model.
@@ -99,7 +99,7 @@ def test_llm_startup_error(
# If MODELS list grows, each architecture needs its own test variant.
if model != "JackFram/llama-68m":
pytest.skip(reason="Only test JackFram/llama-68m")
if cuda_device_count_stateless() < tensor_parallel_size:
if current_platform.device_count() < tensor_parallel_size:
pytest.skip(reason="Not enough CUDA devices")
with monkeypatch.context() as m:

View File

@@ -10,6 +10,8 @@ import regex as re
_TORCH_CUDA_PATTERNS = [
r"\btorch\.cuda\.(empty_cache|synchronize|device_count|current_device|memory_reserved|memory_allocated|max_memory_allocated|max_memory_reserved|reset_peak_memory_stats|memory_stats|set_device|device\()\b",
r"\bwith\storch\.cuda\.device\b",
# Calls torch.cuda.{_is_compiled/_device_count_amdsmi/_device_count_nvml} internally
r"\bcuda_device_count_stateless\(\)\b",
]
ALLOWED_FILES = {"vllm/platforms/", "vllm/device_allocator/"}

View File

@@ -16,7 +16,6 @@ from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.network_utils import get_open_ports_list
from vllm.utils.torch_utils import cuda_device_count_stateless
if TYPE_CHECKING:
from ray.runtime_env import RuntimeEnv
@@ -726,9 +725,9 @@ class ParallelConfig:
backend = "mp"
elif (
current_platform.is_cuda()
and cuda_device_count_stateless() < self.world_size
and current_platform.device_count() < self.world_size
):
gpu_count = cuda_device_count_stateless()
gpu_count = current_platform.device_count()
raise ValueError(
f"World size ({self.world_size}) is larger than the number of "
f"available GPUs ({gpu_count}) in this node. If this is "

View File

@@ -19,8 +19,8 @@ import torch.multiprocessing as mp
import vllm.envs as envs
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables
from vllm.utils.torch_utils import cuda_device_count_stateless
logger = init_logger(__name__)
@@ -320,7 +320,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
is_distributed = dist.is_initialized()
num_dev = cuda_device_count_stateless()
num_dev = current_platform.device_count()
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices is None:
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))

View File

@@ -17,7 +17,6 @@ from vllm.distributed.device_communicators.all_reduce_utils import (
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.torch_utils import cuda_device_count_stateless
try:
ops.meta_size()
@@ -135,7 +134,7 @@ class CustomAllreduce:
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
else:
device_ids = list(range(cuda_device_count_stateless()))
device_ids = list(range(current_platform.device_count()))
physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")

View File

@@ -13,7 +13,6 @@ from vllm.config import get_current_vllm_config_or_none
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.torch_utils import cuda_device_count_stateless
logger = init_logger(__name__)
@@ -137,7 +136,7 @@ class QuickAllReduce:
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
else:
device_ids = list(range(cuda_device_count_stateless()))
device_ids = list(range(current_platform.device_count()))
physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
gather_list = [

View File

@@ -9,7 +9,7 @@ from __future__ import annotations
import os
from collections.abc import Callable
from datetime import timedelta
from functools import cache, wraps
from functools import cache, lru_cache, wraps
from typing import TYPE_CHECKING, TypeVar
import torch
@@ -20,9 +20,9 @@ from typing_extensions import ParamSpec
# import custom ops, trigger op registration
import vllm._C # noqa
import vllm._C_stable_libtorch # noqa
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils.import_utils import import_pynvml
from vllm.utils.torch_utils import cuda_device_count_stateless
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from .interface import DeviceCapability, Platform, PlatformEnum
@@ -47,6 +47,32 @@ pynvml = import_pynvml()
torch.backends.cuda.enable_cudnn_sdp(False)
@lru_cache(maxsize=8)
def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
"""Get number of CUDA devices, caching based on the value of CUDA_VISIBLE_DEVICES
at the time of call.
This should be used instead of torch.accelerator.device_count() unless
CUDA_VISIBLE_DEVICES has already been set to the desired value.
# This can be removed and simply replaced with torch.cuda.get_device_count
# after https://github.com/pytorch/pytorch/pull/122815 is released."""
# Note: cuda_visible_devices is not used, but we keep it as an argument for
# LRU Cache purposes.
# Code below is based on
# https://github.com/pytorch/pytorch/blob/
# c1cd946818442aca8c7f812b16d187ce1586c3bc/
# torch/cuda/__init__.py#L831C1-L831C17
import torch.cuda
if not torch.cuda._is_compiled():
return 0
raw_count = torch.cuda._device_count_nvml()
r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
return r
@cache
def _get_backend_priorities(
use_mla: bool,
@@ -456,7 +482,7 @@ class CudaPlatformBase(Platform):
@classmethod
def device_count(cls) -> int:
return cuda_device_count_stateless()
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
@classmethod
def check_if_supports_dtype(cls, dtype: torch.dtype):

View File

@@ -13,7 +13,6 @@ from torch.distributed.distributed_c10d import is_nccl_available
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils.torch_utils import cuda_device_count_stateless
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from .interface import DeviceCapability, Platform, PlatformEnum
@@ -67,6 +66,38 @@ _ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
}
@lru_cache(maxsize=8)
def _rocm_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
"""Get number of ROCm devices, caching based on the value of CUDA_VISIBLE_DEVICES
at the time of call.
This should be used instead of torch.accelerator.device_count() unless
CUDA_VISIBLE_DEVICES has already been set to the desired value.
# This can be removed and simply replaced with torch.cuda.get_device_count
# after https://github.com/pytorch/pytorch/pull/122815 is released."""
# Note: cuda_visible_devices is not used, but we keep it as an argument for
# LRU Cache purposes.
# Code below is based on
# https://github.com/pytorch/pytorch/blob/
# c1cd946818442aca8c7f812b16d187ce1586c3bc/
# torch/cuda/__init__.py#L831C1-L831C17
import torch.cuda
if not torch.cuda._is_compiled():
return 0
# ROCm uses amdsmi instead of nvml for stateless device count
# This requires a sufficiently modern version of Torch 2.4.0
raw_count = (
torch.cuda._device_count_amdsmi()
if (hasattr(torch.cuda, "_device_count_amdsmi"))
else -1
)
r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
return r
def _sync_hip_cuda_env_vars():
"""Ensure HIP_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES are consistent.
Treats empty string as unset. Raises on genuine conflicts."""
@@ -810,7 +841,7 @@ class RocmPlatform(Platform):
@classmethod
def device_count(cls) -> int:
return cuda_device_count_stateless()
return _rocm_device_count_stateless(getattr(envs, cls.device_control_env_var))
@classmethod
def check_if_supports_dtype(cls, dtype: torch.dtype):

View File

@@ -22,7 +22,6 @@ import vllm.envs as envs
from vllm.connections import global_http_connection
from vllm.logger import init_logger
from vllm.utils.platform_utils import cuda_get_device_properties
from vllm.utils.torch_utils import cuda_device_count_stateless
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
@@ -196,7 +195,7 @@ class UsageMessage:
from vllm.platforms import current_platform
if current_platform.is_cuda_alike():
self.gpu_count = cuda_device_count_stateless()
self.gpu_count = current_platform.device_count()
self.gpu_type, self.gpu_memory_per_device = cuda_get_device_properties(
0, ("name", "total_memory")
)

View File

@@ -6,7 +6,6 @@ import os
import random
import threading
from collections.abc import Callable, Collection
from functools import lru_cache
from typing import TYPE_CHECKING, Any, TypeVar
import numpy as np
@@ -16,7 +15,6 @@ from packaging import version
from packaging.version import Version
from torch.library import Library, infer_schema
import vllm.envs as envs
from vllm.logger import init_logger
if TYPE_CHECKING:
@@ -590,49 +588,6 @@ def aux_stream() -> torch.cuda.Stream | None:
return _aux_stream
@lru_cache(maxsize=8)
def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
# Note: cuda_visible_devices is not used, but we keep it as an argument for
# LRU Cache purposes.
# Code below is based on
# https://github.com/pytorch/pytorch/blob/
# c1cd946818442aca8c7f812b16d187ce1586c3bc/
# torch/cuda/__init__.py#L831C1-L831C17
import torch.cuda
import torch.version
from vllm.platforms import current_platform
if not torch.cuda._is_compiled():
return 0
if current_platform.is_rocm():
# ROCm uses amdsmi instead of nvml for stateless device count
# This requires a sufficiently modern version of Torch 2.4.0
raw_count = (
torch.cuda._device_count_amdsmi()
if (hasattr(torch.cuda, "_device_count_amdsmi"))
else -1
)
else:
raw_count = torch.cuda._device_count_nvml()
r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
return r
def cuda_device_count_stateless() -> int:
"""Get number of CUDA devices, caching based on the value of
CUDA_VISIBLE_DEVICES at the time of call.
This should be used instead of torch.accelerator.device_count()
unless CUDA_VISIBLE_DEVICES has already been set to the desired
value."""
# This can be removed and simply replaced with torch.cuda.get_device_count
# after https://github.com/pytorch/pytorch/pull/122815 is released.
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
def weak_ref_tensor(tensor: Any) -> Any:
"""
Create a weak reference to a tensor.

View File

@@ -369,9 +369,7 @@ def initialize_ray_cluster(
# Prevalidate GPU requirements before Ray processing
if current_platform.is_cuda() and parallel_config.world_size > 1:
from vllm.utils.torch_utils import cuda_device_count_stateless
available_gpus = cuda_device_count_stateless()
available_gpus = current_platform.device_count()
if parallel_config.world_size > available_gpus:
logger.warning(
"Tensor parallel size (%d) exceeds available GPUs (%d). "