[Misc] LoRA - Refactor Punica ops tests (#12970)
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
committed by
GitHub
parent
c320ca8edd
commit
78a141d768
652
tests/lora/test_punica_ops.py
Normal file
652
tests/lora/test_punica_ops.py
Normal file
@@ -0,0 +1,652 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from threading import Lock
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.lora.ops.triton_ops # noqa: F401
|
||||||
|
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
|
||||||
|
bgmv_shrink, sgmv_expand,
|
||||||
|
sgmv_expand_slice, sgmv_shrink)
|
||||||
|
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
from .utils import (PunicaTensors, assert_close, generate_data,
|
||||||
|
generate_data_for_expand_nslices,
|
||||||
|
generate_data_for_nslices)
|
||||||
|
|
||||||
|
|
||||||
|
# Utility shrink and expand operations used as reference implementations.
|
||||||
|
def sgmv_shrink_for_nslices(
|
||||||
|
nslices: int, inputs_tensor: torch.Tensor,
|
||||||
|
lora_weights_lst: List[torch.Tensor], out_tensor: torch.Tensor,
|
||||||
|
b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor,
|
||||||
|
prompt_lora_mapping: torch.Tensor, batches: int, max_seq_length: int,
|
||||||
|
num_tokens: int, scaling: float):
|
||||||
|
"""
|
||||||
|
Wrapper around sgmv_shrink that handles any nslices.
|
||||||
|
"""
|
||||||
|
for index in range(nslices):
|
||||||
|
sgmv_shrink(
|
||||||
|
inputs_tensor,
|
||||||
|
lora_weights_lst[index],
|
||||||
|
out_tensor[index],
|
||||||
|
b_seq_start_loc,
|
||||||
|
seq_len_tensor,
|
||||||
|
prompt_lora_mapping,
|
||||||
|
batches,
|
||||||
|
max_seq_length,
|
||||||
|
num_tokens,
|
||||||
|
scaling,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
|
||||||
|
inputs_tensor: torch.Tensor,
|
||||||
|
lora_weights_lst: List[torch.Tensor],
|
||||||
|
out_tensor: torch.Tensor,
|
||||||
|
b_seq_start_loc: torch.Tensor,
|
||||||
|
seq_len_tensor: torch.Tensor,
|
||||||
|
prompt_lora_mapping: torch.Tensor, batches: int,
|
||||||
|
max_seq_length: int, num_tokens: int,
|
||||||
|
add_inputs: bool) -> None:
|
||||||
|
"""
|
||||||
|
Wrapper around sgmv_expand that handles any nslices.
|
||||||
|
"""
|
||||||
|
if nslices == 1:
|
||||||
|
# Verify the torch's sgmv_expand op
|
||||||
|
sgmv_expand(
|
||||||
|
inputs_tensor[0],
|
||||||
|
lora_weights_lst[0],
|
||||||
|
out_tensor,
|
||||||
|
b_seq_start_loc,
|
||||||
|
seq_len_tensor,
|
||||||
|
prompt_lora_mapping,
|
||||||
|
batches,
|
||||||
|
max_seq_length,
|
||||||
|
num_tokens,
|
||||||
|
add_inputs=add_inputs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
slice_offset = 0
|
||||||
|
for index in range(nslices):
|
||||||
|
lora_weights = lora_weights_lst[index]
|
||||||
|
sgmv_expand_slice(
|
||||||
|
inputs_tensor[index],
|
||||||
|
lora_weights,
|
||||||
|
out_tensor,
|
||||||
|
b_seq_start_loc,
|
||||||
|
seq_len_tensor,
|
||||||
|
prompt_lora_mapping,
|
||||||
|
batches,
|
||||||
|
max_seq_length,
|
||||||
|
num_tokens,
|
||||||
|
slice_offset,
|
||||||
|
hidden_size,
|
||||||
|
add_inputs=add_inputs,
|
||||||
|
)
|
||||||
|
slice_offset += hidden_size
|
||||||
|
|
||||||
|
|
||||||
|
_dict_lock = Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def check_sgmv_shrink(batches: int, num_loras: int, rank: int,
|
||||||
|
hidden_size: int, nslices: int, dtype: torch.dtype,
|
||||||
|
device: str, seq_length: int, scaling: float):
|
||||||
|
"""
|
||||||
|
Compare outputs of vllm.sgmv_shrink kernel against a reference
|
||||||
|
implementation.
|
||||||
|
"""
|
||||||
|
data: PunicaTensors = generate_data_for_nslices(
|
||||||
|
batches,
|
||||||
|
hidden_size,
|
||||||
|
num_loras,
|
||||||
|
rank,
|
||||||
|
seq_length,
|
||||||
|
nslices,
|
||||||
|
dtype,
|
||||||
|
"shrink",
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
max_seq_length, token_nums = data.meta()
|
||||||
|
|
||||||
|
# Preventing cache error pointer.
|
||||||
|
with _dict_lock:
|
||||||
|
_LORA_A_PTR_DICT.clear()
|
||||||
|
torch.ops.vllm.sgmv_shrink(
|
||||||
|
data.inputs_tensor,
|
||||||
|
data.lora_weights,
|
||||||
|
data.our_out_tensor,
|
||||||
|
data.b_seq_start_loc,
|
||||||
|
data.seq_len_tensor,
|
||||||
|
data.prompt_lora_mapping,
|
||||||
|
batches,
|
||||||
|
max_seq_length,
|
||||||
|
token_nums,
|
||||||
|
scaling,
|
||||||
|
)
|
||||||
|
|
||||||
|
sgmv_shrink_for_nslices(
|
||||||
|
nslices,
|
||||||
|
data.inputs_tensor,
|
||||||
|
data.lora_weights,
|
||||||
|
data.ref_out_tensor,
|
||||||
|
data.b_seq_start_loc,
|
||||||
|
data.seq_len_tensor,
|
||||||
|
data.prompt_lora_mapping,
|
||||||
|
batches,
|
||||||
|
max_seq_length,
|
||||||
|
token_nums,
|
||||||
|
scaling,
|
||||||
|
)
|
||||||
|
assert_close(data.our_out_tensor, data.ref_out_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def check_sgmv_expand(batches: int, num_loras: int, rank: int,
|
||||||
|
hidden_size: int, nslices: int, dtype: torch.dtype,
|
||||||
|
device: str, seq_length: int, add_inputs: bool):
|
||||||
|
"""
|
||||||
|
Compare outputs of vllm.sgmv_expand kernel against a reference
|
||||||
|
implementation.
|
||||||
|
"""
|
||||||
|
data: PunicaTensors = generate_data_for_nslices(
|
||||||
|
batches,
|
||||||
|
hidden_size,
|
||||||
|
num_loras,
|
||||||
|
rank,
|
||||||
|
seq_length,
|
||||||
|
nslices,
|
||||||
|
dtype,
|
||||||
|
"expand",
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
|
||||||
|
max_seq_length, token_nums = data.meta()
|
||||||
|
|
||||||
|
with _dict_lock:
|
||||||
|
_LORA_B_PTR_DICT.clear()
|
||||||
|
torch.ops.vllm.sgmv_expand(
|
||||||
|
data.inputs_tensor,
|
||||||
|
data.lora_weights,
|
||||||
|
data.our_out_tensor,
|
||||||
|
data.b_seq_start_loc,
|
||||||
|
data.seq_len_tensor,
|
||||||
|
data.prompt_lora_mapping,
|
||||||
|
batches,
|
||||||
|
max_seq_length,
|
||||||
|
token_nums,
|
||||||
|
offset_start=0,
|
||||||
|
add_inputs=add_inputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
sgmv_expand_for_nslices(nslices,
|
||||||
|
hidden_size,
|
||||||
|
data.inputs_tensor,
|
||||||
|
data.lora_weights,
|
||||||
|
data.ref_out_tensor,
|
||||||
|
data.b_seq_start_loc,
|
||||||
|
data.seq_len_tensor,
|
||||||
|
data.prompt_lora_mapping,
|
||||||
|
batches,
|
||||||
|
max_seq_length,
|
||||||
|
token_nums,
|
||||||
|
add_inputs=add_inputs)
|
||||||
|
|
||||||
|
assert_close(data.our_out_tensor, data.ref_out_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def check_bgmv_shrink(batches: int, num_loras: int, rank: int,
|
||||||
|
hidden_size: int, dtype: torch.dtype, device: str,
|
||||||
|
scaling: float):
|
||||||
|
"""
|
||||||
|
Compare vllm.bgmv_shrink against a reference implementation.
|
||||||
|
"""
|
||||||
|
seq_length = 1
|
||||||
|
data: PunicaTensors = generate_data(
|
||||||
|
batches,
|
||||||
|
hidden_size,
|
||||||
|
num_loras,
|
||||||
|
rank,
|
||||||
|
seq_length,
|
||||||
|
dtype,
|
||||||
|
"shrink",
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.ops.vllm.bgmv_shrink(
|
||||||
|
data.inputs_tensor,
|
||||||
|
data.lora_weights,
|
||||||
|
data.our_out_tensor,
|
||||||
|
data.token_lora_mapping,
|
||||||
|
scaling,
|
||||||
|
)
|
||||||
|
|
||||||
|
bgmv_shrink(
|
||||||
|
data.inputs_tensor,
|
||||||
|
data.lora_weights,
|
||||||
|
data.ref_out_tensor,
|
||||||
|
data.token_lora_mapping,
|
||||||
|
scaling,
|
||||||
|
)
|
||||||
|
|
||||||
|
data.ref_out_tensor = data.ref_out_tensor.to(torch.float32)
|
||||||
|
assert_close(data.our_out_tensor, data.ref_out_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def check_bgmv_expand(batches: int, num_loras: int, rank: int,
|
||||||
|
hidden_size: int, dtype: torch.dtype, device: str,
|
||||||
|
add_inputs: bool):
|
||||||
|
"""
|
||||||
|
Compare vllm.bgmv_expand against a reference implementation.
|
||||||
|
"""
|
||||||
|
seq_length = 1
|
||||||
|
data: PunicaTensors = generate_data(
|
||||||
|
batches,
|
||||||
|
hidden_size,
|
||||||
|
num_loras,
|
||||||
|
rank,
|
||||||
|
seq_length,
|
||||||
|
dtype,
|
||||||
|
"expand",
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.ops.vllm.bgmv_expand(
|
||||||
|
data.inputs_tensor,
|
||||||
|
data.lora_weights,
|
||||||
|
data.our_out_tensor,
|
||||||
|
data.token_lora_mapping,
|
||||||
|
add_inputs=add_inputs,
|
||||||
|
)
|
||||||
|
bgmv_expand(
|
||||||
|
data.inputs_tensor,
|
||||||
|
data.lora_weights,
|
||||||
|
data.ref_out_tensor,
|
||||||
|
data.token_lora_mapping,
|
||||||
|
add_inputs=add_inputs,
|
||||||
|
)
|
||||||
|
assert_close(data.our_out_tensor, data.ref_out_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def check_bgmv_expand_slice(batches: int, num_loras: int, rank: int,
|
||||||
|
hidden_size: int, nslices: int, dtype: torch.dtype,
|
||||||
|
device: str, add_inputs: bool):
|
||||||
|
"""
|
||||||
|
Compare vllm.bgmv_expand_slice against a reference implementation.
|
||||||
|
"""
|
||||||
|
seq_length = 1
|
||||||
|
data: PunicaTensors = generate_data_for_expand_nslices(
|
||||||
|
batches,
|
||||||
|
hidden_size,
|
||||||
|
num_loras,
|
||||||
|
rank,
|
||||||
|
seq_length,
|
||||||
|
dtype,
|
||||||
|
nslices,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
|
||||||
|
slice_offset = 0
|
||||||
|
for index in range(nslices):
|
||||||
|
torch.ops.vllm.bgmv_expand_slice(
|
||||||
|
data.inputs_tensor,
|
||||||
|
data.lora_weights[index],
|
||||||
|
data.our_out_tensor,
|
||||||
|
data.token_lora_mapping,
|
||||||
|
slice_offset,
|
||||||
|
slice_size=hidden_size,
|
||||||
|
add_inputs=add_inputs,
|
||||||
|
)
|
||||||
|
bgmv_expand_slice(
|
||||||
|
data.inputs_tensor,
|
||||||
|
data.lora_weights[index],
|
||||||
|
data.ref_out_tensor,
|
||||||
|
data.token_lora_mapping,
|
||||||
|
slice_offset,
|
||||||
|
slice_size=hidden_size,
|
||||||
|
add_inputs=add_inputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
slice_offset += hidden_size
|
||||||
|
assert_close(data.our_out_tensor, data.ref_out_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
# Tests
|
||||||
|
# We test the punica kernels along 2 verticals mainly.
|
||||||
|
# 1. Variations in hidden_dim size
|
||||||
|
# 2. Variations in all other parameters like (batch_size, max_rank, num_loras
|
||||||
|
# etc.)
|
||||||
|
|
||||||
|
# We have collected the hidden_sizes included in the LoRA models
|
||||||
|
# currently supported by vLLM. It tests whether the corresponding Triton
|
||||||
|
# kernel can run normally when tensor parallelism is set to
|
||||||
|
# [1, 2, 4, 8, 16, 32, 64].
|
||||||
|
HIDDEN_SIZES = [
|
||||||
|
128,
|
||||||
|
256,
|
||||||
|
512,
|
||||||
|
896,
|
||||||
|
1024,
|
||||||
|
1152,
|
||||||
|
1216,
|
||||||
|
1280,
|
||||||
|
1536,
|
||||||
|
1664,
|
||||||
|
2048,
|
||||||
|
2240,
|
||||||
|
2304,
|
||||||
|
2368,
|
||||||
|
2432,
|
||||||
|
2560,
|
||||||
|
2752,
|
||||||
|
3072,
|
||||||
|
3328,
|
||||||
|
3456,
|
||||||
|
3584,
|
||||||
|
3712,
|
||||||
|
4096,
|
||||||
|
4480,
|
||||||
|
4608,
|
||||||
|
4736,
|
||||||
|
4864,
|
||||||
|
5120,
|
||||||
|
5504,
|
||||||
|
5632,
|
||||||
|
5888,
|
||||||
|
6144,
|
||||||
|
6400,
|
||||||
|
6848,
|
||||||
|
6912,
|
||||||
|
7168,
|
||||||
|
7424,
|
||||||
|
8192,
|
||||||
|
8960,
|
||||||
|
9216,
|
||||||
|
9472,
|
||||||
|
10240,
|
||||||
|
11008,
|
||||||
|
11264,
|
||||||
|
13824,
|
||||||
|
14336,
|
||||||
|
14784,
|
||||||
|
14848,
|
||||||
|
15360,
|
||||||
|
18944,
|
||||||
|
22016,
|
||||||
|
22528,
|
||||||
|
24576,
|
||||||
|
27392,
|
||||||
|
27648,
|
||||||
|
29568,
|
||||||
|
29696,
|
||||||
|
32000,
|
||||||
|
32256,
|
||||||
|
32512,
|
||||||
|
32768,
|
||||||
|
33024,
|
||||||
|
36864,
|
||||||
|
43264,
|
||||||
|
49152,
|
||||||
|
49408,
|
||||||
|
60544,
|
||||||
|
60672,
|
||||||
|
64000,
|
||||||
|
64256,
|
||||||
|
102400,
|
||||||
|
102656,
|
||||||
|
128000,
|
||||||
|
128256,
|
||||||
|
]
|
||||||
|
#The size of TP
|
||||||
|
divisibility = [1, 2, 8, 16, 64]
|
||||||
|
|
||||||
|
all_hidden_size = []
|
||||||
|
for div in divisibility:
|
||||||
|
for hidden_size in HIDDEN_SIZES:
|
||||||
|
all_hidden_size.append(hidden_size // div)
|
||||||
|
|
||||||
|
HIDDEN_SIZES = list(set(all_hidden_size))
|
||||||
|
|
||||||
|
# Test params that focuses on hidden_size variation.
|
||||||
|
hs_test_params = {
|
||||||
|
"hidden_sizes": HIDDEN_SIZES,
|
||||||
|
"batches": [4],
|
||||||
|
"num_loras": [4],
|
||||||
|
"max_ranks": [32],
|
||||||
|
}
|
||||||
|
|
||||||
|
# General tests params that tests for variations in all dimensions
|
||||||
|
# except hidden_size.
|
||||||
|
test_params = {
|
||||||
|
"hidden_sizes": [2049],
|
||||||
|
"batches": [1, 4, 16, 32],
|
||||||
|
"num_loras": [1, 8, 32, 128],
|
||||||
|
"max_ranks": [1, 4, 8, 16, 32, 64, 128, 256],
|
||||||
|
}
|
||||||
|
|
||||||
|
DTYPES = [torch.float16, torch.bfloat16]
|
||||||
|
DEVICES = [f"cuda:{0}"]
|
||||||
|
SEED = [0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batches", test_params['batches'])
|
||||||
|
@pytest.mark.parametrize("num_loras", test_params['num_loras'])
|
||||||
|
@pytest.mark.parametrize("rank", test_params['max_ranks'])
|
||||||
|
@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes'])
|
||||||
|
@pytest.mark.parametrize("nslices", [1, 2, 3])
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("device", DEVICES)
|
||||||
|
@pytest.mark.parametrize("seed", SEED)
|
||||||
|
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||||
|
def test_punica_sgmv(
|
||||||
|
batches: int,
|
||||||
|
num_loras: int,
|
||||||
|
rank: int,
|
||||||
|
hidden_size: int,
|
||||||
|
nslices: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: str,
|
||||||
|
seed: int,
|
||||||
|
op_type: str,
|
||||||
|
):
|
||||||
|
torch.set_default_device(device)
|
||||||
|
current_platform.seed_everything(seed)
|
||||||
|
|
||||||
|
if op_type == "shrink":
|
||||||
|
check_sgmv_shrink(batches=batches,
|
||||||
|
num_loras=num_loras,
|
||||||
|
rank=rank,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
nslices=nslices,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
seq_length=128,
|
||||||
|
scaling=0.5)
|
||||||
|
else:
|
||||||
|
check_sgmv_expand(batches=batches,
|
||||||
|
num_loras=num_loras,
|
||||||
|
rank=rank,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
nslices=nslices,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
seq_length=128,
|
||||||
|
add_inputs=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batches", hs_test_params['batches'])
|
||||||
|
@pytest.mark.parametrize("num_loras", hs_test_params['num_loras'])
|
||||||
|
@pytest.mark.parametrize("rank", hs_test_params['max_ranks'])
|
||||||
|
@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes'])
|
||||||
|
@pytest.mark.parametrize("nslices", [1, 2, 3])
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("device", DEVICES)
|
||||||
|
@pytest.mark.parametrize("seed", SEED)
|
||||||
|
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||||
|
def test_punica_sgmv_hidden_size(
|
||||||
|
batches: int,
|
||||||
|
num_loras: int,
|
||||||
|
rank: int,
|
||||||
|
hidden_size: int,
|
||||||
|
nslices: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: str,
|
||||||
|
seed: int,
|
||||||
|
op_type: str,
|
||||||
|
):
|
||||||
|
torch.set_default_device(device)
|
||||||
|
current_platform.seed_everything(seed)
|
||||||
|
|
||||||
|
if op_type == "shrink":
|
||||||
|
check_sgmv_shrink(batches=batches,
|
||||||
|
num_loras=num_loras,
|
||||||
|
rank=rank,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
nslices=nslices,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
seq_length=128,
|
||||||
|
scaling=0.5)
|
||||||
|
else:
|
||||||
|
check_sgmv_expand(batches=batches,
|
||||||
|
num_loras=num_loras,
|
||||||
|
rank=rank,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
nslices=nslices,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
seq_length=128,
|
||||||
|
add_inputs=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batches", test_params['batches'])
|
||||||
|
@pytest.mark.parametrize("num_loras", test_params['num_loras'])
|
||||||
|
@pytest.mark.parametrize("rank", test_params['max_ranks'])
|
||||||
|
@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes'])
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("device", DEVICES)
|
||||||
|
@pytest.mark.parametrize("seed", SEED)
|
||||||
|
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||||
|
def test_punica_bgmv(
|
||||||
|
batches: int,
|
||||||
|
num_loras: int,
|
||||||
|
rank: int,
|
||||||
|
hidden_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: str,
|
||||||
|
seed: int,
|
||||||
|
op_type: str,
|
||||||
|
):
|
||||||
|
torch.set_default_device(device)
|
||||||
|
current_platform.seed_everything(seed)
|
||||||
|
|
||||||
|
if op_type == "shrink":
|
||||||
|
check_bgmv_shrink(batches=batches,
|
||||||
|
num_loras=num_loras,
|
||||||
|
rank=rank,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
scaling=0.5)
|
||||||
|
else:
|
||||||
|
check_bgmv_expand(batches=batches,
|
||||||
|
num_loras=num_loras,
|
||||||
|
rank=rank,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
add_inputs=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batches", hs_test_params['batches'])
|
||||||
|
@pytest.mark.parametrize("num_loras", hs_test_params['num_loras'])
|
||||||
|
@pytest.mark.parametrize("rank", hs_test_params['max_ranks'])
|
||||||
|
@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes'])
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("device", DEVICES)
|
||||||
|
@pytest.mark.parametrize("seed", SEED)
|
||||||
|
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||||
|
def test_punica_bgmv_hidden_size(
|
||||||
|
batches: int,
|
||||||
|
num_loras: int,
|
||||||
|
rank: int,
|
||||||
|
hidden_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: str,
|
||||||
|
seed: int,
|
||||||
|
op_type: str,
|
||||||
|
):
|
||||||
|
torch.set_default_device(device)
|
||||||
|
current_platform.seed_everything(seed)
|
||||||
|
|
||||||
|
if op_type == "shrink":
|
||||||
|
check_bgmv_shrink(batches=batches,
|
||||||
|
num_loras=num_loras,
|
||||||
|
rank=rank,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
scaling=0.5)
|
||||||
|
else:
|
||||||
|
check_bgmv_expand(batches=batches,
|
||||||
|
num_loras=num_loras,
|
||||||
|
rank=rank,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
add_inputs=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batches", test_params['batches'])
|
||||||
|
@pytest.mark.parametrize("num_loras", test_params['num_loras'])
|
||||||
|
@pytest.mark.parametrize("rank", test_params['max_ranks'])
|
||||||
|
@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes'])
|
||||||
|
@pytest.mark.parametrize("nslices", [2, 3])
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("device", DEVICES)
|
||||||
|
@pytest.mark.parametrize("seed", SEED)
|
||||||
|
def test_punica_bgmv_expand_nslices(batches: int, num_loras: int, rank: int,
|
||||||
|
hidden_size: int, nslices: int,
|
||||||
|
dtype: torch.dtype, device: str,
|
||||||
|
seed: int):
|
||||||
|
|
||||||
|
torch.set_default_device(device)
|
||||||
|
current_platform.seed_everything(seed)
|
||||||
|
|
||||||
|
check_bgmv_expand_slice(batches=batches,
|
||||||
|
num_loras=num_loras,
|
||||||
|
rank=rank,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
nslices=nslices,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
add_inputs=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("batches", hs_test_params['batches'])
|
||||||
|
@pytest.mark.parametrize("num_loras", hs_test_params['num_loras'])
|
||||||
|
@pytest.mark.parametrize("rank", hs_test_params['max_ranks'])
|
||||||
|
@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes'])
|
||||||
|
@pytest.mark.parametrize("nslices", [2, 3])
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("device", DEVICES)
|
||||||
|
@pytest.mark.parametrize("seed", SEED)
|
||||||
|
def test_punica_bgmv_expand_nslices_hidden_size(batches: int, num_loras: int,
|
||||||
|
rank: int, hidden_size: int,
|
||||||
|
nslices: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: str, seed: int):
|
||||||
|
|
||||||
|
torch.set_default_device(device)
|
||||||
|
current_platform.seed_everything(seed)
|
||||||
|
|
||||||
|
check_bgmv_expand_slice(batches=batches,
|
||||||
|
num_loras=num_loras,
|
||||||
|
rank=rank,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
nslices=nslices,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
add_inputs=True)
|
||||||
@@ -1,401 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
"""
|
|
||||||
This script is mainly used to tests various hidden_sizes. We have collected the
|
|
||||||
hidden_sizes included in the LoRA models currently supported by vLLM. It tests
|
|
||||||
whether the corresponding Triton kernel can run normally when tensor parallelism
|
|
||||||
is set to [1, 2, 4, 8, 16, 32, 64].
|
|
||||||
"""
|
|
||||||
from threading import Lock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import vllm.lora.ops.triton_ops # noqa: F401
|
|
||||||
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
|
|
||||||
bgmv_shrink, sgmv_expand,
|
|
||||||
sgmv_expand_slice, sgmv_shrink)
|
|
||||||
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
from .utils import (assert_close, generate_data,
|
|
||||||
generate_data_for_expand_nslices,
|
|
||||||
generate_data_for_nslices)
|
|
||||||
|
|
||||||
HIDDEN_SIZES = [
|
|
||||||
128,
|
|
||||||
256,
|
|
||||||
512,
|
|
||||||
896,
|
|
||||||
1024,
|
|
||||||
1152,
|
|
||||||
1216,
|
|
||||||
1280,
|
|
||||||
1536,
|
|
||||||
1664,
|
|
||||||
2048,
|
|
||||||
2240,
|
|
||||||
2304,
|
|
||||||
2368,
|
|
||||||
2432,
|
|
||||||
2560,
|
|
||||||
2752,
|
|
||||||
3072,
|
|
||||||
3328,
|
|
||||||
3456,
|
|
||||||
3584,
|
|
||||||
3712,
|
|
||||||
4096,
|
|
||||||
4480,
|
|
||||||
4608,
|
|
||||||
4736,
|
|
||||||
4864,
|
|
||||||
5120,
|
|
||||||
5504,
|
|
||||||
5632,
|
|
||||||
5888,
|
|
||||||
6144,
|
|
||||||
6400,
|
|
||||||
6848,
|
|
||||||
6912,
|
|
||||||
7168,
|
|
||||||
7424,
|
|
||||||
8192,
|
|
||||||
8960,
|
|
||||||
9216,
|
|
||||||
9472,
|
|
||||||
10240,
|
|
||||||
11008,
|
|
||||||
11264,
|
|
||||||
13824,
|
|
||||||
14336,
|
|
||||||
14784,
|
|
||||||
14848,
|
|
||||||
15360,
|
|
||||||
18944,
|
|
||||||
22016,
|
|
||||||
22528,
|
|
||||||
24576,
|
|
||||||
27392,
|
|
||||||
27648,
|
|
||||||
29568,
|
|
||||||
29696,
|
|
||||||
32000,
|
|
||||||
32256,
|
|
||||||
32512,
|
|
||||||
32768,
|
|
||||||
33024,
|
|
||||||
36864,
|
|
||||||
43264,
|
|
||||||
49152,
|
|
||||||
49408,
|
|
||||||
60544,
|
|
||||||
60672,
|
|
||||||
64000,
|
|
||||||
64256,
|
|
||||||
102400,
|
|
||||||
102656,
|
|
||||||
128000,
|
|
||||||
128256,
|
|
||||||
]
|
|
||||||
#The size of TP
|
|
||||||
divisibility = [1, 2, 8, 16, 64]
|
|
||||||
|
|
||||||
all_hidden_size = []
|
|
||||||
for div in divisibility:
|
|
||||||
for hidden_size in HIDDEN_SIZES:
|
|
||||||
all_hidden_size.append(hidden_size // div)
|
|
||||||
|
|
||||||
HIDDEN_SIZES = list(set(all_hidden_size))
|
|
||||||
|
|
||||||
BATCHES = [4]
|
|
||||||
NUM_LORA = [4]
|
|
||||||
DTYPES = [torch.float16, torch.bfloat16]
|
|
||||||
MAX_RANKS = [32]
|
|
||||||
SCALES = [0.5]
|
|
||||||
SEED = [0]
|
|
||||||
DEVICES = [f"cuda:{0}"]
|
|
||||||
|
|
||||||
_dict_lock = Lock()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batches", BATCHES)
|
|
||||||
@pytest.mark.parametrize("num_loras", NUM_LORA)
|
|
||||||
@pytest.mark.parametrize("rank", MAX_RANKS)
|
|
||||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
|
||||||
@pytest.mark.parametrize("scaling", SCALES)
|
|
||||||
@pytest.mark.parametrize("nslices", [1, 2, 3])
|
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
|
||||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
|
||||||
@pytest.mark.parametrize("seed", SEED)
|
|
||||||
@pytest.mark.parametrize("device", DEVICES)
|
|
||||||
def test_punica_sgmv(
|
|
||||||
batches: int,
|
|
||||||
num_loras: int,
|
|
||||||
rank: int,
|
|
||||||
hidden_size: int,
|
|
||||||
scaling: float,
|
|
||||||
nslices: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
op_type: str,
|
|
||||||
seed: int,
|
|
||||||
device: str,
|
|
||||||
):
|
|
||||||
torch.set_default_device(device)
|
|
||||||
current_platform.seed_everything(seed)
|
|
||||||
|
|
||||||
seq_length = 128
|
|
||||||
(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights_lst,
|
|
||||||
our_out_tensor,
|
|
||||||
ref_out_tensor,
|
|
||||||
b_seq_start_loc,
|
|
||||||
lora_indices_tensor,
|
|
||||||
seq_len_tensor,
|
|
||||||
indices,
|
|
||||||
) = generate_data_for_nslices(
|
|
||||||
batches,
|
|
||||||
hidden_size,
|
|
||||||
num_loras,
|
|
||||||
rank,
|
|
||||||
seq_length,
|
|
||||||
nslices,
|
|
||||||
dtype,
|
|
||||||
op_type,
|
|
||||||
device,
|
|
||||||
)
|
|
||||||
max_seq_length = seq_len_tensor.max()
|
|
||||||
token_nums = seq_len_tensor.sum().item()
|
|
||||||
if isinstance(max_seq_length, tuple):
|
|
||||||
max_seq_length = max_seq_length[0].item()
|
|
||||||
else:
|
|
||||||
max_seq_length = max_seq_length.item()
|
|
||||||
if op_type == "shrink":
|
|
||||||
# Preventing cache error pointer.
|
|
||||||
with _dict_lock:
|
|
||||||
_LORA_A_PTR_DICT.clear()
|
|
||||||
torch.ops.vllm.sgmv_shrink(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights_lst,
|
|
||||||
our_out_tensor,
|
|
||||||
b_seq_start_loc,
|
|
||||||
seq_len_tensor,
|
|
||||||
lora_indices_tensor,
|
|
||||||
batches,
|
|
||||||
max_seq_length,
|
|
||||||
token_nums,
|
|
||||||
scaling,
|
|
||||||
)
|
|
||||||
for index in range(nslices):
|
|
||||||
sgmv_shrink(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights_lst[index],
|
|
||||||
ref_out_tensor[index],
|
|
||||||
b_seq_start_loc,
|
|
||||||
seq_len_tensor,
|
|
||||||
lora_indices_tensor,
|
|
||||||
batches,
|
|
||||||
max_seq_length,
|
|
||||||
token_nums,
|
|
||||||
scaling,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
with _dict_lock:
|
|
||||||
_LORA_B_PTR_DICT.clear()
|
|
||||||
torch.ops.vllm.sgmv_expand(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights_lst,
|
|
||||||
our_out_tensor,
|
|
||||||
b_seq_start_loc,
|
|
||||||
seq_len_tensor,
|
|
||||||
lora_indices_tensor,
|
|
||||||
batches,
|
|
||||||
max_seq_length,
|
|
||||||
token_nums,
|
|
||||||
offset_start=0,
|
|
||||||
add_inputs=True,
|
|
||||||
)
|
|
||||||
if nslices == 1:
|
|
||||||
# Verify the torch's sgmv_expand op
|
|
||||||
sgmv_expand(
|
|
||||||
inputs_tensor[0],
|
|
||||||
lora_weights_lst[0],
|
|
||||||
ref_out_tensor,
|
|
||||||
b_seq_start_loc,
|
|
||||||
seq_len_tensor,
|
|
||||||
lora_indices_tensor,
|
|
||||||
batches,
|
|
||||||
max_seq_length,
|
|
||||||
token_nums,
|
|
||||||
add_inputs=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
slice_offset = 0
|
|
||||||
for index in range(nslices):
|
|
||||||
lora_weights = lora_weights_lst[index]
|
|
||||||
sgmv_expand_slice(
|
|
||||||
inputs_tensor[index],
|
|
||||||
lora_weights,
|
|
||||||
ref_out_tensor,
|
|
||||||
b_seq_start_loc,
|
|
||||||
seq_len_tensor,
|
|
||||||
lora_indices_tensor,
|
|
||||||
batches,
|
|
||||||
max_seq_length,
|
|
||||||
token_nums,
|
|
||||||
slice_offset,
|
|
||||||
hidden_size,
|
|
||||||
add_inputs=True,
|
|
||||||
)
|
|
||||||
slice_offset += hidden_size
|
|
||||||
|
|
||||||
assert_close(our_out_tensor, ref_out_tensor)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batches", BATCHES)
|
|
||||||
@pytest.mark.parametrize("num_loras", NUM_LORA)
|
|
||||||
@pytest.mark.parametrize("rank", MAX_RANKS)
|
|
||||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
|
||||||
@pytest.mark.parametrize("scaling", SCALES)
|
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
|
||||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
|
||||||
@pytest.mark.parametrize("seed", SEED)
|
|
||||||
@pytest.mark.parametrize("device", DEVICES)
|
|
||||||
def test_punica_bgmv(
|
|
||||||
batches: int,
|
|
||||||
num_loras: int,
|
|
||||||
rank: int,
|
|
||||||
hidden_size: int,
|
|
||||||
scaling: float,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
op_type: str,
|
|
||||||
seed: int,
|
|
||||||
device: str,
|
|
||||||
):
|
|
||||||
torch.set_default_device(device)
|
|
||||||
current_platform.seed_everything(seed)
|
|
||||||
|
|
||||||
seq_length = 1
|
|
||||||
(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights,
|
|
||||||
our_out_tensor,
|
|
||||||
ref_out_tensor,
|
|
||||||
b_seq_start_loc,
|
|
||||||
lora_indices_tensor,
|
|
||||||
seq_len_tensor,
|
|
||||||
indices,
|
|
||||||
) = generate_data(
|
|
||||||
batches,
|
|
||||||
hidden_size,
|
|
||||||
num_loras,
|
|
||||||
rank,
|
|
||||||
seq_length,
|
|
||||||
dtype,
|
|
||||||
op_type,
|
|
||||||
device,
|
|
||||||
)
|
|
||||||
if op_type == "shrink":
|
|
||||||
torch.ops.vllm.bgmv_shrink(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights,
|
|
||||||
our_out_tensor,
|
|
||||||
indices,
|
|
||||||
scaling,
|
|
||||||
)
|
|
||||||
|
|
||||||
bgmv_shrink(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights,
|
|
||||||
ref_out_tensor,
|
|
||||||
indices,
|
|
||||||
scaling,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
torch.ops.vllm.bgmv_expand(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights,
|
|
||||||
our_out_tensor,
|
|
||||||
indices,
|
|
||||||
add_inputs=True,
|
|
||||||
)
|
|
||||||
bgmv_expand(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights,
|
|
||||||
ref_out_tensor,
|
|
||||||
indices,
|
|
||||||
add_inputs=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if op_type == "shrink":
|
|
||||||
ref_out_tensor = ref_out_tensor.to(torch.float32)
|
|
||||||
assert_close(our_out_tensor, ref_out_tensor)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batches", BATCHES)
|
|
||||||
@pytest.mark.parametrize("num_loras", NUM_LORA)
|
|
||||||
@pytest.mark.parametrize("rank", MAX_RANKS)
|
|
||||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
|
||||||
@pytest.mark.parametrize("nslices", [2, 3])
|
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
|
||||||
@pytest.mark.parametrize("seed", SEED)
|
|
||||||
@pytest.mark.parametrize("device", DEVICES)
|
|
||||||
def test_punica_bgmv_expand_nslices(
|
|
||||||
batches: int,
|
|
||||||
num_loras: int,
|
|
||||||
rank: int,
|
|
||||||
hidden_size: int,
|
|
||||||
nslices: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
seed: int,
|
|
||||||
device: str,
|
|
||||||
):
|
|
||||||
torch.set_default_device(device)
|
|
||||||
current_platform.seed_everything(seed)
|
|
||||||
|
|
||||||
seq_length = 1
|
|
||||||
(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights_lst,
|
|
||||||
our_outputs,
|
|
||||||
ref_outputs,
|
|
||||||
b_seq_start_loc,
|
|
||||||
lora_indices_tensor,
|
|
||||||
seq_len_tensor,
|
|
||||||
indices,
|
|
||||||
) = generate_data_for_expand_nslices(
|
|
||||||
batches,
|
|
||||||
hidden_size,
|
|
||||||
num_loras,
|
|
||||||
rank,
|
|
||||||
seq_length,
|
|
||||||
dtype,
|
|
||||||
nslices,
|
|
||||||
device,
|
|
||||||
)
|
|
||||||
slice_offset = 0
|
|
||||||
for index in range(nslices):
|
|
||||||
lora_weights = lora_weights_lst[index]
|
|
||||||
torch.ops.vllm.bgmv_expand_slice(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights,
|
|
||||||
our_outputs,
|
|
||||||
indices,
|
|
||||||
slice_offset,
|
|
||||||
slice_size=hidden_size,
|
|
||||||
add_inputs=True,
|
|
||||||
)
|
|
||||||
bgmv_expand_slice(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights,
|
|
||||||
ref_outputs,
|
|
||||||
indices,
|
|
||||||
slice_offset,
|
|
||||||
slice_size=hidden_size,
|
|
||||||
add_inputs=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
slice_offset += hidden_size
|
|
||||||
assert_close(our_outputs, ref_outputs)
|
|
||||||
@@ -1,317 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
"""
|
|
||||||
This script is mainly used to test whether trtion kernels can run normally
|
|
||||||
under different conditions, including various batches, numbers of LoRA , and
|
|
||||||
maximum ranks.
|
|
||||||
"""
|
|
||||||
from threading import Lock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# Enable custom op register
|
|
||||||
import vllm.lora.ops.triton_ops # noqa: F401
|
|
||||||
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
|
|
||||||
bgmv_shrink, sgmv_expand,
|
|
||||||
sgmv_expand_slice, sgmv_shrink)
|
|
||||||
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
from .utils import (assert_close, generate_data,
|
|
||||||
generate_data_for_expand_nslices,
|
|
||||||
generate_data_for_nslices)
|
|
||||||
|
|
||||||
HIDDEN_SIZES = [2049]
|
|
||||||
|
|
||||||
BATCHES = [1, 4, 16, 32]
|
|
||||||
NUM_LORA = [1, 8, 32, 128]
|
|
||||||
DTYPES = [torch.float16, torch.bfloat16]
|
|
||||||
MAX_RANKS = [1, 4, 8, 16, 32, 64, 128, 256]
|
|
||||||
SCALES = [0.5]
|
|
||||||
SEED = [0]
|
|
||||||
DEVICES = [f"cuda:{0}"]
|
|
||||||
|
|
||||||
_dict_lock = Lock()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batches", BATCHES)
|
|
||||||
@pytest.mark.parametrize("num_loras", NUM_LORA)
|
|
||||||
@pytest.mark.parametrize("rank", MAX_RANKS)
|
|
||||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
|
||||||
@pytest.mark.parametrize("scaling", SCALES)
|
|
||||||
@pytest.mark.parametrize("nslices", [1, 2, 3])
|
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
|
||||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
|
||||||
@pytest.mark.parametrize("seed", SEED)
|
|
||||||
@pytest.mark.parametrize("device", DEVICES)
|
|
||||||
def test_punica_sgmv(
|
|
||||||
batches: int,
|
|
||||||
num_loras: int,
|
|
||||||
rank: int,
|
|
||||||
hidden_size: int,
|
|
||||||
scaling: float,
|
|
||||||
nslices: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
op_type: str,
|
|
||||||
seed: int,
|
|
||||||
device: str,
|
|
||||||
):
|
|
||||||
torch.set_default_device(device)
|
|
||||||
current_platform.seed_everything(seed)
|
|
||||||
|
|
||||||
seq_length = 128
|
|
||||||
(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights_lst,
|
|
||||||
our_out_tensor,
|
|
||||||
ref_out_tensor,
|
|
||||||
b_seq_start_loc,
|
|
||||||
lora_indices_tensor,
|
|
||||||
seq_len_tensor,
|
|
||||||
indices,
|
|
||||||
) = generate_data_for_nslices(
|
|
||||||
batches,
|
|
||||||
hidden_size,
|
|
||||||
num_loras,
|
|
||||||
rank,
|
|
||||||
seq_length,
|
|
||||||
nslices,
|
|
||||||
dtype,
|
|
||||||
op_type,
|
|
||||||
device,
|
|
||||||
)
|
|
||||||
max_seq_length = seq_len_tensor.max()
|
|
||||||
token_nums = seq_len_tensor.sum().item()
|
|
||||||
if isinstance(max_seq_length, tuple):
|
|
||||||
max_seq_length = max_seq_length[0].item()
|
|
||||||
else:
|
|
||||||
max_seq_length = max_seq_length.item()
|
|
||||||
if op_type == "shrink":
|
|
||||||
# Preventing cache error pointer.
|
|
||||||
with _dict_lock:
|
|
||||||
_LORA_A_PTR_DICT.clear()
|
|
||||||
torch.ops.vllm.sgmv_shrink(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights_lst,
|
|
||||||
our_out_tensor,
|
|
||||||
b_seq_start_loc,
|
|
||||||
seq_len_tensor,
|
|
||||||
lora_indices_tensor,
|
|
||||||
batches,
|
|
||||||
max_seq_length,
|
|
||||||
token_nums,
|
|
||||||
scaling,
|
|
||||||
)
|
|
||||||
for index in range(nslices):
|
|
||||||
sgmv_shrink(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights_lst[index],
|
|
||||||
ref_out_tensor[index],
|
|
||||||
b_seq_start_loc,
|
|
||||||
seq_len_tensor,
|
|
||||||
lora_indices_tensor,
|
|
||||||
batches,
|
|
||||||
max_seq_length,
|
|
||||||
token_nums,
|
|
||||||
scaling,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
with _dict_lock:
|
|
||||||
_LORA_B_PTR_DICT.clear()
|
|
||||||
torch.ops.vllm.sgmv_expand(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights_lst,
|
|
||||||
our_out_tensor,
|
|
||||||
b_seq_start_loc,
|
|
||||||
seq_len_tensor,
|
|
||||||
lora_indices_tensor,
|
|
||||||
batches,
|
|
||||||
max_seq_length,
|
|
||||||
token_nums,
|
|
||||||
offset_start=0,
|
|
||||||
add_inputs=True,
|
|
||||||
)
|
|
||||||
slice_offset = 0
|
|
||||||
if nslices == 1:
|
|
||||||
# Verify the torch's sgmv_expand op
|
|
||||||
sgmv_expand(
|
|
||||||
inputs_tensor[0],
|
|
||||||
lora_weights_lst[0],
|
|
||||||
ref_out_tensor,
|
|
||||||
b_seq_start_loc,
|
|
||||||
seq_len_tensor,
|
|
||||||
lora_indices_tensor,
|
|
||||||
batches,
|
|
||||||
max_seq_length,
|
|
||||||
token_nums,
|
|
||||||
add_inputs=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
for index in range(nslices):
|
|
||||||
lora_weights = lora_weights_lst[index]
|
|
||||||
sgmv_expand_slice(
|
|
||||||
inputs_tensor[index],
|
|
||||||
lora_weights,
|
|
||||||
ref_out_tensor,
|
|
||||||
b_seq_start_loc,
|
|
||||||
seq_len_tensor,
|
|
||||||
lora_indices_tensor,
|
|
||||||
batches,
|
|
||||||
max_seq_length,
|
|
||||||
token_nums,
|
|
||||||
slice_offset,
|
|
||||||
hidden_size,
|
|
||||||
add_inputs=True,
|
|
||||||
)
|
|
||||||
slice_offset += hidden_size
|
|
||||||
|
|
||||||
assert_close(our_out_tensor, ref_out_tensor)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batches", BATCHES)
|
|
||||||
@pytest.mark.parametrize("num_loras", NUM_LORA)
|
|
||||||
@pytest.mark.parametrize("rank", MAX_RANKS)
|
|
||||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
|
||||||
@pytest.mark.parametrize("scaling", SCALES)
|
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
|
||||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
|
||||||
@pytest.mark.parametrize("seed", SEED)
|
|
||||||
@pytest.mark.parametrize("device", DEVICES)
|
|
||||||
def test_punica_bgmv(
|
|
||||||
batches: int,
|
|
||||||
num_loras: int,
|
|
||||||
rank: int,
|
|
||||||
hidden_size: int,
|
|
||||||
scaling: float,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
op_type: str,
|
|
||||||
seed: int,
|
|
||||||
device: str,
|
|
||||||
):
|
|
||||||
torch.set_default_device(device)
|
|
||||||
current_platform.seed_everything(seed)
|
|
||||||
|
|
||||||
seq_length = 1
|
|
||||||
(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights,
|
|
||||||
our_out_tensor,
|
|
||||||
ref_out_tensor,
|
|
||||||
b_seq_start_loc,
|
|
||||||
lora_indices_tensor,
|
|
||||||
seq_len_tensor,
|
|
||||||
indices,
|
|
||||||
) = generate_data(
|
|
||||||
batches,
|
|
||||||
hidden_size,
|
|
||||||
num_loras,
|
|
||||||
rank,
|
|
||||||
seq_length,
|
|
||||||
dtype,
|
|
||||||
op_type,
|
|
||||||
device,
|
|
||||||
)
|
|
||||||
if op_type == "shrink":
|
|
||||||
torch.ops.vllm.bgmv_shrink(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights,
|
|
||||||
our_out_tensor,
|
|
||||||
indices,
|
|
||||||
scaling,
|
|
||||||
)
|
|
||||||
|
|
||||||
bgmv_shrink(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights,
|
|
||||||
ref_out_tensor,
|
|
||||||
indices,
|
|
||||||
scaling,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
torch.ops.vllm.bgmv_expand(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights,
|
|
||||||
our_out_tensor,
|
|
||||||
indices,
|
|
||||||
add_inputs=True,
|
|
||||||
)
|
|
||||||
bgmv_expand(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights,
|
|
||||||
ref_out_tensor,
|
|
||||||
indices,
|
|
||||||
add_inputs=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if op_type == "shrink":
|
|
||||||
ref_out_tensor = ref_out_tensor.to(torch.float32)
|
|
||||||
assert_close(our_out_tensor, ref_out_tensor)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batches", BATCHES)
|
|
||||||
@pytest.mark.parametrize("num_loras", NUM_LORA)
|
|
||||||
@pytest.mark.parametrize("rank", MAX_RANKS)
|
|
||||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
|
||||||
@pytest.mark.parametrize("nslices", [2, 3])
|
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
|
||||||
@pytest.mark.parametrize("seed", SEED)
|
|
||||||
@pytest.mark.parametrize("device", DEVICES)
|
|
||||||
def test_punica_bgmv_expand_nslices(
|
|
||||||
batches: int,
|
|
||||||
num_loras: int,
|
|
||||||
rank: int,
|
|
||||||
hidden_size: int,
|
|
||||||
nslices: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
seed: int,
|
|
||||||
device: str,
|
|
||||||
):
|
|
||||||
torch.set_default_device(device)
|
|
||||||
current_platform.seed_everything(seed)
|
|
||||||
|
|
||||||
seq_length = 1
|
|
||||||
(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights_lst,
|
|
||||||
our_outputs,
|
|
||||||
ref_outputs,
|
|
||||||
b_seq_start_loc,
|
|
||||||
lora_indices_tensor,
|
|
||||||
seq_len_tensor,
|
|
||||||
indices,
|
|
||||||
) = generate_data_for_expand_nslices(
|
|
||||||
batches,
|
|
||||||
hidden_size,
|
|
||||||
num_loras,
|
|
||||||
rank,
|
|
||||||
seq_length,
|
|
||||||
dtype,
|
|
||||||
nslices,
|
|
||||||
device,
|
|
||||||
)
|
|
||||||
slice_offset = 0
|
|
||||||
for index in range(nslices):
|
|
||||||
lora_weights = lora_weights_lst[index]
|
|
||||||
torch.ops.vllm.bgmv_expand_slice(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights,
|
|
||||||
our_outputs,
|
|
||||||
indices,
|
|
||||||
slice_offset,
|
|
||||||
slice_size=hidden_size,
|
|
||||||
add_inputs=True,
|
|
||||||
)
|
|
||||||
bgmv_expand_slice(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights,
|
|
||||||
ref_outputs,
|
|
||||||
indices,
|
|
||||||
slice_offset,
|
|
||||||
slice_size=hidden_size,
|
|
||||||
add_inputs=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
slice_offset += hidden_size
|
|
||||||
assert_close(our_outputs, ref_outputs)
|
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Dict, List, Optional
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -106,6 +107,31 @@ def assert_close(a, b):
|
|||||||
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
|
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PunicaTensors:
|
||||||
|
inputs_tensor: torch.Tensor
|
||||||
|
lora_weights: Union[torch.Tensor, List[torch.Tensor]]
|
||||||
|
our_out_tensor: torch.Tensor
|
||||||
|
ref_out_tensor: torch.Tensor
|
||||||
|
b_seq_start_loc: torch.Tensor
|
||||||
|
prompt_lora_mapping: torch.Tensor
|
||||||
|
seq_len_tensor: torch.Tensor
|
||||||
|
token_lora_mapping: torch.Tensor
|
||||||
|
|
||||||
|
def meta(self) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Infer max_seq_length and token_nums from the tensors
|
||||||
|
and return them.
|
||||||
|
"""
|
||||||
|
max_seq_length = self.seq_len_tensor.max()
|
||||||
|
token_nums = self.seq_len_tensor.sum().item()
|
||||||
|
if isinstance(max_seq_length, tuple):
|
||||||
|
max_seq_length = max_seq_length[0].item()
|
||||||
|
else:
|
||||||
|
max_seq_length = max_seq_length.item()
|
||||||
|
return max_seq_length, token_nums
|
||||||
|
|
||||||
|
|
||||||
def generate_data(
|
def generate_data(
|
||||||
batches,
|
batches,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
@@ -115,7 +141,7 @@ def generate_data(
|
|||||||
dtype,
|
dtype,
|
||||||
op_type,
|
op_type,
|
||||||
device,
|
device,
|
||||||
):
|
) -> PunicaTensors:
|
||||||
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
|
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
|
||||||
(batches, )).to(device)
|
(batches, )).to(device)
|
||||||
b_seq_start_loc = torch.cumsum(
|
b_seq_start_loc = torch.cumsum(
|
||||||
@@ -164,7 +190,8 @@ def generate_data(
|
|||||||
indices[current_offset:current_offset +
|
indices[current_offset:current_offset +
|
||||||
seq_len_tensor[b_id]].copy_(lora_index)
|
seq_len_tensor[b_id]].copy_(lora_index)
|
||||||
current_offset += seq_len_tensor[b_id].item()
|
current_offset += seq_len_tensor[b_id].item()
|
||||||
return (
|
|
||||||
|
return PunicaTensors(
|
||||||
inputs_tensor,
|
inputs_tensor,
|
||||||
lora_weights,
|
lora_weights,
|
||||||
our_out_tensor,
|
our_out_tensor,
|
||||||
@@ -185,7 +212,7 @@ def generate_data_for_expand_nslices(
|
|||||||
dtype,
|
dtype,
|
||||||
nslices,
|
nslices,
|
||||||
device,
|
device,
|
||||||
):
|
) -> PunicaTensors:
|
||||||
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
|
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
|
||||||
(batches, )).to(device)
|
(batches, )).to(device)
|
||||||
b_seq_start_loc = torch.cumsum(
|
b_seq_start_loc = torch.cumsum(
|
||||||
@@ -222,7 +249,7 @@ def generate_data_for_expand_nslices(
|
|||||||
current_offset += seq_len_tensor[b_id].item()
|
current_offset += seq_len_tensor[b_id].item()
|
||||||
|
|
||||||
lora_indices_tensor = lora_indices_tensor.to(device)
|
lora_indices_tensor = lora_indices_tensor.to(device)
|
||||||
return (
|
return PunicaTensors(
|
||||||
inputs_tensor,
|
inputs_tensor,
|
||||||
lora_weights_lst,
|
lora_weights_lst,
|
||||||
our_out_tensor,
|
our_out_tensor,
|
||||||
@@ -244,7 +271,7 @@ def generate_data_for_nslices(
|
|||||||
dtype,
|
dtype,
|
||||||
op_type,
|
op_type,
|
||||||
device,
|
device,
|
||||||
):
|
) -> PunicaTensors:
|
||||||
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
|
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
|
||||||
(batches, )).to(device)
|
(batches, )).to(device)
|
||||||
b_seq_start_loc = torch.cumsum(
|
b_seq_start_loc = torch.cumsum(
|
||||||
@@ -302,7 +329,7 @@ def generate_data_for_nslices(
|
|||||||
current_offset += seq_len_tensor[b_id].item()
|
current_offset += seq_len_tensor[b_id].item()
|
||||||
|
|
||||||
lora_indices_tensor = lora_indices_tensor.to(device)
|
lora_indices_tensor = lora_indices_tensor.to(device)
|
||||||
return (
|
return PunicaTensors(
|
||||||
inputs_tensor,
|
inputs_tensor,
|
||||||
lora_weights_lst,
|
lora_weights_lst,
|
||||||
our_out_tensor,
|
our_out_tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user