refactor hard coded device string in test files under tests/v1 and tests/lora (#37566)
Signed-off-by: Liao, Wei <wei.liao@intel.com>
This commit is contained in:
@@ -7,8 +7,7 @@ from torch import Generator
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_pytorch
|
||||
|
||||
CUDA_DEVICE = "cuda" if current_platform.is_cuda() else None
|
||||
DEVICE = current_platform.device_type
|
||||
DEVICE_TYPE = current_platform.device_type
|
||||
|
||||
BATCH_SIZE = 1024
|
||||
VOCAB_SIZE = 128 * 1024
|
||||
@@ -26,8 +25,8 @@ def reset_default_device():
|
||||
|
||||
|
||||
def test_topk_impl_equivalence():
|
||||
torch.set_default_device(DEVICE)
|
||||
generator = Generator(device=DEVICE).manual_seed(33)
|
||||
torch.set_default_device(DEVICE_TYPE)
|
||||
generator = Generator(device=DEVICE_TYPE).manual_seed(33)
|
||||
|
||||
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
|
||||
|
||||
@@ -76,8 +75,8 @@ def test_flashinfer_sampler():
|
||||
if not FLASHINFER_ENABLED:
|
||||
pytest.skip("FlashInfer not installed or not available on this platform.")
|
||||
|
||||
torch.set_default_device(DEVICE)
|
||||
generator = Generator(device=DEVICE).manual_seed(42)
|
||||
torch.set_default_device(DEVICE_TYPE)
|
||||
generator = Generator(device=DEVICE_TYPE).manual_seed(42)
|
||||
|
||||
# Generate random logits
|
||||
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
|
||||
@@ -128,15 +127,15 @@ def test_flashinfer_sampler():
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.skipif(CUDA_DEVICE is None, reason="CUDA not available")
|
||||
@pytest.mark.skipif("CPU" in DEVICE_TYPE, reason="CUDA/XPU not available")
|
||||
class TestTritonTopkTopp:
|
||||
"""Tests for the Triton top-k/top-p kernel."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self):
|
||||
"""Set up test fixtures."""
|
||||
torch.set_default_device(CUDA_DEVICE)
|
||||
self.generator = Generator(device=CUDA_DEVICE).manual_seed(42)
|
||||
torch.set_default_device(DEVICE_TYPE)
|
||||
self.generator = Generator(device=DEVICE_TYPE).manual_seed(42)
|
||||
|
||||
def _compare_results(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user