[Kernel][LoRA]Punica prefill kernels fusion (#11234)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: Abatom <abzhonghua@gmail.com> Co-authored-by: Zhonghua Deng <abatom@163.com>
This commit is contained in:
@@ -4,6 +4,8 @@ hidden_sizes included in the LoRA models currently supported by vLLM. It tests
|
||||
whether the corresponding Triton kernel can run normally when tensor parallelism
|
||||
is set to [1, 2, 4, 8, 16, 32, 64].
|
||||
"""
|
||||
from threading import Lock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -11,12 +13,13 @@ from vllm.lora.ops.bgmv_expand import bgmv_expand
|
||||
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
|
||||
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
|
||||
from vllm.lora.ops.sgmv_expand import sgmv_expand
|
||||
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
|
||||
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
|
||||
from vllm.lora.ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .utils import (generate_data, generate_data_for_expand_nslices,
|
||||
ref_torch_groupgemm)
|
||||
from .utils import (assert_close, generate_data,
|
||||
generate_data_for_expand_nslices,
|
||||
generate_data_for_nslices, ref_torch_groupgemm)
|
||||
|
||||
HIDDEN_SIZES = [
|
||||
128,
|
||||
@@ -112,14 +115,7 @@ SCALES = [0.5]
|
||||
SEED = [0]
|
||||
CUDA_DEVICES = [f"cuda:{0}"]
|
||||
|
||||
|
||||
def assert_close(a, b):
|
||||
rtol, atol = {
|
||||
torch.float16: (6e-2, 6e-2),
|
||||
torch.bfloat16: (6e-2, 6e-2),
|
||||
torch.float32: (1e-2, 1e-2),
|
||||
}[a.dtype]
|
||||
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
|
||||
_dict_lock = Lock()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", BATCHES)
|
||||
@@ -127,6 +123,7 @@ def assert_close(a, b):
|
||||
@pytest.mark.parametrize("rank", MAX_RANKS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("scaling", SCALES)
|
||||
@pytest.mark.parametrize("nslices", [1, 2, 3])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@@ -137,6 +134,7 @@ def test_punica_sgmv(
|
||||
rank: int,
|
||||
hidden_size: int,
|
||||
scaling: float,
|
||||
nslices: int,
|
||||
dtype: torch.dtype,
|
||||
op_type: str,
|
||||
seed: int,
|
||||
@@ -148,19 +146,20 @@ def test_punica_sgmv(
|
||||
seq_length = 128
|
||||
(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
lora_weights_lst,
|
||||
our_out_tensor,
|
||||
ref_out_tensor,
|
||||
b_seq_start_loc,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
indices,
|
||||
) = generate_data(
|
||||
) = generate_data_for_nslices(
|
||||
batches,
|
||||
hidden_size,
|
||||
num_loras,
|
||||
rank,
|
||||
seq_length,
|
||||
nslices,
|
||||
dtype,
|
||||
op_type,
|
||||
device,
|
||||
@@ -172,43 +171,64 @@ def test_punica_sgmv(
|
||||
else:
|
||||
max_seq_length = max_seq_length.item()
|
||||
if op_type == "shrink":
|
||||
sgmv_shrink(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_out_tensor,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
scaling,
|
||||
)
|
||||
# Preventing cache error pointer.
|
||||
with _dict_lock:
|
||||
_LORA_A_PTR_DICT.clear()
|
||||
sgmv_shrink(
|
||||
inputs_tensor,
|
||||
lora_weights_lst,
|
||||
our_out_tensor,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
scaling,
|
||||
)
|
||||
for index in range(nslices):
|
||||
ref_torch_groupgemm(
|
||||
ref_out_tensor[index],
|
||||
inputs_tensor,
|
||||
lora_weights_lst[index],
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
batches,
|
||||
scaling,
|
||||
op_type,
|
||||
)
|
||||
else:
|
||||
sgmv_expand(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_out_tensor,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
add_inputs=True,
|
||||
)
|
||||
ref_torch_groupgemm(
|
||||
ref_out_tensor,
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
batches,
|
||||
scaling if op_type == "shrink" else 1.0,
|
||||
op_type,
|
||||
)
|
||||
if op_type == "shrink":
|
||||
ref_out_tensor = ref_out_tensor.to(torch.float32)
|
||||
with _dict_lock:
|
||||
_LORA_B_PTR_DICT.clear()
|
||||
sgmv_expand(
|
||||
inputs_tensor,
|
||||
lora_weights_lst,
|
||||
our_out_tensor,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
offset_start=0,
|
||||
add_inputs=True,
|
||||
)
|
||||
|
||||
slice_offset = 0
|
||||
for index in range(nslices):
|
||||
lora_weights = lora_weights_lst[index]
|
||||
ref_torch_groupgemm(
|
||||
ref_out_tensor[:, slice_offset:slice_offset + hidden_size],
|
||||
inputs_tensor[index],
|
||||
lora_weights,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
batches,
|
||||
1.0,
|
||||
op_type,
|
||||
)
|
||||
slice_offset += hidden_size
|
||||
|
||||
assert_close(our_out_tensor, ref_out_tensor)
|
||||
|
||||
|
||||
@@ -292,25 +312,22 @@ def test_punica_bgmv(
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("nslices", [2, 3])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("op_type", ["sgmv", "bgmv"])
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_punica_expand_nslices(
|
||||
def test_punica_bgmv_expand_nslices(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
hidden_size: int,
|
||||
nslices: int,
|
||||
dtype: torch.dtype,
|
||||
op_type: str,
|
||||
seed: int,
|
||||
device: str,
|
||||
):
|
||||
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
seq_length = 128 if op_type == "sgmv" else 1
|
||||
seq_length = 1
|
||||
(
|
||||
inputs_tensor,
|
||||
lora_weights_lst,
|
||||
@@ -330,41 +347,18 @@ def test_punica_expand_nslices(
|
||||
nslices,
|
||||
device,
|
||||
)
|
||||
max_seq_length = seq_len_tensor.max()
|
||||
token_nums = seq_len_tensor.sum().item()
|
||||
if isinstance(max_seq_length, tuple):
|
||||
max_seq_length = max_seq_length[0].item()
|
||||
else:
|
||||
max_seq_length = max_seq_length.item()
|
||||
slice_offset = 0
|
||||
for index in range(nslices):
|
||||
lora_weights = lora_weights_lst[index]
|
||||
if op_type == "sgmv":
|
||||
sgmv_expand_slice(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_outputs,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
slice_offset,
|
||||
hidden_size,
|
||||
add_inputs=True,
|
||||
)
|
||||
else:
|
||||
|
||||
bgmv_expand_slice(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_outputs,
|
||||
indices,
|
||||
slice_offset,
|
||||
slice_size=hidden_size,
|
||||
add_inputs=True,
|
||||
)
|
||||
bgmv_expand_slice(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_outputs,
|
||||
indices,
|
||||
slice_offset,
|
||||
slice_size=hidden_size,
|
||||
add_inputs=True,
|
||||
)
|
||||
ref_torch_groupgemm(
|
||||
ref_outputs[:, slice_offset:slice_offset + hidden_size],
|
||||
inputs_tensor,
|
||||
|
||||
Reference in New Issue
Block a user