refactor hard coded device string in test files under tests/v1 and tests/lora (#37566)

Signed-off-by: Liao, Wei <wei.liao@intel.com>
This commit is contained in:
wliao2
2026-04-02 20:21:47 -07:00
committed by GitHub
parent 4a06e1246e
commit 32e0c0bfa2
28 changed files with 239 additions and 146 deletions

View File

@@ -64,6 +64,8 @@ SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
seq_lens=[256] * 2, query_lens=[256] * 2
)
DEVICE_TYPE = current_platform.device_type
def _float_to_e8m0_truncate(f: float) -> float:
"""Simulate SM100's float -> e8m0 -> bf16 scale conversion.
@@ -222,7 +224,7 @@ def test_sparse_backend_decode_correctness(
batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name]
use_fp8_ds_mla_quantization = kv_cache_dtype == "fp8_ds_mla"
device = torch.device("cuda")
device = torch.device(DEVICE_TYPE)
dtype = torch.bfloat16
# Model hyper-parameters (kept intentionally small for the unit test)
@@ -586,7 +588,7 @@ def _triton_convert_reference_impl(
def test_triton_convert_req_index_to_global_index_decode_only(
block_size, num_topk_tokens
):
device = torch.device("cuda")
device = torch.device(DEVICE_TYPE)
num_tokens = 8
num_requests = 4
max_blocks_per_req = 10
@@ -639,7 +641,7 @@ def test_triton_convert_req_index_to_global_index_decode_only(
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
)
def test_triton_convert_req_index_to_global_index_with_prefill_workspace(block_size):
device = torch.device("cuda")
device = torch.device(DEVICE_TYPE)
num_requests = 4
max_blocks_per_req = 8
num_topk_tokens = 128
@@ -794,7 +796,7 @@ def test_split_indexer_prefill_chunks_single_request_overflow():
def test_triton_convert_returns_valid_counts():
"""Test that return_valid_counts correctly counts non-negative indices."""
device = torch.device("cuda")
device = torch.device(DEVICE_TYPE)
num_tokens = 8
num_requests = 2
max_blocks_per_req = 10