refactor hard coded device string in test files under tests/v1 and tests/lora (#37566)
Signed-off-by: Liao, Wei <wei.liao@intel.com>
This commit is contained in:
@@ -27,6 +27,7 @@ from vllm.v1.spec_decode.extract_hidden_states import ExtractHiddenStatesPropose
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
model_dir = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
||||
DEVICE_TYPE = current_platform.device_type
|
||||
|
||||
|
||||
def _create_proposer(
|
||||
@@ -51,7 +52,7 @@ def _create_proposer(
|
||||
},
|
||||
)
|
||||
|
||||
device = current_platform.device_type
|
||||
device = DEVICE_TYPE
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=CacheConfig(),
|
||||
@@ -101,7 +102,7 @@ def test_proposer_initialization_missing_layer_ids():
|
||||
},
|
||||
)
|
||||
|
||||
device = current_platform.device_type
|
||||
device = DEVICE_TYPE
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=CacheConfig(),
|
||||
@@ -130,7 +131,7 @@ def test_prepare_next_token_ids_padded():
|
||||
For each request we either use the sampled token (if valid and not discarded)
|
||||
or a backup token from the request state.
|
||||
"""
|
||||
device = torch.device(current_platform.device_type)
|
||||
device = torch.device(DEVICE_TYPE)
|
||||
|
||||
num_requests = 4
|
||||
req_ids = [f"req_{i + 1}" for i in range(num_requests)]
|
||||
@@ -197,7 +198,7 @@ def test_propose():
|
||||
2. Return the sampled tokens as "draft" tokens (shape [batch_size, 1])
|
||||
3. Cache the hidden states in the model's KV cache
|
||||
"""
|
||||
device = torch.device(current_platform.device_type)
|
||||
device = torch.device(DEVICE_TYPE)
|
||||
|
||||
# Setup test parameters
|
||||
batch_size = 2
|
||||
@@ -273,7 +274,7 @@ def test_propose():
|
||||
@pytest.mark.parametrize("num_hidden_layers", [1, 4, 8])
|
||||
def test_propose_different_layer_counts(num_hidden_layers):
|
||||
"""Test that propose works correctly with different numbers of hidden layers."""
|
||||
device = torch.device(current_platform.device_type)
|
||||
device = torch.device(DEVICE_TYPE)
|
||||
|
||||
batch_size = 2
|
||||
num_tokens = 5
|
||||
|
||||
Reference in New Issue
Block a user