Files
vllm/tests/compile/conftest.py

35 lines
1.0 KiB
Python
Raw Normal View History

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
from unittest.mock import MagicMock, patch
import pytest
from vllm.platforms.interface import DeviceCapability
@pytest.fixture
def mock_cuda_platform():
"""
Fixture that returns a factory for creating mocked CUDA platforms.
Usage:
def test_something(mock_cuda_platform):
with mock_cuda_platform(is_cuda=True, capability=(9, 0)):
# test code
"""
@contextmanager
def _mock_platform(is_cuda: bool = True, capability: tuple[int, int] | None = None):
mock_platform = MagicMock()
mock_platform.is_cuda.return_value = is_cuda
if capability is not None:
mock_platform.get_device_capability.return_value = DeviceCapability(
*capability
)
with patch("vllm.platforms.current_platform", mock_platform):
yield mock_platform
return _mock_platform