refactor hard coded device string in test files under tests/v1 and tests/lora (#37566)
Signed-off-by: Liao, Wei <wei.liao@intel.com>
This commit is contained in:
@@ -45,7 +45,7 @@ from vllm.v1.worker.utils import AttentionGroup, select_common_block_size
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
NUM_BLOCKS = 10
|
||||
DEVICE = current_platform.device_type
|
||||
DEVICE_TYPE = current_platform.device_type
|
||||
|
||||
|
||||
def initialize_kv_cache(runner: GPUModelRunner):
|
||||
@@ -121,7 +121,7 @@ def model_runner():
|
||||
vllm_config.compilation_config.static_forward_context["layer.0"] = Attention(
|
||||
num_heads, head_size, 0.1
|
||||
)
|
||||
runner = GPUModelRunner(vllm_config, DEVICE)
|
||||
runner = GPUModelRunner(vllm_config, DEVICE_TYPE)
|
||||
initialize_kv_cache(runner)
|
||||
yield runner
|
||||
|
||||
@@ -340,7 +340,7 @@ def test_get_nans_in_logits(model_runner, dist_init):
|
||||
[1.0, 2.0, 3.0],
|
||||
[3.0, 2.0, 1.0],
|
||||
],
|
||||
device=DEVICE,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
result = model_runner._get_nans_in_logits(logits)
|
||||
assert result == {"req_0": 0, "req_1": 0}
|
||||
@@ -350,7 +350,7 @@ def test_get_nans_in_logits(model_runner, dist_init):
|
||||
[1.0, float("nan"), 3.0],
|
||||
[4.0, float("nan"), float("nan")],
|
||||
],
|
||||
device=DEVICE,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
result = model_runner._get_nans_in_logits(logits)
|
||||
assert result == {"req_0": 1, "req_1": 2}
|
||||
@@ -360,7 +360,7 @@ def test_get_nans_in_logits(model_runner, dist_init):
|
||||
[1.0, 2.0, 3.0],
|
||||
[4.0, float("nan"), float("nan")],
|
||||
],
|
||||
device=DEVICE,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
result = model_runner._get_nans_in_logits(logits)
|
||||
assert result == {"req_0": 0, "req_1": 2}
|
||||
@@ -372,7 +372,7 @@ def test_get_nans_in_logits(model_runner, dist_init):
|
||||
[
|
||||
[1.0, float("nan"), 3.0],
|
||||
],
|
||||
device=DEVICE,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
result = model_runner._get_nans_in_logits(logits)
|
||||
assert result == {"req_0": 1, "req_1": 0}
|
||||
@@ -383,7 +383,7 @@ def test_get_nans_in_logits(model_runner, dist_init):
|
||||
[1.0, 2.0, 3.0],
|
||||
[float("nan"), 2.0, 3.0],
|
||||
],
|
||||
device=DEVICE,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
result = model_runner._get_nans_in_logits(logits)
|
||||
assert result == {"req_0": 2, "req_1": 0}
|
||||
@@ -643,7 +643,7 @@ def test_init_kv_cache_without_kv_sharing(default_vllm_config):
|
||||
# Set high context length to test max context length estimation
|
||||
vllm_config.model_config.max_model_len = 3_000_000
|
||||
vllm_ctx = vllm_config.compilation_config.static_forward_context
|
||||
runner = GPUModelRunner(vllm_config, DEVICE)
|
||||
runner = GPUModelRunner(vllm_config, DEVICE_TYPE)
|
||||
kv_cache_spec = runner.get_kv_cache_spec()
|
||||
assert len(kv_cache_spec) == 2
|
||||
assert len(runner.shared_kv_cache_layers) == 0
|
||||
@@ -711,7 +711,7 @@ def test_init_kv_cache_with_kv_sharing_valid(default_vllm_config):
|
||||
# Set high context length to test max context length estimation
|
||||
vllm_config.model_config.max_model_len = 3_000_000
|
||||
vllm_ctx = vllm_config.compilation_config.static_forward_context
|
||||
runner = GPUModelRunner(vllm_config, DEVICE)
|
||||
runner = GPUModelRunner(vllm_config, DEVICE_TYPE)
|
||||
kv_cache_spec = runner.get_kv_cache_spec()
|
||||
assert len(kv_cache_spec) == 1
|
||||
assert layer_0 in kv_cache_spec
|
||||
@@ -850,7 +850,7 @@ def test_hybrid_attention_mamba_tensor_shapes():
|
||||
assert fwd_context is not None
|
||||
vllm_ctx = vllm_config.compilation_config.static_forward_context
|
||||
|
||||
runner = GPUModelRunner(vllm_config, DEVICE)
|
||||
runner = GPUModelRunner(vllm_config, DEVICE_TYPE)
|
||||
current_platform.update_block_size_for_backend(vllm_config)
|
||||
kv_cache_spec = runner.get_kv_cache_spec()
|
||||
|
||||
@@ -896,13 +896,13 @@ def test_hybrid_attention_mamba_tensor_shapes():
|
||||
ssm_constant_shape = ssm_shape[1:]
|
||||
|
||||
attn_blocks_constant = torch.full(
|
||||
(test_block_size, *attn_constant_shape), device=DEVICE, fill_value=3.33
|
||||
(test_block_size, *attn_constant_shape), device=DEVICE_TYPE, fill_value=3.33
|
||||
)
|
||||
conv_blocks_constant = torch.full(
|
||||
(test_block_size, *conv_constant_shape), device=DEVICE, fill_value=6.66
|
||||
(test_block_size, *conv_constant_shape), device=DEVICE_TYPE, fill_value=6.66
|
||||
)
|
||||
ssm_blocks_constant = torch.full(
|
||||
(test_block_size, *ssm_constant_shape), device=DEVICE, fill_value=9.99
|
||||
(test_block_size, *ssm_constant_shape), device=DEVICE_TYPE, fill_value=9.99
|
||||
)
|
||||
|
||||
# Fill attention blocks with constants using kv block indices
|
||||
@@ -997,7 +997,7 @@ def test_hybrid_block_table_initialization():
|
||||
max_num_blocks_per_req=max_num_blocks_per_req,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
pin_memory=False,
|
||||
device=torch.device(DEVICE),
|
||||
device=torch.device(DEVICE_TYPE),
|
||||
kernel_block_size=kernel_block_sizes[0],
|
||||
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
|
||||
)
|
||||
@@ -1036,7 +1036,7 @@ def test_input_batch_with_kernel_block_sizes():
|
||||
max_num_reqs = 10
|
||||
max_model_len = 512
|
||||
max_num_batched_tokens = 512
|
||||
device = torch.device(DEVICE)
|
||||
device = torch.device(DEVICE_TYPE)
|
||||
pin_memory = False
|
||||
vocab_size = 50272
|
||||
|
||||
@@ -1083,7 +1083,7 @@ def test_hybrid_cache_integration(default_vllm_config, dist_init):
|
||||
num_heads, head_size, 0.1
|
||||
)
|
||||
|
||||
runner = GPUModelRunner(vllm_config, DEVICE)
|
||||
runner = GPUModelRunner(vllm_config, DEVICE_TYPE)
|
||||
|
||||
# Initialize KV cache with configuration
|
||||
attn_spec = FullAttentionSpec(
|
||||
@@ -1306,7 +1306,7 @@ def test_mamba_cache_raises_when_max_num_seqs_exceeds_blocks():
|
||||
)
|
||||
assert fwd_context is not None
|
||||
|
||||
runner = GPUModelRunner(vllm_config, DEVICE)
|
||||
runner = GPUModelRunner(vllm_config, DEVICE_TYPE)
|
||||
current_platform.update_block_size_for_backend(vllm_config)
|
||||
kv_cache_spec = runner.get_kv_cache_spec()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user