replace cuda_device_count_stateless() to current_platform.device_count() (#37841)
Signed-off-by: Liao, Wei <wei.liao@intel.com> Signed-off-by: wliao2 <wei.liao@intel.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -19,7 +19,6 @@ from vllm.model_executor.model_loader.reload.meta import (
|
||||
from vllm.model_executor.model_loader.reload.types import LayerReloadingInfo
|
||||
from vllm.model_executor.model_loader.reload.utils import get_layer_tensors
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
|
||||
def test_move_metatensors():
|
||||
@@ -140,7 +139,7 @@ def test_get_numel_loaded():
|
||||
],
|
||||
)
|
||||
def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
|
||||
if cuda_device_count_stateless() < tp_size:
|
||||
if current_platform.device_count() < tp_size:
|
||||
pytest.skip(reason="Not enough CUDA devices")
|
||||
|
||||
if "FP8" in base_model and not current_platform.supports_fp8():
|
||||
@@ -206,8 +205,8 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
|
||||
def test_online_quantize_reload(
|
||||
base_model, mul_model, add_model, quantization, tp_size, vllm_runner
|
||||
):
|
||||
if cuda_device_count_stateless() < tp_size:
|
||||
pytest.skip(reason="Not enough CUDA devices")
|
||||
if current_platform.device_count() < tp_size:
|
||||
pytest.skip(reason="Not enough GPU devices")
|
||||
|
||||
if quantization == "fp8" and not current_platform.supports_fp8():
|
||||
pytest.skip(reason="Requires FP8 support")
|
||||
|
||||
Reference in New Issue
Block a user