refactor hard coded device string in test files under tests/v1 and tests/lora (#37566)

Signed-off-by: Liao, Wei <wei.liao@intel.com>
This commit is contained in:
wliao2
2026-04-02 20:21:47 -07:00
committed by GitHub
parent 4a06e1246e
commit 32e0c0bfa2
28 changed files with 239 additions and 146 deletions

View File

@@ -7,8 +7,7 @@ from torch import Generator
from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_pytorch
CUDA_DEVICE = "cuda" if current_platform.is_cuda() else None
DEVICE = current_platform.device_type
DEVICE_TYPE = current_platform.device_type
BATCH_SIZE = 1024
VOCAB_SIZE = 128 * 1024
@@ -26,8 +25,8 @@ def reset_default_device():
def test_topk_impl_equivalence():
torch.set_default_device(DEVICE)
generator = Generator(device=DEVICE).manual_seed(33)
torch.set_default_device(DEVICE_TYPE)
generator = Generator(device=DEVICE_TYPE).manual_seed(33)
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
@@ -76,8 +75,8 @@ def test_flashinfer_sampler():
if not FLASHINFER_ENABLED:
pytest.skip("FlashInfer not installed or not available on this platform.")
torch.set_default_device(DEVICE)
generator = Generator(device=DEVICE).manual_seed(42)
torch.set_default_device(DEVICE_TYPE)
generator = Generator(device=DEVICE_TYPE).manual_seed(42)
# Generate random logits
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
@@ -128,15 +127,15 @@ def test_flashinfer_sampler():
# =============================================================================
@pytest.mark.skipif(CUDA_DEVICE is None, reason="CUDA not available")
@pytest.mark.skipif("CPU" in DEVICE_TYPE, reason="CUDA/XPU not available")
class TestTritonTopkTopp:
"""Tests for the Triton top-k/top-p kernel."""
@pytest.fixture(autouse=True)
def setup(self):
"""Set up test fixtures."""
torch.set_default_device(CUDA_DEVICE)
self.generator = Generator(device=CUDA_DEVICE).manual_seed(42)
torch.set_default_device(DEVICE_TYPE)
self.generator = Generator(device=DEVICE_TYPE).manual_seed(42)
def _compare_results(
self,