Remove hardcoded device="cuda" to support more devices (#2503)

Co-authored-by: Jiang Li <jiang1.li@intel.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji
2024-02-02 07:46:39 +08:00
committed by GitHub
parent c410f5d020
commit 96b6f475dd
32 changed files with 343 additions and 292 deletions

View File

@@ -34,6 +34,9 @@ TOLERANCES = {
torch.float32: (5e-3, 5e-3),
torch.bfloat16: (3e-2, 2e-2),
}
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
def get_random_id_to_index(num_loras: int,
@@ -151,14 +154,10 @@ def create_random_inputs(
for _ in range(num_inputs):
if input_type == torch.int:
inputs.append(
torch.randint(low=int(low),
high=int(high),
size=input_size,
device="cuda"))
torch.randint(low=int(low), high=int(high), size=input_size))
else:
inputs.append(
torch.rand(size=input_size, dtype=input_type, device="cuda") *
high + low)
torch.rand(size=input_size, dtype=input_type) * high + low)
lora_id = random.choice(active_lora_ids)
index_mapping += [lora_id] * input_size[0]
@@ -169,8 +168,10 @@ def create_random_inputs(
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
def test_embeddings(dist_init, num_loras) -> None:
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_embeddings(dist_init, num_loras, device) -> None:
torch.set_default_device(device)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
@@ -259,8 +260,10 @@ def test_embeddings(dist_init, num_loras) -> None:
@torch.inference_mode()
# @pytest.mark.skip(reason="Fails when loras are in any slot other than the first.")
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
def test_embeddings_with_new_embeddings(dist_init, num_loras) -> None:
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
torch.set_default_device(device)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
@@ -305,8 +308,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras) -> None:
# Add empty embeddings_tensors for unoccupied lora slots.
for _ in range(max_loras - len(embeddings_tensors)):
embeddings_tensors.append(
torch.zeros(embeddings_tensors[0].shape, device="cuda"))
embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape))
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=list(lora_dict.keys()),
@@ -388,8 +390,10 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras) -> None:
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
def test_lm_head_sampler(dist_init, num_loras) -> None:
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lm_head_sampler(dist_init, num_loras, device) -> None:
torch.set_default_device(device)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
@@ -432,7 +436,7 @@ def test_lm_head_sampler(dist_init, num_loras) -> None:
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
input_ = torch.rand(20, 1024, device="cuda")
input_ = torch.rand(20, 1024)
mapping_info = convert_mapping(
lora_mapping,
id_to_index,
@@ -500,8 +504,10 @@ def test_lm_head_sampler(dist_init, num_loras) -> None:
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("orientation", ["row", "column"])
def test_linear_parallel(dist_init, num_loras, orientation) -> None:
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
torch.set_default_device(device)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
@@ -597,8 +603,10 @@ def test_linear_parallel(dist_init, num_loras, orientation) -> None:
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("repeats", [2, 3])
def test_column_parallel_packed(dist_init, num_loras, repeats) -> None:
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
torch.set_default_device(device)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,