[Kernel] Full Tensor Parallelism for LoRA Layers (#3524)
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
@@ -8,6 +8,10 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.lora.fully_sharded_layers import (
|
||||
ColumnParallelLinearWithShardedLoRA,
|
||||
MergedColumnParallelLinearWithShardedLoRA,
|
||||
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA)
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
||||
@@ -524,13 +528,16 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("orientation", ["row", "column"])
|
||||
@pytest.mark.parametrize("fully_shard", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
|
||||
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
||||
device) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
fully_sharded_loras=fully_shard,
|
||||
lora_dtype=torch.float16)
|
||||
|
||||
def create_random_linear_parallel_layer():
|
||||
@@ -540,14 +547,17 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
|
||||
bias=False,
|
||||
params_dtype=torch.float16)
|
||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||
lora_linear = RowParallelLinearWithLoRA(linear)
|
||||
lora_linear = (RowParallelLinearWithLoRA(linear) if not fully_shard
|
||||
else RowParallelLinearWithShardedLoRA(linear))
|
||||
else:
|
||||
linear = ColumnParallelLinear(4096,
|
||||
4096,
|
||||
bias=False,
|
||||
params_dtype=torch.float16)
|
||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||
lora_linear = ColumnParallelLinearWithLoRA(linear)
|
||||
lora_linear = (ColumnParallelLinearWithLoRA(linear)
|
||||
if not fully_shard else
|
||||
ColumnParallelLinearWithShardedLoRA(linear))
|
||||
lora_linear.create_lora_weights(max_loras, lora_config)
|
||||
|
||||
return linear, lora_linear
|
||||
@@ -629,13 +639,16 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("repeats", [1, 2, 3])
|
||||
@pytest.mark.parametrize("fully_shard", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
|
||||
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
||||
device) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
fully_sharded_loras=fully_shard,
|
||||
lora_dtype=torch.float16)
|
||||
|
||||
def create_column_parallel_packed_layer():
|
||||
@@ -644,7 +657,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
|
||||
bias=False,
|
||||
params_dtype=torch.float16)
|
||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||
lora_linear = MergedColumnParallelLinearWithLoRA(linear)
|
||||
lora_linear = (MergedColumnParallelLinearWithLoRA(linear)
|
||||
if not fully_shard else
|
||||
MergedColumnParallelLinearWithShardedLoRA(linear))
|
||||
elif repeats == 3:
|
||||
linear = QKVParallelLinear(4096,
|
||||
64,
|
||||
@@ -652,7 +667,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
|
||||
bias=False,
|
||||
params_dtype=torch.float16)
|
||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||
lora_linear = MergedQKVParallelLinearWithLora(linear)
|
||||
lora_linear = (MergedQKVParallelLinearWithLora(linear)
|
||||
if not fully_shard else
|
||||
MergedQKVParallelLinearWithShardedLora(linear))
|
||||
else:
|
||||
linear = QKVParallelLinear(4096,
|
||||
64,
|
||||
|
||||
Reference in New Issue
Block a user