refactor hard coded device string in test files under tests/v1 and tests/lora (#37566)
Signed-off-by: Liao, Wei <wei.liao@intel.com>
This commit is contained in:
@@ -18,6 +18,8 @@ from vllm.v1.attention.backend import CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
DEVICE_TYPE = current_platform.device_type
|
||||
|
||||
if not is_flash_attn_varlen_func_available():
|
||||
pytest.skip(
|
||||
"This test requires flash_attn_varlen_func, but it's not available.",
|
||||
@@ -170,9 +172,9 @@ def _get_available_reference_backends() -> list[AttentionBackendEnum]:
|
||||
|
||||
|
||||
class MockAttentionLayer(torch.nn.Module):
|
||||
_q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
|
||||
_k_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
|
||||
_v_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
|
||||
_q_scale = torch.tensor(1.0, dtype=torch.float32, device=DEVICE_TYPE)
|
||||
_k_scale = torch.tensor(1.0, dtype=torch.float32, device=DEVICE_TYPE)
|
||||
_v_scale = torch.tensor(1.0, dtype=torch.float32, device=DEVICE_TYPE)
|
||||
layer_name = "mock_layer"
|
||||
|
||||
def __init__(self):
|
||||
|
||||
Reference in New Issue
Block a user