[Kernel][LoRA]Punica prefill kernels fusion (#11234)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Abatom <abzhonghua@gmail.com>
Co-authored-by: Zhonghua Deng <abatom@163.com>
This commit is contained in:
Jee Jee Li
2025-01-07 12:01:39 +08:00
committed by GitHub
parent 8ceffbf315
commit b278557935
11 changed files with 710 additions and 767 deletions

View File

@@ -4,6 +4,8 @@ hidden_sizes included in the LoRA models currently supported by vLLM. It tests
whether the corresponding Triton kernel can run normally when tensor parallelism
is set to [1, 2, 4, 8, 16, 32, 64].
"""
from threading import Lock
import pytest
import torch
@@ -11,12 +13,13 @@ from vllm.lora.ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from vllm.lora.ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.platforms import current_platform
from .utils import (generate_data, generate_data_for_expand_nslices,
ref_torch_groupgemm)
from .utils import (assert_close, generate_data,
generate_data_for_expand_nslices,
generate_data_for_nslices, ref_torch_groupgemm)
HIDDEN_SIZES = [
128,
@@ -112,14 +115,7 @@ SCALES = [0.5]
SEED = [0]
CUDA_DEVICES = [f"cuda:{0}"]
def assert_close(a, b):
rtol, atol = {
torch.float16: (6e-2, 6e-2),
torch.bfloat16: (6e-2, 6e-2),
torch.float32: (1e-2, 1e-2),
}[a.dtype]
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
_dict_lock = Lock()
@pytest.mark.parametrize("batches", BATCHES)
@@ -127,6 +123,7 @@ def assert_close(a, b):
@pytest.mark.parametrize("rank", MAX_RANKS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("scaling", SCALES)
@pytest.mark.parametrize("nslices", [1, 2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.parametrize("seed", SEED)
@@ -137,6 +134,7 @@ def test_punica_sgmv(
rank: int,
hidden_size: int,
scaling: float,
nslices: int,
dtype: torch.dtype,
op_type: str,
seed: int,
@@ -148,19 +146,20 @@ def test_punica_sgmv(
seq_length = 128
(
inputs_tensor,
lora_weights,
lora_weights_lst,
our_out_tensor,
ref_out_tensor,
b_seq_start_loc,
lora_indices_tensor,
seq_len_tensor,
indices,
) = generate_data(
) = generate_data_for_nslices(
batches,
hidden_size,
num_loras,
rank,
seq_length,
nslices,
dtype,
op_type,
device,
@@ -172,43 +171,64 @@ def test_punica_sgmv(
else:
max_seq_length = max_seq_length.item()
if op_type == "shrink":
sgmv_shrink(
inputs_tensor,
lora_weights,
our_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
scaling,
)
# Preventing cache error pointer.
with _dict_lock:
_LORA_A_PTR_DICT.clear()
sgmv_shrink(
inputs_tensor,
lora_weights_lst,
our_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
scaling,
)
for index in range(nslices):
ref_torch_groupgemm(
ref_out_tensor[index],
inputs_tensor,
lora_weights_lst[index],
lora_indices_tensor,
seq_len_tensor,
batches,
scaling,
op_type,
)
else:
sgmv_expand(
inputs_tensor,
lora_weights,
our_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
add_inputs=True,
)
ref_torch_groupgemm(
ref_out_tensor,
inputs_tensor,
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
scaling if op_type == "shrink" else 1.0,
op_type,
)
if op_type == "shrink":
ref_out_tensor = ref_out_tensor.to(torch.float32)
with _dict_lock:
_LORA_B_PTR_DICT.clear()
sgmv_expand(
inputs_tensor,
lora_weights_lst,
our_out_tensor,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
offset_start=0,
add_inputs=True,
)
slice_offset = 0
for index in range(nslices):
lora_weights = lora_weights_lst[index]
ref_torch_groupgemm(
ref_out_tensor[:, slice_offset:slice_offset + hidden_size],
inputs_tensor[index],
lora_weights,
lora_indices_tensor,
seq_len_tensor,
batches,
1.0,
op_type,
)
slice_offset += hidden_size
assert_close(our_out_tensor, ref_out_tensor)
@@ -292,25 +312,22 @@ def test_punica_bgmv(
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("nslices", [2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("op_type", ["sgmv", "bgmv"])
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_punica_expand_nslices(
def test_punica_bgmv_expand_nslices(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
nslices: int,
dtype: torch.dtype,
op_type: str,
seed: int,
device: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
seq_length = 128 if op_type == "sgmv" else 1
seq_length = 1
(
inputs_tensor,
lora_weights_lst,
@@ -330,41 +347,18 @@ def test_punica_expand_nslices(
nslices,
device,
)
max_seq_length = seq_len_tensor.max()
token_nums = seq_len_tensor.sum().item()
if isinstance(max_seq_length, tuple):
max_seq_length = max_seq_length[0].item()
else:
max_seq_length = max_seq_length.item()
slice_offset = 0
for index in range(nslices):
lora_weights = lora_weights_lst[index]
if op_type == "sgmv":
sgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
b_seq_start_loc,
seq_len_tensor,
lora_indices_tensor,
batches,
max_seq_length,
token_nums,
slice_offset,
hidden_size,
add_inputs=True,
)
else:
bgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
bgmv_expand_slice(
inputs_tensor,
lora_weights,
our_outputs,
indices,
slice_offset,
slice_size=hidden_size,
add_inputs=True,
)
ref_torch_groupgemm(
ref_outputs[:, slice_offset:slice_offset + hidden_size],
inputs_tensor,