[Kernels] LoRA - Retire SGMV and BGMV Kernels (#14685)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-03-18 05:47:53 -04:00
committed by simon-mo
parent 16e9064f84
commit 9e8f089d08
15 changed files with 245 additions and 2092 deletions

View File

@@ -4,18 +4,13 @@ from threading import Lock
import pytest
import torch
import vllm.lora.ops.triton_ops # noqa: F401
import vllm.lora.ops.triton_ops.v1 # noqa: F401
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)
import vllm.lora.ops.torch_ops as torch_ops
import vllm.lora.ops.triton_ops as triton_ops
from vllm.lora.ops.triton_ops import LoRAKernelMeta
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.lora.ops.triton_ops.v1 import V1KernelMeta
from vllm.platforms import current_platform
from .utils import (PunicaTensors, assert_close, generate_data,
generate_data_for_expand_nslices,
generate_data_for_nslices)
from .utils import PunicaTensors, assert_close, generate_data_for_nslices
# Utility shrink and expand operations used as reference implementations.
@@ -26,10 +21,10 @@ def sgmv_shrink_for_nslices(
prompt_lora_mapping: torch.Tensor, batches: int, max_seq_length: int,
num_tokens: int, scaling: float):
"""
Wrapper around sgmv_shrink that handles any nslices.
Wrapper around torch_ops.sgmv_shrink that handles any nslices.
"""
for index in range(nslices):
sgmv_shrink(
torch_ops.sgmv_shrink(
inputs_tensor,
lora_weights_lst[index],
out_tensor[index],
@@ -53,11 +48,11 @@ def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
max_seq_length: int, num_tokens: int,
add_inputs: bool) -> None:
"""
Wrapper around sgmv_expand that handles any nslices.
Wrapper around torch_ops.sgmv_expand that handles any nslices.
"""
if nslices == 1:
# Verify the torch's sgmv_expand op
sgmv_expand(
torch_ops.sgmv_expand(
inputs_tensor[0],
lora_weights_lst[0],
out_tensor,
@@ -73,7 +68,7 @@ def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
slice_offset = 0
for index in range(nslices):
lora_weights = lora_weights_lst[index]
sgmv_expand_slice(
torch_ops.sgmv_expand_slice(
inputs_tensor[index],
lora_weights,
out_tensor,
@@ -93,12 +88,13 @@ def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
_dict_lock = Lock()
def check_shrink_kernels(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int, dtype: torch.dtype,
device: str, seq_length: int, scaling: float):
def check_lora_shrink_kernel(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int,
dtype: torch.dtype, device: str, seq_length: int,
scaling: float):
"""
Compare outputs of vllm.sgmv_shrink and vllm.v1_shrink kernel against a
reference implementation.
Compare outputs of torch_ops.sgmv_shrink and triton_ops.lora_shrink
kernels.
"""
data: PunicaTensors = generate_data_for_nslices(
batches,
@@ -118,35 +114,24 @@ def check_shrink_kernels(batches: int, num_loras: int, rank: int,
data.prompt_lora_mapping, batches, max_seq_length,
token_nums)
# Setup metadata information for the V1 kernel.
v1_meta = V1KernelMeta.make(max_loras=num_loras,
max_num_tokens=token_nums,
device='cuda')
v1_meta.prepare_tensors(data.token_lora_mapping)
# Setup metadata information for the LoRA kernel.
lora_meta = LoRAKernelMeta.make(max_loras=num_loras,
max_num_tokens=token_nums,
device='cuda')
lora_meta.prepare_tensors(data.token_lora_mapping)
ref_out_tensor = data.ref_out_tensor
sgmv_out_tensor = data.our_out_tensor
v1_out_tensor = data.our_out_tensor.clone()
out_tensor = data.our_out_tensor.clone()
# Preventing cache error pointer.
with _dict_lock:
# SGMV shrink kernel
# lora_shrink kernel
_LORA_A_PTR_DICT.clear()
torch.ops.vllm.sgmv_shrink(
triton_ops.lora_shrink(
data.inputs_tensor,
data.lora_weights,
sgmv_out_tensor,
*sgmv_meta_args,
scaling,
)
# V1 shrink kernel
_LORA_A_PTR_DICT.clear()
torch.ops.vllm.v1_shrink(
data.inputs_tensor,
data.lora_weights,
v1_out_tensor,
*v1_meta.meta_args(token_nums=token_nums),
out_tensor,
*lora_meta.meta_args(token_nums=token_nums),
scaling,
)
@@ -160,16 +145,16 @@ def check_shrink_kernels(batches: int, num_loras: int, rank: int,
scaling,
)
assert_close(sgmv_out_tensor, ref_out_tensor)
assert_close(v1_out_tensor, ref_out_tensor)
assert_close(out_tensor, ref_out_tensor)
def check_expand_kernels(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int, dtype: torch.dtype,
device: str, seq_length: int, add_inputs: bool):
def check_lora_expand_kernel(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int,
dtype: torch.dtype, device: str, seq_length: int,
add_inputs: bool):
"""
Compare outputs of vllm.sgmv_expand and vllm.v1_expand kernels against a
reference implementation.
Compare outputs of torch_ops.sgmv_expand and triton_ops.lora_expand
kernels.
"""
data: PunicaTensors = generate_data_for_nslices(
batches,
@@ -190,37 +175,25 @@ def check_expand_kernels(batches: int, num_loras: int, rank: int,
data.prompt_lora_mapping, batches, max_seq_length,
token_nums)
# Setup metadata information for the V1 kernel.
v1_meta = V1KernelMeta.make(max_loras=num_loras,
max_num_tokens=token_nums,
device='cuda')
v1_meta.prepare_tensors(data.token_lora_mapping)
# Setup metadata information for the LoRA kernel.
lora_meta = LoRAKernelMeta.make(max_loras=num_loras,
max_num_tokens=token_nums,
device='cuda')
lora_meta.prepare_tensors(data.token_lora_mapping)
# Setup output tensors
ref_out_tensor = data.ref_out_tensor
sgmv_out_tensor = data.our_out_tensor
v1_out_tensor = data.our_out_tensor.clone()
out_tensor = data.our_out_tensor.clone()
with _dict_lock:
# SGMV expand kernel
# lora_expand kernel
_LORA_B_PTR_DICT.clear()
torch.ops.vllm.sgmv_expand(
data.inputs_tensor,
data.lora_weights,
sgmv_out_tensor,
*sgmv_meta_args,
offset_start=0,
add_inputs=add_inputs,
)
# V1 expand kernel
_LORA_B_PTR_DICT.clear()
torch.ops.vllm.v1_expand(data.inputs_tensor,
data.lora_weights,
v1_out_tensor,
*v1_meta.meta_args(token_nums=token_nums),
offset_start=0,
add_inputs=add_inputs)
triton_ops.lora_expand(data.inputs_tensor,
data.lora_weights,
out_tensor,
*lora_meta.meta_args(token_nums=token_nums),
offset_start=0,
add_inputs=add_inputs)
# Reference
sgmv_expand_for_nslices(nslices,
@@ -231,124 +204,7 @@ def check_expand_kernels(batches: int, num_loras: int, rank: int,
*sgmv_meta_args,
add_inputs=add_inputs)
assert_close(sgmv_out_tensor, ref_out_tensor)
assert_close(v1_out_tensor, ref_out_tensor)
def check_bgmv_shrink(batches: int, num_loras: int, rank: int,
hidden_size: int, dtype: torch.dtype, device: str,
scaling: float):
"""
Compare vllm.bgmv_shrink against a reference implementation.
"""
seq_length = 1
data: PunicaTensors = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
"shrink",
device,
)
torch.ops.vllm.bgmv_shrink(
data.inputs_tensor,
data.lora_weights,
data.our_out_tensor,
data.token_lora_mapping,
scaling,
)
bgmv_shrink(
data.inputs_tensor,
data.lora_weights,
data.ref_out_tensor,
data.token_lora_mapping,
scaling,
)
data.ref_out_tensor = data.ref_out_tensor.to(torch.float32)
assert_close(data.our_out_tensor, data.ref_out_tensor)
def check_bgmv_expand(batches: int, num_loras: int, rank: int,
hidden_size: int, dtype: torch.dtype, device: str,
add_inputs: bool):
"""
Compare vllm.bgmv_expand against a reference implementation.
"""
seq_length = 1
data: PunicaTensors = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
"expand",
device,
)
torch.ops.vllm.bgmv_expand(
data.inputs_tensor,
data.lora_weights,
data.our_out_tensor,
data.token_lora_mapping,
add_inputs=add_inputs,
)
bgmv_expand(
data.inputs_tensor,
data.lora_weights,
data.ref_out_tensor,
data.token_lora_mapping,
add_inputs=add_inputs,
)
assert_close(data.our_out_tensor, data.ref_out_tensor)
def check_bgmv_expand_slice(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int, dtype: torch.dtype,
device: str, add_inputs: bool):
"""
Compare vllm.bgmv_expand_slice against a reference implementation.
"""
seq_length = 1
data: PunicaTensors = generate_data_for_expand_nslices(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
nslices,
device,
)
slice_offset = 0
for index in range(nslices):
torch.ops.vllm.bgmv_expand_slice(
data.inputs_tensor,
data.lora_weights[index],
data.our_out_tensor,
data.token_lora_mapping,
slice_offset,
slice_size=hidden_size,
add_inputs=add_inputs,
)
bgmv_expand_slice(
data.inputs_tensor,
data.lora_weights[index],
data.ref_out_tensor,
data.token_lora_mapping,
slice_offset,
slice_size=hidden_size,
add_inputs=add_inputs,
)
slice_offset += hidden_size
assert_close(data.our_out_tensor, data.ref_out_tensor)
assert_close(out_tensor, ref_out_tensor)
# Tests
@@ -490,31 +346,31 @@ def test_kernels(
op_type: str,
):
"""
Tests SGMV and V1 kernels.
Tests LoRA kernels.
"""
torch.set_default_device(device)
current_platform.seed_everything(seed)
if op_type == "shrink":
check_shrink_kernels(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
scaling=0.5)
check_lora_shrink_kernel(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
scaling=0.5)
else:
check_expand_kernels(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
add_inputs=True)
check_lora_expand_kernel(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
add_inputs=True)
@pytest.mark.parametrize("batches", hs_test_params['batches'])
@@ -538,159 +394,28 @@ def test_kernels_hidden_size(
op_type: str,
):
"""
Tests SGMV and V1 kernels.
Tests SGMV and LoRA kernels.
"""
torch.set_default_device(device)
current_platform.seed_everything(seed)
if op_type == "shrink":
check_shrink_kernels(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
scaling=0.5)
check_lora_shrink_kernel(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
scaling=0.5)
else:
check_expand_kernels(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
add_inputs=True)
@pytest.mark.parametrize("batches", test_params['batches'])
@pytest.mark.parametrize("num_loras", test_params['num_loras'])
@pytest.mark.parametrize("rank", test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes'])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
def test_punica_bgmv(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
dtype: torch.dtype,
device: str,
seed: int,
op_type: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
if op_type == "shrink":
check_bgmv_shrink(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
dtype=dtype,
device=device,
scaling=0.5)
else:
check_bgmv_expand(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
dtype=dtype,
device=device,
add_inputs=True)
@pytest.mark.parametrize("batches", hs_test_params['batches'])
@pytest.mark.parametrize("num_loras", hs_test_params['num_loras'])
@pytest.mark.parametrize("rank", hs_test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes'])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
def test_punica_bgmv_hidden_size(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
dtype: torch.dtype,
device: str,
seed: int,
op_type: str,
):
torch.set_default_device(device)
current_platform.seed_everything(seed)
if op_type == "shrink":
check_bgmv_shrink(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
dtype=dtype,
device=device,
scaling=0.5)
else:
check_bgmv_expand(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
dtype=dtype,
device=device,
add_inputs=True)
@pytest.mark.parametrize("batches", test_params['batches'])
@pytest.mark.parametrize("num_loras", test_params['num_loras'])
@pytest.mark.parametrize("rank", test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes'])
@pytest.mark.parametrize("nslices", [2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
def test_punica_bgmv_expand_nslices(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int,
dtype: torch.dtype, device: str,
seed: int):
torch.set_default_device(device)
current_platform.seed_everything(seed)
check_bgmv_expand_slice(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
add_inputs=True)
@pytest.mark.parametrize("batches", hs_test_params['batches'])
@pytest.mark.parametrize("num_loras", hs_test_params['num_loras'])
@pytest.mark.parametrize("rank", hs_test_params['max_ranks'])
@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes'])
@pytest.mark.parametrize("nslices", [2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
def test_punica_bgmv_expand_nslices_hidden_size(batches: int, num_loras: int,
rank: int, hidden_size: int,
nslices: int,
dtype: torch.dtype,
device: str, seed: int):
torch.set_default_device(device)
current_platform.seed_everything(seed)
check_bgmv_expand_slice(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
add_inputs=True)
check_lora_expand_kernel(batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
seq_length=128,
add_inputs=True)