refactor hard coded device string in test files under tests/v1 and tests/lora (#37566)
Signed-off-by: Liao, Wei <wei.liao@intel.com>
This commit is contained in:
@@ -9,10 +9,13 @@ import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DEVICE_TYPE = current_platform.device_type
|
||||
|
||||
|
||||
class DummyLoRAManager:
|
||||
def __init__(self, device: torch.device = "cuda:0"):
|
||||
def __init__(self, device: torch.device = f"{DEVICE_TYPE}:0"):
|
||||
super().__init__()
|
||||
self._loras: dict[str, LoRALayerWeights] = {}
|
||||
self._device = device
|
||||
@@ -57,8 +60,8 @@ class DummyLoRAManager:
|
||||
module_name,
|
||||
rank=rank,
|
||||
lora_alpha=1,
|
||||
lora_a=torch.rand([rank, input_dim], device="cuda"),
|
||||
lora_b=torch.rand([output_dim, input_dim], device="cuda"),
|
||||
lora_a=torch.rand([rank, input_dim], device=DEVICE_TYPE),
|
||||
lora_b=torch.rand([output_dim, input_dim], device=DEVICE_TYPE),
|
||||
embeddings_tensor=embeddings_tensor,
|
||||
)
|
||||
self.set_module_lora(module_name, lora)
|
||||
|
||||
Reference in New Issue
Block a user