refactor hard coded device string in test files under tests/v1 and tests/lora (#37566)
Signed-off-by: Liao, Wei <wei.liao@intel.com>
This commit is contained in:
@@ -5,11 +5,14 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.spec_decode.utils import (
|
||||
PADDING_SLOT_ID,
|
||||
eagle_step_update_slot_mapping_and_metadata,
|
||||
)
|
||||
|
||||
DEVICE_TYPE = current_platform.device_type
|
||||
|
||||
# Skip if no CUDA - Triton kernel requires GPU
|
||||
pytest.importorskip("triton")
|
||||
if not torch.cuda.is_available():
|
||||
@@ -47,7 +50,7 @@ def _reference_eagle_step_slot_mapping(
|
||||
|
||||
def test_eagle_step_slot_mapping_kernel():
|
||||
"""Test fused kernel matches Python reference for slot mapping and metadata."""
|
||||
device = torch.device("cuda")
|
||||
device = torch.device(DEVICE_TYPE)
|
||||
batch_size = 32
|
||||
block_size = 16
|
||||
max_model_len = 4096
|
||||
@@ -93,7 +96,7 @@ def test_eagle_step_slot_mapping_kernel():
|
||||
|
||||
def test_eagle_step_slot_mapping_kernel_exceeds_max():
|
||||
"""Test fused kernel when position exceeds max_model_len."""
|
||||
device = torch.device("cuda")
|
||||
device = torch.device(DEVICE_TYPE)
|
||||
batch_size = 4
|
||||
block_size = 16
|
||||
max_model_len = 100
|
||||
@@ -130,7 +133,7 @@ def test_eagle_step_slot_mapping_kernel_exceeds_max():
|
||||
def test_eagle_step_slot_mapping_kernel_cudagraph_padding():
|
||||
"""Test that padding threads write PADDING_SLOT_ID when
|
||||
input_batch_size > batch_size (cudagraph padding)."""
|
||||
device = torch.device("cuda")
|
||||
device = torch.device(DEVICE_TYPE)
|
||||
batch_size = 4
|
||||
input_batch_size = 8
|
||||
block_size = 16
|
||||
|
||||
Reference in New Issue
Block a user