[Kernel][LoRA]Punica prefill kernels fusion (#11234)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Signed-off-by: Abatom <abzhonghua@gmail.com>
Co-authored-by: Zhonghua Deng <abatom@163.com>
This commit is contained in:
Jee Jee Li
2025-01-07 12:01:39 +08:00
committed by GitHub
parent 8ceffbf315
commit b278557935
11 changed files with 710 additions and 767 deletions

View File

@@ -5,7 +5,7 @@ Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import Callable, Optional, Tuple, Union, final
from typing import Optional, Tuple, Union, final
import torch
@@ -16,7 +16,6 @@ if HAS_TRITON:
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
from .punica_base import PunicaWrapperBase
@@ -35,11 +34,11 @@ class PunicaWrapperGPU(PunicaWrapperBase):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
device)
def _shrink_prefill(
def _apply_shrink_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
w_t_all: Tuple[torch.Tensor, ...],
scale: float,
):
#No LoRA request, so return directly
@@ -53,7 +52,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
scale,
)
def _shrink_decode(
def _apply_shrink_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
@@ -62,56 +61,28 @@ class PunicaWrapperGPU(PunicaWrapperBase):
):
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
def _expand_prefill(
def _apply_expand_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
offset_start: int,
add_inputs: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand(
x,
w_t_all,
y,
*self.prefill_metadata,
add_inputs,
offset_start=offset_start,
add_inputs=add_inputs,
)
def _expand_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_inputs: bool,
):
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs)
def _expand_slice_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: Optional[int],
y_slice_size: Optional[int],
add_inputs: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand_slice(
x,
w_t_all,
y,
*self.prefill_metadata,
y_offset,
y_slice_size,
add_inputs,
)
def _expand_slice_decode(
def _apply_expand_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
@@ -123,43 +94,6 @@ class PunicaWrapperGPU(PunicaWrapperBase):
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
y_slice_size, add_inputs)
def _apply_expand(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: Optional[int],
y_slice_size: Optional[int],
add_inputs: bool = True,
):
"""
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
computation, which is suitable for the
GEMM of lora'b.
"""
expand_slice_fun: Callable = (self._expand_slice_prefill
if self.is_prefill else
self._expand_slice_decode)
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)
def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, scale: float):
"""
Perform the ` y+=x@w_t_all` computation, which is suitable for the
GEMM of lora'a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
"""
y_org = y
y = y.view(-1, y.shape[-1])
shrink_fun: Callable = (self._shrink_prefill
if self.is_prefill else self._shrink_decode)
shrink_fun(y, x, w_t_all, scale)
y = y.view_as(y_org)
def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...],
scale: float, **kwargs):
@@ -182,10 +116,15 @@ class PunicaWrapperGPU(PunicaWrapperBase):
"""
x = x.view(-1, x.shape[-1])
# TODO fuse these kernels
for slice_idx in range(len(lora_a_stacked)):
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx],
scale)
if self.is_prefill:
# NOTE fused kernel
self._apply_shrink_prefill(y, x, lora_a_stacked, scale)
else:
# TODO fuse these kernels
for slice_idx in range(len(lora_a_stacked)):
self._apply_shrink_decode(y[slice_idx], x,
lora_a_stacked[slice_idx], scale)
def add_expand(self,
y: torch.Tensor,
@@ -217,20 +156,28 @@ class PunicaWrapperGPU(PunicaWrapperBase):
"""
y_org = y
y = y.view(-1, y.shape[-1])
offset_left = offset_start
if lora_bias_stacked is not None:
self._apply_bias(self.token_lora_indices, y, output_slices,
lora_bias_stacked)
for slice_idx in range(len(lora_b_stacked)):
self._apply_expand(
y,
x[slice_idx],
lora_b_stacked[slice_idx],
offset_left,
output_slices[slice_idx],
add_inputs=add_inputs,
)
offset_left += output_slices[slice_idx]
if self.is_prefill:
# NOTE fused kernel
self._apply_expand_prefill(y,
x,
lora_b_stacked,
offset_start,
add_inputs=True)
else:
# TODO fuse these kernels
for slice_idx in range(len(lora_b_stacked)):
self._apply_expand_decode(
y,
x[slice_idx],
lora_b_stacked[slice_idx],
offset_start,
output_slices[slice_idx],
add_inputs=add_inputs,
)
offset_start += output_slices[slice_idx]
y = y.view_as(y_org)
def add_lora_embedding(self,
@@ -252,10 +199,18 @@ class PunicaWrapperGPU(PunicaWrapperBase):
add_inputs (bool): Default to True.
"""
# Embedding layer only need expand op
expand_fun: Callable = (self._expand_prefill
if self.is_prefill else self._expand_decode)
expand_fun(y, x, lora_b_stacked, add_inputs)
if self.is_prefill:
sgmv_expand(
x.unsqueeze(dim=0),
[lora_b_stacked],
y,
*self.prefill_metadata,
offset_start=0,
add_inputs=add_inputs,
)
else:
bgmv_expand(x, lora_b_stacked, y, self.token_lora_indices,
add_inputs)
def add_lora_linear(self,
y: torch.Tensor,
@@ -301,10 +256,11 @@ class PunicaWrapperGPU(PunicaWrapperBase):
r = lora_b_stacked[0].size(-1)
# We set the buffer to be float32 by default ,refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer = tuple(
torch.zeros(
(x.size(0), r), dtype=torch.float32, device=x.device)
for _ in range(len(output_slices)))
buffer = torch.zeros(
(len(output_slices), x.size(0), r),
dtype=torch.float32,
device=x.device,
)
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
self.add_expand(y,
buffer,