[Misc][LoRA] Abstract PunicaWrapper (#10955)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -28,7 +28,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
||||
# yapf: enable
|
||||
from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights,
|
||||
PackedLoRALayerWeights)
|
||||
from vllm.lora.punica import PunicaWrapper
|
||||
from vllm.lora.punica_wrapper import get_punica_wrapper
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
@@ -48,11 +48,12 @@ TOLERANCES = {
|
||||
torch.float32: (5e-3, 5e-3),
|
||||
torch.bfloat16: (3e-2, 2e-2),
|
||||
}
|
||||
CUDA_DEVICES = [
|
||||
# TODO: Modify this based on platform
|
||||
DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
# We will launch different triton kernels between the prefill and decode
|
||||
#For GPU, we will launch different triton kernels between the prefill and decode
|
||||
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
|
||||
STAGES = [True, False]
|
||||
|
||||
@@ -192,9 +193,18 @@ def create_random_inputs(
|
||||
return inputs, index_mapping, prompt_mapping
|
||||
|
||||
|
||||
def check_punica_wrapper(punica_wrapper) -> bool:
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
|
||||
|
||||
return type(punica_wrapper) is PunicaWrapperGPU
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
|
||||
@pytest.mark.parametrize("stage", STAGES)
|
||||
def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
||||
@@ -205,7 +215,8 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
punica_wrapper = PunicaWrapper(8192, 256, device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
lora_dtype=torch.float16)
|
||||
@@ -296,7 +307,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
||||
# @pytest.mark.skip(
|
||||
# reason="Fails when loras are in any slot other than the first.")
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
|
||||
@pytest.mark.parametrize("stage", STAGES)
|
||||
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
||||
@@ -305,7 +316,8 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
punica_wrapper = PunicaWrapper(8192, 256, device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
lora_dtype=torch.float16)
|
||||
@@ -432,7 +444,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512])
|
||||
@pytest.mark.parametrize("stage", STAGES)
|
||||
def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
|
||||
@@ -441,7 +453,8 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
punica_wrapper = PunicaWrapper(8192, 256, device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
lora_dtype=torch.float16)
|
||||
@@ -563,7 +576,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("stage", STAGES)
|
||||
@pytest.mark.parametrize("bias_enabled", [True, False])
|
||||
def test_linear_replicated(dist_init, num_loras, device, stage,
|
||||
@@ -571,7 +584,8 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
|
||||
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = PunicaWrapper(8192, 256, device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
max_loras = 8
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
@@ -675,7 +689,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("orientation", ["row", "column"])
|
||||
@pytest.mark.parametrize("fully_shard", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("stage", STAGES)
|
||||
@pytest.mark.parametrize("bias_enabled", [True, False])
|
||||
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
||||
@@ -683,7 +697,8 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
||||
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = PunicaWrapper(8192, 256, device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
max_loras = 8
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
@@ -797,7 +812,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("repeats", [1, 2, 3])
|
||||
@pytest.mark.parametrize("fully_shard", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("stage", STAGES)
|
||||
@pytest.mark.parametrize("bias_enabled", [True, False])
|
||||
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
||||
@@ -805,7 +820,8 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
||||
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = PunicaWrapper(8192, 256, device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
max_loras = 8
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
@@ -963,7 +979,8 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
|
||||
seed = 0
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = PunicaWrapper(8192, 256, device)
|
||||
punica_wrapper = get_punica_wrapper(8192, 256, device)
|
||||
assert check_punica_wrapper(punica_wrapper)
|
||||
max_loras = 8
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
|
||||
Reference in New Issue
Block a user