Files
vllm/tests/compile/conftest.py
Jason Li 9d37941017 [torch.compile] Sequence Parallelism threshold compile ranges (#28672)
Signed-off-by: jasonlizhengjian <jasonlizhengjian@gmail.com>
Signed-off-by: Jason Li <jasonlizhengjian@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
2026-02-26 05:00:12 +00:00

35 lines
1.0 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
from unittest.mock import MagicMock, patch
import pytest
from vllm.platforms.interface import DeviceCapability
@pytest.fixture
def mock_cuda_platform():
"""
Fixture that returns a factory for creating mocked CUDA platforms.
Usage:
def test_something(mock_cuda_platform):
with mock_cuda_platform(is_cuda=True, capability=(9, 0)):
# test code
"""
@contextmanager
def _mock_platform(is_cuda: bool = True, capability: tuple[int, int] | None = None):
mock_platform = MagicMock()
mock_platform.is_cuda.return_value = is_cuda
if capability is not None:
mock_platform.get_device_capability.return_value = DeviceCapability(
*capability
)
with patch("vllm.platforms.current_platform", mock_platform):
yield mock_platform
return _mock_platform