Signed-off-by: jasonlizhengjian <jasonlizhengjian@gmail.com> Signed-off-by: Jason Li <jasonlizhengjian@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
35 lines
1.0 KiB
Python
35 lines
1.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from contextlib import contextmanager
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from vllm.platforms.interface import DeviceCapability
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_cuda_platform():
|
|
"""
|
|
Fixture that returns a factory for creating mocked CUDA platforms.
|
|
|
|
Usage:
|
|
def test_something(mock_cuda_platform):
|
|
with mock_cuda_platform(is_cuda=True, capability=(9, 0)):
|
|
# test code
|
|
"""
|
|
|
|
@contextmanager
|
|
def _mock_platform(is_cuda: bool = True, capability: tuple[int, int] | None = None):
|
|
mock_platform = MagicMock()
|
|
mock_platform.is_cuda.return_value = is_cuda
|
|
if capability is not None:
|
|
mock_platform.get_device_capability.return_value = DeviceCapability(
|
|
*capability
|
|
)
|
|
with patch("vllm.platforms.current_platform", mock_platform):
|
|
yield mock_platform
|
|
|
|
return _mock_platform
|