replace cuda_device_count_stateless() to current_platform.device_count() (#37841)
Signed-off-by: Liao, Wei <wei.liao@intel.com> Signed-off-by: wliao2 <wei.liao@intel.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -6,7 +6,6 @@ import pytest
|
||||
|
||||
from vllm.config import CompilationMode
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
from ...utils import compare_all_settings
|
||||
|
||||
@@ -109,10 +108,10 @@ def test_compile_correctness(
|
||||
tp_size = test_setting.tp_size
|
||||
attn_backend = test_setting.attn_backend
|
||||
method = test_setting.method
|
||||
if cuda_device_count_stateless() < pp_size * tp_size:
|
||||
if current_platform.device_count() < pp_size * tp_size:
|
||||
pytest.skip(
|
||||
f"Need at least {pp_size}*{tp_size} CUDA gpus but got "
|
||||
f"{cuda_device_count_stateless()}"
|
||||
f"{current_platform.device_count()}"
|
||||
)
|
||||
|
||||
final_args = [
|
||||
|
||||
@@ -412,7 +412,7 @@ def test_cudagraph_sizes_post_init(
|
||||
|
||||
with (
|
||||
ctx,
|
||||
patch("vllm.config.parallel.cuda_device_count_stateless", return_value=tp_size),
|
||||
patch.object(current_platform, "device_count", return_value=tp_size),
|
||||
):
|
||||
kwargs = {}
|
||||
if cudagraph_capture_sizes is not None:
|
||||
|
||||
@@ -13,7 +13,6 @@ from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
from ..utils import multi_gpu_test
|
||||
|
||||
@@ -21,7 +20,7 @@ from ..utils import multi_gpu_test
|
||||
@ray.remote
|
||||
class _CUDADeviceCountStatelessTestActor:
|
||||
def get_count(self):
|
||||
return cuda_device_count_stateless()
|
||||
return current_platform.device_count()
|
||||
|
||||
def set_cuda_visible_devices(self, cuda_visible_devices: str):
|
||||
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||
|
||||
@@ -15,7 +15,7 @@ from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless, set_random_seed
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
from .modular_kernel_tools.common import (
|
||||
@@ -310,10 +310,10 @@ def test_modular_kernel_combinations_multigpu(
|
||||
world_size: int,
|
||||
pytestconfig,
|
||||
):
|
||||
if cuda_device_count_stateless() < world_size:
|
||||
if current_platform.device_count() < world_size:
|
||||
pytest.skip(
|
||||
f"Not enough GPUs available to run, got "
|
||||
f"{cuda_device_count_stateless()} expected "
|
||||
f"{current_platform.device_count()} expected "
|
||||
f"{world_size}."
|
||||
)
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from vllm.model_executor.model_loader.reload.meta import (
|
||||
from vllm.model_executor.model_loader.reload.types import LayerReloadingInfo
|
||||
from vllm.model_executor.model_loader.reload.utils import get_layer_tensors
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
|
||||
def test_move_metatensors():
|
||||
@@ -140,7 +139,7 @@ def test_get_numel_loaded():
|
||||
],
|
||||
)
|
||||
def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
|
||||
if cuda_device_count_stateless() < tp_size:
|
||||
if current_platform.device_count() < tp_size:
|
||||
pytest.skip(reason="Not enough CUDA devices")
|
||||
|
||||
if "FP8" in base_model and not current_platform.supports_fp8():
|
||||
@@ -206,8 +205,8 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
|
||||
def test_online_quantize_reload(
|
||||
base_model, mul_model, add_model, quantization, tp_size, vllm_runner
|
||||
):
|
||||
if cuda_device_count_stateless() < tp_size:
|
||||
pytest.skip(reason="Not enough CUDA devices")
|
||||
if current_platform.device_count() < tp_size:
|
||||
pytest.skip(reason="Not enough GPU devices")
|
||||
|
||||
if quantization == "fp8" and not current_platform.supports_fp8():
|
||||
pytest.skip(reason="Requires FP8 support")
|
||||
|
||||
@@ -21,8 +21,8 @@ import lm_eval
|
||||
import pytest
|
||||
from packaging import version
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.rocm import on_gfx950
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
MODEL_ACCURACIES = {
|
||||
# Full quantization: attention linears and MoE linears
|
||||
@@ -89,7 +89,7 @@ def test_gpt_oss_attention_quantization(
|
||||
expected_accuracy: float,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
if tp_size > cuda_device_count_stateless():
|
||||
if tp_size > current_platform.device_count():
|
||||
pytest.skip("Not enough GPUs to run this test case")
|
||||
|
||||
if "amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8" in model_name and on_gfx950():
|
||||
|
||||
@@ -58,7 +58,6 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.mem_constants import GB_bytes
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
from vllm.utils.torch_utils import (
|
||||
cuda_device_count_stateless,
|
||||
set_random_seed, # noqa: F401 - re-exported for use in test files
|
||||
)
|
||||
|
||||
@@ -384,7 +383,7 @@ class RemoteVLLMServer:
|
||||
elif current_platform.is_cuda():
|
||||
with _nvml():
|
||||
total_used = 0
|
||||
device_count = cuda_device_count_stateless()
|
||||
device_count = current_platform.device_count()
|
||||
for i in range(device_count):
|
||||
handle = nvmlDeviceGetHandleByIndex(i)
|
||||
mem_info = nvmlDeviceGetMemoryInfo(handle)
|
||||
@@ -1497,7 +1496,7 @@ def multi_gpu_marks(*, num_gpus: int):
|
||||
"""Get a collection of pytest marks to apply for `@multi_gpu_test`."""
|
||||
test_selector = pytest.mark.distributed(num_gpus=num_gpus)
|
||||
test_skipif = pytest.mark.skipif(
|
||||
cuda_device_count_stateless() < num_gpus,
|
||||
current_platform.device_count() < num_gpus,
|
||||
reason=f"Need at least {num_gpus} GPUs to run the test.",
|
||||
)
|
||||
|
||||
@@ -1529,7 +1528,7 @@ def gpu_tier_mark(*, min_gpus: int = 1, max_gpus: int | None = None):
|
||||
@gpu_tier_mark(max_gpus=1) # only on single-GPU
|
||||
@gpu_tier_mark(min_gpus=2, max_gpus=4) # 2-4 GPUs only
|
||||
"""
|
||||
gpu_count = cuda_device_count_stateless()
|
||||
gpu_count = current_platform.device_count()
|
||||
marks = []
|
||||
|
||||
if min_gpus > 1:
|
||||
|
||||
@@ -11,8 +11,8 @@ from tests.v1.shutdown.utils import (
|
||||
)
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
MODELS = ["hmellor/tiny-random-LlamaForCausalLM"]
|
||||
@@ -34,7 +34,7 @@ async def test_async_llm_delete(
|
||||
tensor_parallel_size: degree of tensor parallelism
|
||||
send_one_request: send one request to engine before deleting
|
||||
"""
|
||||
if cuda_device_count_stateless() < tensor_parallel_size:
|
||||
if current_platform.device_count() < tensor_parallel_size:
|
||||
pytest.skip(reason="Not enough CUDA devices")
|
||||
|
||||
engine_args = AsyncEngineArgs(
|
||||
@@ -83,7 +83,7 @@ def test_llm_delete(
|
||||
enable_multiprocessing: enable workers in separate process(es)
|
||||
send_one_request: send one request to engine before deleting
|
||||
"""
|
||||
if cuda_device_count_stateless() < tensor_parallel_size:
|
||||
if current_platform.device_count() < tensor_parallel_size:
|
||||
pytest.skip(reason="Not enough CUDA devices")
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
|
||||
@@ -15,7 +15,7 @@ from tests.v1.shutdown.utils import (
|
||||
from vllm import LLM, AsyncEngineArgs, SamplingParams
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
from vllm.v1.engine.exceptions import EngineDeadError
|
||||
|
||||
@@ -60,7 +60,7 @@ async def test_async_llm_model_error(
|
||||
|
||||
AsyncLLM always uses an MP client.
|
||||
"""
|
||||
if cuda_device_count_stateless() < tensor_parallel_size:
|
||||
if current_platform.device_count() < tensor_parallel_size:
|
||||
pytest.skip(reason="Not enough CUDA devices")
|
||||
|
||||
# Monkeypatch an error in the model.
|
||||
@@ -126,7 +126,7 @@ def test_llm_model_error(
|
||||
TODO(andy) - LLM without multiprocessing; LLM with multiprocessing
|
||||
and >1 rank
|
||||
"""
|
||||
if cuda_device_count_stateless() < tensor_parallel_size:
|
||||
if current_platform.device_count() < tensor_parallel_size:
|
||||
pytest.skip(reason="Not enough CUDA devices")
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
|
||||
@@ -15,7 +15,7 @@ from vllm import LLM
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
|
||||
MODELS = ["hmellor/tiny-random-LlamaForCausalLM"]
|
||||
@@ -57,7 +57,7 @@ def test_async_llm_startup_error(
|
||||
Test profiling (forward()) and load weights failures.
|
||||
AsyncLLM always uses an MP client.
|
||||
"""
|
||||
if cuda_device_count_stateless() < tensor_parallel_size:
|
||||
if current_platform.device_count() < tensor_parallel_size:
|
||||
pytest.skip(reason="Not enough CUDA devices")
|
||||
|
||||
# Monkeypatch an error in the model.
|
||||
@@ -99,7 +99,7 @@ def test_llm_startup_error(
|
||||
# If MODELS list grows, each architecture needs its own test variant.
|
||||
if model != "JackFram/llama-68m":
|
||||
pytest.skip(reason="Only test JackFram/llama-68m")
|
||||
if cuda_device_count_stateless() < tensor_parallel_size:
|
||||
if current_platform.device_count() < tensor_parallel_size:
|
||||
pytest.skip(reason="Not enough CUDA devices")
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
|
||||
@@ -10,6 +10,8 @@ import regex as re
|
||||
_TORCH_CUDA_PATTERNS = [
|
||||
r"\btorch\.cuda\.(empty_cache|synchronize|device_count|current_device|memory_reserved|memory_allocated|max_memory_allocated|max_memory_reserved|reset_peak_memory_stats|memory_stats|set_device|device\()\b",
|
||||
r"\bwith\storch\.cuda\.device\b",
|
||||
# Calls torch.cuda.{_is_compiled/_device_count_amdsmi/_device_count_nvml} internally
|
||||
r"\bcuda_device_count_stateless\(\)\b",
|
||||
]
|
||||
|
||||
ALLOWED_FILES = {"vllm/platforms/", "vllm/device_allocator/"}
|
||||
|
||||
@@ -16,7 +16,6 @@ from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.network_utils import get_open_ports_list
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.runtime_env import RuntimeEnv
|
||||
@@ -726,9 +725,9 @@ class ParallelConfig:
|
||||
backend = "mp"
|
||||
elif (
|
||||
current_platform.is_cuda()
|
||||
and cuda_device_count_stateless() < self.world_size
|
||||
and current_platform.device_count() < self.world_size
|
||||
):
|
||||
gpu_count = cuda_device_count_stateless()
|
||||
gpu_count = current_platform.device_count()
|
||||
raise ValueError(
|
||||
f"World size ({self.world_size}) is larger than the number of "
|
||||
f"available GPUs ({gpu_count}) in this node. If this is "
|
||||
|
||||
@@ -19,8 +19,8 @@ import torch.multiprocessing as mp
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -320,7 +320,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
||||
|
||||
is_distributed = dist.is_initialized()
|
||||
|
||||
num_dev = cuda_device_count_stateless()
|
||||
num_dev = current_platform.device_count()
|
||||
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
|
||||
if cuda_visible_devices is None:
|
||||
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
|
||||
|
||||
@@ -17,7 +17,6 @@ from vllm.distributed.device_communicators.all_reduce_utils import (
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
try:
|
||||
ops.meta_size()
|
||||
@@ -135,7 +134,7 @@ class CustomAllreduce:
|
||||
if cuda_visible_devices:
|
||||
device_ids = list(map(int, cuda_visible_devices.split(",")))
|
||||
else:
|
||||
device_ids = list(range(cuda_device_count_stateless()))
|
||||
device_ids = list(range(current_platform.device_count()))
|
||||
|
||||
physical_device_id = device_ids[device.index]
|
||||
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
|
||||
|
||||
@@ -13,7 +13,6 @@ from vllm.config import get_current_vllm_config_or_none
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -137,7 +136,7 @@ class QuickAllReduce:
|
||||
if cuda_visible_devices:
|
||||
device_ids = list(map(int, cuda_visible_devices.split(",")))
|
||||
else:
|
||||
device_ids = list(range(cuda_device_count_stateless()))
|
||||
device_ids = list(range(current_platform.device_count()))
|
||||
physical_device_id = device_ids[device.index]
|
||||
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
|
||||
gather_list = [
|
||||
|
||||
@@ -9,7 +9,7 @@ from __future__ import annotations
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from functools import cache, wraps
|
||||
from functools import cache, lru_cache, wraps
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
import torch
|
||||
@@ -20,9 +20,9 @@ from typing_extensions import ParamSpec
|
||||
# import custom ops, trigger op registration
|
||||
import vllm._C # noqa
|
||||
import vllm._C_stable_libtorch # noqa
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import import_pynvml
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
@@ -47,6 +47,32 @@ pynvml = import_pynvml()
|
||||
torch.backends.cuda.enable_cudnn_sdp(False)
|
||||
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
|
||||
"""Get number of CUDA devices, caching based on the value of CUDA_VISIBLE_DEVICES
|
||||
at the time of call.
|
||||
|
||||
This should be used instead of torch.accelerator.device_count() unless
|
||||
CUDA_VISIBLE_DEVICES has already been set to the desired value.
|
||||
|
||||
# This can be removed and simply replaced with torch.cuda.get_device_count
|
||||
# after https://github.com/pytorch/pytorch/pull/122815 is released."""
|
||||
# Note: cuda_visible_devices is not used, but we keep it as an argument for
|
||||
# LRU Cache purposes.
|
||||
|
||||
# Code below is based on
|
||||
# https://github.com/pytorch/pytorch/blob/
|
||||
# c1cd946818442aca8c7f812b16d187ce1586c3bc/
|
||||
# torch/cuda/__init__.py#L831C1-L831C17
|
||||
import torch.cuda
|
||||
|
||||
if not torch.cuda._is_compiled():
|
||||
return 0
|
||||
raw_count = torch.cuda._device_count_nvml()
|
||||
r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
|
||||
return r
|
||||
|
||||
|
||||
@cache
|
||||
def _get_backend_priorities(
|
||||
use_mla: bool,
|
||||
@@ -456,7 +482,7 @@ class CudaPlatformBase(Platform):
|
||||
|
||||
@classmethod
|
||||
def device_count(cls) -> int:
|
||||
return cuda_device_count_stateless()
|
||||
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
|
||||
|
||||
@classmethod
|
||||
def check_if_supports_dtype(cls, dtype: torch.dtype):
|
||||
|
||||
@@ -13,7 +13,6 @@ from torch.distributed.distributed_c10d import is_nccl_available
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
@@ -67,6 +66,38 @@ _ROCM_DEVICE_ID_NAME_MAP: dict[str, str] = {
|
||||
}
|
||||
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def _rocm_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
|
||||
"""Get number of ROCm devices, caching based on the value of CUDA_VISIBLE_DEVICES
|
||||
at the time of call.
|
||||
|
||||
This should be used instead of torch.accelerator.device_count() unless
|
||||
CUDA_VISIBLE_DEVICES has already been set to the desired value.
|
||||
|
||||
# This can be removed and simply replaced with torch.cuda.get_device_count
|
||||
# after https://github.com/pytorch/pytorch/pull/122815 is released."""
|
||||
# Note: cuda_visible_devices is not used, but we keep it as an argument for
|
||||
# LRU Cache purposes.
|
||||
|
||||
# Code below is based on
|
||||
# https://github.com/pytorch/pytorch/blob/
|
||||
# c1cd946818442aca8c7f812b16d187ce1586c3bc/
|
||||
# torch/cuda/__init__.py#L831C1-L831C17
|
||||
import torch.cuda
|
||||
|
||||
if not torch.cuda._is_compiled():
|
||||
return 0
|
||||
# ROCm uses amdsmi instead of nvml for stateless device count
|
||||
# This requires a sufficiently modern version of Torch 2.4.0
|
||||
raw_count = (
|
||||
torch.cuda._device_count_amdsmi()
|
||||
if (hasattr(torch.cuda, "_device_count_amdsmi"))
|
||||
else -1
|
||||
)
|
||||
r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
|
||||
return r
|
||||
|
||||
|
||||
def _sync_hip_cuda_env_vars():
|
||||
"""Ensure HIP_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES are consistent.
|
||||
Treats empty string as unset. Raises on genuine conflicts."""
|
||||
@@ -810,7 +841,7 @@ class RocmPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
def device_count(cls) -> int:
|
||||
return cuda_device_count_stateless()
|
||||
return _rocm_device_count_stateless(getattr(envs, cls.device_control_env_var))
|
||||
|
||||
@classmethod
|
||||
def check_if_supports_dtype(cls, dtype: torch.dtype):
|
||||
|
||||
@@ -22,7 +22,6 @@ import vllm.envs as envs
|
||||
from vllm.connections import global_http_connection
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.platform_utils import cuda_get_device_properties
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -196,7 +195,7 @@ class UsageMessage:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
self.gpu_count = cuda_device_count_stateless()
|
||||
self.gpu_count = current_platform.device_count()
|
||||
self.gpu_type, self.gpu_memory_per_device = cuda_get_device_properties(
|
||||
0, ("name", "total_memory")
|
||||
)
|
||||
|
||||
@@ -6,7 +6,6 @@ import os
|
||||
import random
|
||||
import threading
|
||||
from collections.abc import Callable, Collection
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
import numpy as np
|
||||
@@ -16,7 +15,6 @@ from packaging import version
|
||||
from packaging.version import Version
|
||||
from torch.library import Library, infer_schema
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -590,49 +588,6 @@ def aux_stream() -> torch.cuda.Stream | None:
|
||||
return _aux_stream
|
||||
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def _cuda_device_count_stateless(cuda_visible_devices: str | None = None) -> int:
|
||||
# Note: cuda_visible_devices is not used, but we keep it as an argument for
|
||||
# LRU Cache purposes.
|
||||
|
||||
# Code below is based on
|
||||
# https://github.com/pytorch/pytorch/blob/
|
||||
# c1cd946818442aca8c7f812b16d187ce1586c3bc/
|
||||
# torch/cuda/__init__.py#L831C1-L831C17
|
||||
import torch.cuda
|
||||
import torch.version
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not torch.cuda._is_compiled():
|
||||
return 0
|
||||
if current_platform.is_rocm():
|
||||
# ROCm uses amdsmi instead of nvml for stateless device count
|
||||
# This requires a sufficiently modern version of Torch 2.4.0
|
||||
raw_count = (
|
||||
torch.cuda._device_count_amdsmi()
|
||||
if (hasattr(torch.cuda, "_device_count_amdsmi"))
|
||||
else -1
|
||||
)
|
||||
else:
|
||||
raw_count = torch.cuda._device_count_nvml()
|
||||
r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
|
||||
return r
|
||||
|
||||
|
||||
def cuda_device_count_stateless() -> int:
|
||||
"""Get number of CUDA devices, caching based on the value of
|
||||
CUDA_VISIBLE_DEVICES at the time of call.
|
||||
|
||||
This should be used instead of torch.accelerator.device_count()
|
||||
unless CUDA_VISIBLE_DEVICES has already been set to the desired
|
||||
value."""
|
||||
|
||||
# This can be removed and simply replaced with torch.cuda.get_device_count
|
||||
# after https://github.com/pytorch/pytorch/pull/122815 is released.
|
||||
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
|
||||
|
||||
|
||||
def weak_ref_tensor(tensor: Any) -> Any:
|
||||
"""
|
||||
Create a weak reference to a tensor.
|
||||
|
||||
@@ -369,9 +369,7 @@ def initialize_ray_cluster(
|
||||
|
||||
# Prevalidate GPU requirements before Ray processing
|
||||
if current_platform.is_cuda() and parallel_config.world_size > 1:
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
available_gpus = cuda_device_count_stateless()
|
||||
available_gpus = current_platform.device_count()
|
||||
if parallel_config.world_size > available_gpus:
|
||||
logger.warning(
|
||||
"Tensor parallel size (%d) exceeds available GPUs (%d). "
|
||||
|
||||
Reference in New Issue
Block a user