[XPU][8/N] Fix kernel bugs in XPU LoRA and MOE LORA (#34115)
Signed-off-by: chzhang <chaojun.zhang@intel.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -18,6 +18,7 @@ from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.lora.ops.triton_ops import fused_moe_lora
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
|
||||
@@ -244,8 +245,9 @@ def use_torch(
|
||||
return torch.stack(outputs, dim=0)
|
||||
|
||||
|
||||
DEVICE_TYPE = current_platform.device_type
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
DEVICES = [f"cuda:{0}"]
|
||||
DEVICES = [f"{DEVICE_TYPE}:{0}"]
|
||||
SEED = [42]
|
||||
|
||||
|
||||
|
||||
298
tests/lora/test_punica_xpu_ops.py
Normal file
298
tests/lora/test_punica_xpu_ops.py
Normal file
@@ -0,0 +1,298 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.lora.utils import (
|
||||
PunicaTensors,
|
||||
assert_close,
|
||||
generate_data,
|
||||
generate_data_for_expand_nslices,
|
||||
)
|
||||
from vllm.lora.ops.xpu_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def torch_bgmv_expand(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
):
|
||||
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
|
||||
if len(selected_loras.shape) == 4:
|
||||
selected_loras = selected_loras.squeeze(dim=1)
|
||||
inputs = inputs.to(dtype=output_tensor.dtype)
|
||||
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
|
||||
|
||||
limit = output_tensor.shape[0]
|
||||
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
|
||||
limit = 1
|
||||
|
||||
# LoRA adapter and model may add different amounts of padding to output
|
||||
common_len = min(outputs.shape[1], output_tensor.shape[1])
|
||||
|
||||
if add_inputs:
|
||||
output_tensor[:, :common_len] += outputs[:limit, :common_len]
|
||||
else:
|
||||
output_tensor[:, :common_len] = outputs[:limit, :common_len]
|
||||
|
||||
|
||||
def torch_bgmv_shrink(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
scaling: float = 1.0,
|
||||
):
|
||||
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
|
||||
if len(selected_loras.shape) == 4:
|
||||
selected_loras = selected_loras.squeeze(dim=1)
|
||||
inputs = inputs.to(dtype=output_tensor.dtype)
|
||||
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
|
||||
|
||||
output_tensor[:, : outputs.shape[1]] = scaling * outputs[:]
|
||||
|
||||
|
||||
def torch_bgmv_expand_slice(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True,
|
||||
):
|
||||
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
|
||||
inputs = inputs.to(dtype=output_tensor.dtype)
|
||||
if len(selected_loras.shape) == 4:
|
||||
selected_loras = selected_loras.squeeze(dim=1)
|
||||
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
|
||||
|
||||
if add_inputs:
|
||||
output_tensor[:, slice_offset : slice_offset + slice_size] += outputs[:]
|
||||
else:
|
||||
output_tensor[:, slice_offset : slice_offset + slice_size] = outputs[:]
|
||||
|
||||
|
||||
def check_bgmv_shrink(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
scaling: float,
|
||||
):
|
||||
"""
|
||||
Compare vllm.bgmv_shrink against a reference implementation.
|
||||
"""
|
||||
seq_length = 1
|
||||
data: PunicaTensors = generate_data(
|
||||
batches,
|
||||
hidden_size,
|
||||
num_loras,
|
||||
rank,
|
||||
seq_length,
|
||||
dtype,
|
||||
"shrink",
|
||||
device,
|
||||
)
|
||||
|
||||
bgmv_shrink(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
data.our_out_tensor,
|
||||
data.token_lora_mapping,
|
||||
scaling,
|
||||
)
|
||||
|
||||
torch_bgmv_shrink(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
data.ref_out_tensor,
|
||||
data.token_lora_mapping,
|
||||
scaling,
|
||||
)
|
||||
|
||||
data.ref_out_tensor = data.ref_out_tensor.to(torch.float32)
|
||||
assert_close(data.our_out_tensor, data.ref_out_tensor)
|
||||
|
||||
|
||||
def check_bgmv_expand(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
add_inputs: bool,
|
||||
):
|
||||
"""
|
||||
Compare vllm.bgmv_expand against a reference implementation.
|
||||
"""
|
||||
seq_length = 1
|
||||
data: PunicaTensors = generate_data(
|
||||
batches,
|
||||
hidden_size,
|
||||
num_loras,
|
||||
rank,
|
||||
seq_length,
|
||||
dtype,
|
||||
"expand",
|
||||
device,
|
||||
)
|
||||
|
||||
bgmv_expand(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
data.our_out_tensor,
|
||||
data.token_lora_mapping,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
torch_bgmv_expand(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
data.ref_out_tensor,
|
||||
data.token_lora_mapping,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
assert_close(data.ref_out_tensor, data.our_out_tensor)
|
||||
|
||||
|
||||
def check_bgmv_expand_slice(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
hidden_size: int,
|
||||
nslices: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
add_inputs: bool,
|
||||
):
|
||||
"""
|
||||
Compare vllm.bgmv_expand_slice against a reference implementation.
|
||||
"""
|
||||
seq_length = 1
|
||||
data: PunicaTensors = generate_data_for_expand_nslices(
|
||||
batches,
|
||||
hidden_size,
|
||||
num_loras,
|
||||
rank,
|
||||
seq_length,
|
||||
dtype,
|
||||
nslices,
|
||||
device,
|
||||
)
|
||||
|
||||
slice_offset = 0
|
||||
for index in range(nslices):
|
||||
bgmv_expand_slice(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights[index],
|
||||
data.our_out_tensor,
|
||||
data.token_lora_mapping,
|
||||
slice_offset,
|
||||
slice_size=hidden_size,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
torch_bgmv_expand_slice(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights[index],
|
||||
data.ref_out_tensor,
|
||||
data.token_lora_mapping,
|
||||
slice_offset,
|
||||
slice_size=hidden_size,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
|
||||
slice_offset += hidden_size
|
||||
assert_close(data.ref_out_tensor, data.our_out_tensor)
|
||||
|
||||
|
||||
# General tests params that tests for variations in all dimensions
|
||||
# except hidden_size.
|
||||
test_params = {
|
||||
"hidden_sizes": [2049],
|
||||
"batches": [4],
|
||||
"num_loras": [4],
|
||||
"max_ranks": [32],
|
||||
}
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
DEVICES = [f"xpu:{0}"]
|
||||
SEED = [0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", test_params["batches"])
|
||||
@pytest.mark.parametrize("num_loras", test_params["num_loras"])
|
||||
@pytest.mark.parametrize("rank", test_params["max_ranks"])
|
||||
@pytest.mark.parametrize("hidden_size", test_params["hidden_sizes"])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||
@pytest.mark.skipif(not current_platform.is_xpu(), reason="skip for non xpu platform")
|
||||
def test_bgmv(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
seed: int,
|
||||
op_type: str,
|
||||
):
|
||||
if op_type == "shrink":
|
||||
check_bgmv_shrink(
|
||||
batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
scaling=0.5,
|
||||
)
|
||||
else:
|
||||
check_bgmv_expand(
|
||||
batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
add_inputs=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", test_params["batches"])
|
||||
@pytest.mark.parametrize("num_loras", test_params["num_loras"])
|
||||
@pytest.mark.parametrize("rank", test_params["max_ranks"])
|
||||
@pytest.mark.parametrize("hidden_size", test_params["hidden_sizes"])
|
||||
@pytest.mark.parametrize("nslices", [2, 3])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.skipif(not current_platform.is_xpu(), reason="skip for non xpu platform")
|
||||
def test_bgmv_expand_nslices(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
hidden_size: int,
|
||||
nslices: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
seed: int,
|
||||
):
|
||||
check_bgmv_expand_slice(
|
||||
batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
add_inputs=True,
|
||||
)
|
||||
@@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.lora.ops.ipex_ops.lora_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
|
||||
from vllm.lora.ops.xpu_ops.lora_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
|
||||
|
||||
__all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"]
|
||||
@@ -7,11 +7,6 @@ from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
except ImportError as e:
|
||||
raise e
|
||||
|
||||
|
||||
def bgmv_shrink(
|
||||
inputs: torch.Tensor,
|
||||
@@ -20,8 +15,8 @@ def bgmv_shrink(
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
scaling: float = 1.0,
|
||||
) -> None:
|
||||
ipex.llm.functional.bgmv_shrink(
|
||||
inputs, lora_a_weights, output_tensor, lora_indices_tensor, scaling
|
||||
torch.ops._xpu_C.bgmv_shrink(
|
||||
output_tensor, inputs, lora_a_weights, lora_indices_tensor, scaling
|
||||
)
|
||||
|
||||
|
||||
@@ -32,8 +27,8 @@ def bgmv_expand(
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
) -> None:
|
||||
ipex.llm.functional.bgmv_expand(
|
||||
inputs, lora_b_weights, output_tensor, lora_indices_tensor, add_inputs
|
||||
torch.ops._xpu_C.bgmv_expand(
|
||||
output_tensor, inputs, lora_b_weights, lora_indices_tensor, add_inputs
|
||||
)
|
||||
|
||||
|
||||
@@ -46,10 +41,12 @@ def bgmv_expand_slice(
|
||||
slice_size: int,
|
||||
add_inputs: bool = True,
|
||||
) -> None:
|
||||
ipex.llm.functional.bgmv_expand_slice(
|
||||
assert slice_size == lora_b_weights.size(-2)
|
||||
assert slice_offset + slice_size <= output_tensor.size(1)
|
||||
torch.ops._xpu_C.bgmv_expand_slice(
|
||||
output_tensor,
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
output_tensor,
|
||||
lora_indices_tensor,
|
||||
slice_offset,
|
||||
slice_size,
|
||||
@@ -11,8 +11,17 @@ from typing import final
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.ops.ipex_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
|
||||
from vllm.lora.ops.xpu_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
|
||||
from vllm.triton_utils import HAS_TRITON, triton
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.lora.ops.triton_ops import (
|
||||
LoRAKernelMeta,
|
||||
fused_moe_lora,
|
||||
)
|
||||
|
||||
from .punica_base import PunicaWrapperBase
|
||||
|
||||
@@ -37,6 +46,12 @@ class PunicaWrapperXPU(PunicaWrapperBase):
|
||||
torch._dynamo.mark_dynamic(self._embeddings_indices, 1)
|
||||
torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0)
|
||||
|
||||
self.lora_config = kwargs["lora_config"]
|
||||
self.max_loras = self.lora_config.max_loras
|
||||
self.token_mapping_meta = LoRAKernelMeta.make(
|
||||
self.max_loras, max_num_batched_tokens, device=device
|
||||
)
|
||||
|
||||
def update_metadata(
|
||||
self,
|
||||
mapping: LoRAMapping,
|
||||
@@ -206,11 +221,9 @@ class PunicaWrapperXPU(PunicaWrapperBase):
|
||||
|
||||
if buffer is None:
|
||||
r = lora_b_stacked[0].size(-1)
|
||||
# We set the buffer to be float32 by default, refer to:
|
||||
# https://github.com/triton-lang/triton/issues/1387
|
||||
buffer = torch.zeros( # type: ignore
|
||||
(len(output_slices), x.size(0), r),
|
||||
dtype=torch.float32,
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
)
|
||||
self.add_shrink(
|
||||
@@ -267,10 +280,142 @@ class PunicaWrapperXPU(PunicaWrapperBase):
|
||||
x = x.view(-1, x.shape[-1])
|
||||
r = lora_b_stacked.size(-1)
|
||||
if buffer is None:
|
||||
# We set the buffer to be float32 by default, refer to:
|
||||
# https://github.com/triton-lang/triton/issues/1387
|
||||
buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device)
|
||||
buffer = torch.zeros((x.size(0), r), dtype=x.dtype, device=x.device)
|
||||
sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0))
|
||||
bgmv_shrink(x, lora_a_stacked, buffer, sampler_indices, scale)
|
||||
bgmv_expand(buffer, lora_b_stacked, y, sampler_indices, add_inputs=True)
|
||||
return y.view_as(y_org)
|
||||
|
||||
def moe_lora_align_block_size(
|
||||
self,
|
||||
topk_ids: torch.Tensor,
|
||||
num_tokens: int,
|
||||
block_size: int,
|
||||
num_experts: int,
|
||||
max_loras: int,
|
||||
adapter_enabled: torch.Tensor,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
pad_sorted_ids: bool = False,
|
||||
naive_block_assignment: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Aligns tokens and experts into block-sized chunks for LoRA-based
|
||||
mixture-of-experts (MoE) execution.
|
||||
"""
|
||||
(token_lora_mapping, _, _, _, lora_ids, _, _) = (
|
||||
self.token_mapping_meta.meta_args(
|
||||
num_tokens, self.lora_config.specialize_active_lora
|
||||
)
|
||||
)
|
||||
if naive_block_assignment:
|
||||
expert_ids = topk_ids.reshape(-1)
|
||||
sorted_ids = None
|
||||
num_tokens_post_pad = None
|
||||
else:
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
if pad_sorted_ids:
|
||||
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
|
||||
sorted_ids = torch.empty(
|
||||
(max_loras * max_num_tokens_padded,),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device,
|
||||
)
|
||||
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
||||
# Expert ids must be set default to -1 to prevent a blank block
|
||||
expert_ids = torch.empty(
|
||||
(max_loras * max_num_m_blocks,),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device,
|
||||
)
|
||||
num_tokens_post_pad = torch.empty(
|
||||
(max_loras), dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
ops.moe_lora_align_block_size(
|
||||
topk_ids,
|
||||
token_lora_mapping,
|
||||
num_experts,
|
||||
block_size,
|
||||
max_loras,
|
||||
max_num_tokens_padded,
|
||||
max_num_m_blocks,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
adapter_enabled,
|
||||
lora_ids,
|
||||
)
|
||||
if expert_map is not None:
|
||||
expert_ids = expert_map[expert_ids]
|
||||
|
||||
return None, sorted_ids, expert_ids, num_tokens_post_pad
|
||||
|
||||
def add_lora_fused_moe(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: tuple[torch.Tensor, ...],
|
||||
lora_b_stacked: tuple[torch.Tensor, ...],
|
||||
topk_weights: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor | None,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor | None,
|
||||
max_lora_rank: int,
|
||||
top_k_num: int,
|
||||
shrink_config,
|
||||
expand_config,
|
||||
adapter_enabled: torch.Tensor,
|
||||
mul_routed_weight=False,
|
||||
fully_sharded: bool = False,
|
||||
offset: int = 0,
|
||||
token_lora_mapping: torch.Tensor | None = None,
|
||||
):
|
||||
"""
|
||||
Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
|
||||
"""
|
||||
(
|
||||
token_lora_mapping_meta,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
lora_ids,
|
||||
_,
|
||||
num_active_loras,
|
||||
) = self.token_mapping_meta.meta_args(
|
||||
x.size(0), self.lora_config.specialize_active_lora
|
||||
)
|
||||
if token_lora_mapping is None:
|
||||
token_lora_mapping = token_lora_mapping_meta
|
||||
fused_moe_lora(
|
||||
y,
|
||||
x,
|
||||
lora_a_stacked,
|
||||
lora_b_stacked,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
token_lora_mapping,
|
||||
max_lora_rank,
|
||||
top_k_num,
|
||||
lora_ids,
|
||||
num_active_loras,
|
||||
adapter_enabled,
|
||||
shrink_config.get("BLOCK_SIZE_M", 64),
|
||||
shrink_config.get("BLOCK_SIZE_N", 64),
|
||||
shrink_config.get("BLOCK_SIZE_K", 32),
|
||||
shrink_config.get("GROUP_SIZE_M", 8),
|
||||
shrink_config.get("NUM_WARPS", 4),
|
||||
shrink_config.get("NUM_STAGES", 3),
|
||||
shrink_config.get("SPLIT_K", 1),
|
||||
expand_config.get("BLOCK_SIZE_M", 64),
|
||||
expand_config.get("BLOCK_SIZE_N", 64),
|
||||
expand_config.get("BLOCK_SIZE_K", 32),
|
||||
expand_config.get("GROUP_SIZE_M", 8),
|
||||
expand_config.get("NUM_WARPS", 4),
|
||||
expand_config.get("NUM_STAGES", 3),
|
||||
expand_config.get("SPLIT_K", 1),
|
||||
mul_routed_weight,
|
||||
fully_sharded,
|
||||
offset,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user