[XPU][8/N] Fix kernel bugs in XPU LoRA and MOE LORA (#34115)

Signed-off-by: chzhang <chaojun.zhang@intel.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Chaojun Zhang
2026-02-26 15:46:44 +08:00
committed by GitHub
parent a07c4c5939
commit 9f9a675b23
5 changed files with 462 additions and 20 deletions

View File

@@ -18,6 +18,7 @@ from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size,
)
from vllm.lora.ops.triton_ops import fused_moe_lora
from vllm.platforms import current_platform
from vllm.utils.network_utils import get_open_port
from vllm.utils.torch_utils import set_random_seed
@@ -244,8 +245,9 @@ def use_torch(
return torch.stack(outputs, dim=0)
DEVICE_TYPE = current_platform.device_type
DTYPES = [torch.float16, torch.bfloat16]
DEVICES = [f"cuda:{0}"]
DEVICES = [f"{DEVICE_TYPE}:{0}"]
SEED = [42]

View File

@@ -0,0 +1,298 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.lora.utils import (
PunicaTensors,
assert_close,
generate_data,
generate_data_for_expand_nslices,
)
from vllm.lora.ops.xpu_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
from vllm.platforms import current_platform
def torch_bgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
):
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(dim=1)
inputs = inputs.to(dtype=output_tensor.dtype)
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
limit = output_tensor.shape[0]
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
limit = 1
# LoRA adapter and model may add different amounts of padding to output
common_len = min(outputs.shape[1], output_tensor.shape[1])
if add_inputs:
output_tensor[:, :common_len] += outputs[:limit, :common_len]
else:
output_tensor[:, :common_len] = outputs[:limit, :common_len]
def torch_bgmv_shrink(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
):
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(dim=1)
inputs = inputs.to(dtype=output_tensor.dtype)
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
output_tensor[:, : outputs.shape[1]] = scaling * outputs[:]
def torch_bgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True,
):
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
inputs = inputs.to(dtype=output_tensor.dtype)
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(dim=1)
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
if add_inputs:
output_tensor[:, slice_offset : slice_offset + slice_size] += outputs[:]
else:
output_tensor[:, slice_offset : slice_offset + slice_size] = outputs[:]
def check_bgmv_shrink(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
dtype: torch.dtype,
device: str,
scaling: float,
):
"""
Compare vllm.bgmv_shrink against a reference implementation.
"""
seq_length = 1
data: PunicaTensors = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
"shrink",
device,
)
bgmv_shrink(
data.inputs_tensor,
data.lora_weights,
data.our_out_tensor,
data.token_lora_mapping,
scaling,
)
torch_bgmv_shrink(
data.inputs_tensor,
data.lora_weights,
data.ref_out_tensor,
data.token_lora_mapping,
scaling,
)
data.ref_out_tensor = data.ref_out_tensor.to(torch.float32)
assert_close(data.our_out_tensor, data.ref_out_tensor)
def check_bgmv_expand(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
dtype: torch.dtype,
device: str,
add_inputs: bool,
):
"""
Compare vllm.bgmv_expand against a reference implementation.
"""
seq_length = 1
data: PunicaTensors = generate_data(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
"expand",
device,
)
bgmv_expand(
data.inputs_tensor,
data.lora_weights,
data.our_out_tensor,
data.token_lora_mapping,
add_inputs=add_inputs,
)
torch_bgmv_expand(
data.inputs_tensor,
data.lora_weights,
data.ref_out_tensor,
data.token_lora_mapping,
add_inputs=add_inputs,
)
assert_close(data.ref_out_tensor, data.our_out_tensor)
def check_bgmv_expand_slice(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
nslices: int,
dtype: torch.dtype,
device: str,
add_inputs: bool,
):
"""
Compare vllm.bgmv_expand_slice against a reference implementation.
"""
seq_length = 1
data: PunicaTensors = generate_data_for_expand_nslices(
batches,
hidden_size,
num_loras,
rank,
seq_length,
dtype,
nslices,
device,
)
slice_offset = 0
for index in range(nslices):
bgmv_expand_slice(
data.inputs_tensor,
data.lora_weights[index],
data.our_out_tensor,
data.token_lora_mapping,
slice_offset,
slice_size=hidden_size,
add_inputs=add_inputs,
)
torch_bgmv_expand_slice(
data.inputs_tensor,
data.lora_weights[index],
data.ref_out_tensor,
data.token_lora_mapping,
slice_offset,
slice_size=hidden_size,
add_inputs=add_inputs,
)
slice_offset += hidden_size
assert_close(data.ref_out_tensor, data.our_out_tensor)
# General tests params that tests for variations in all dimensions
# except hidden_size.
test_params = {
"hidden_sizes": [2049],
"batches": [4],
"num_loras": [4],
"max_ranks": [32],
}
DTYPES = [torch.float16, torch.bfloat16]
DEVICES = [f"xpu:{0}"]
SEED = [0]
@pytest.mark.parametrize("batches", test_params["batches"])
@pytest.mark.parametrize("num_loras", test_params["num_loras"])
@pytest.mark.parametrize("rank", test_params["max_ranks"])
@pytest.mark.parametrize("hidden_size", test_params["hidden_sizes"])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
@pytest.mark.skipif(not current_platform.is_xpu(), reason="skip for non xpu platform")
def test_bgmv(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
dtype: torch.dtype,
device: str,
seed: int,
op_type: str,
):
if op_type == "shrink":
check_bgmv_shrink(
batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
dtype=dtype,
device=device,
scaling=0.5,
)
else:
check_bgmv_expand(
batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
dtype=dtype,
device=device,
add_inputs=True,
)
@pytest.mark.parametrize("batches", test_params["batches"])
@pytest.mark.parametrize("num_loras", test_params["num_loras"])
@pytest.mark.parametrize("rank", test_params["max_ranks"])
@pytest.mark.parametrize("hidden_size", test_params["hidden_sizes"])
@pytest.mark.parametrize("nslices", [2, 3])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.skipif(not current_platform.is_xpu(), reason="skip for non xpu platform")
def test_bgmv_expand_nslices(
batches: int,
num_loras: int,
rank: int,
hidden_size: int,
nslices: int,
dtype: torch.dtype,
device: str,
seed: int,
):
check_bgmv_expand_slice(
batches=batches,
num_loras=num_loras,
rank=rank,
hidden_size=hidden_size,
nslices=nslices,
dtype=dtype,
device=device,
add_inputs=True,
)

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.lora.ops.ipex_ops.lora_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
from vllm.lora.ops.xpu_ops.lora_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
__all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"]

View File

@@ -7,11 +7,6 @@ from vllm.logger import init_logger
logger = init_logger(__name__)
try:
import intel_extension_for_pytorch as ipex
except ImportError as e:
raise e
def bgmv_shrink(
inputs: torch.Tensor,
@@ -20,8 +15,8 @@ def bgmv_shrink(
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
) -> None:
ipex.llm.functional.bgmv_shrink(
inputs, lora_a_weights, output_tensor, lora_indices_tensor, scaling
torch.ops._xpu_C.bgmv_shrink(
output_tensor, inputs, lora_a_weights, lora_indices_tensor, scaling
)
@@ -32,8 +27,8 @@ def bgmv_expand(
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
) -> None:
ipex.llm.functional.bgmv_expand(
inputs, lora_b_weights, output_tensor, lora_indices_tensor, add_inputs
torch.ops._xpu_C.bgmv_expand(
output_tensor, inputs, lora_b_weights, lora_indices_tensor, add_inputs
)
@@ -46,10 +41,12 @@ def bgmv_expand_slice(
slice_size: int,
add_inputs: bool = True,
) -> None:
ipex.llm.functional.bgmv_expand_slice(
assert slice_size == lora_b_weights.size(-2)
assert slice_offset + slice_size <= output_tensor.size(1)
torch.ops._xpu_C.bgmv_expand_slice(
output_tensor,
inputs,
lora_b_weights,
output_tensor,
lora_indices_tensor,
slice_offset,
slice_size,

View File

@@ -11,8 +11,17 @@ from typing import final
import torch
from vllm import _custom_ops as ops
from vllm.lora.layers import LoRAMapping
from vllm.lora.ops.ipex_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
from vllm.lora.ops.xpu_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
from vllm.triton_utils import HAS_TRITON, triton
from vllm.utils.math_utils import round_up
if HAS_TRITON:
from vllm.lora.ops.triton_ops import (
LoRAKernelMeta,
fused_moe_lora,
)
from .punica_base import PunicaWrapperBase
@@ -37,6 +46,12 @@ class PunicaWrapperXPU(PunicaWrapperBase):
torch._dynamo.mark_dynamic(self._embeddings_indices, 1)
torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0)
self.lora_config = kwargs["lora_config"]
self.max_loras = self.lora_config.max_loras
self.token_mapping_meta = LoRAKernelMeta.make(
self.max_loras, max_num_batched_tokens, device=device
)
def update_metadata(
self,
mapping: LoRAMapping,
@@ -206,11 +221,9 @@ class PunicaWrapperXPU(PunicaWrapperBase):
if buffer is None:
r = lora_b_stacked[0].size(-1)
# We set the buffer to be float32 by default, refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer = torch.zeros( # type: ignore
(len(output_slices), x.size(0), r),
dtype=torch.float32,
dtype=x.dtype,
device=x.device,
)
self.add_shrink(
@@ -267,10 +280,142 @@ class PunicaWrapperXPU(PunicaWrapperBase):
x = x.view(-1, x.shape[-1])
r = lora_b_stacked.size(-1)
if buffer is None:
# We set the buffer to be float32 by default, refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device)
buffer = torch.zeros((x.size(0), r), dtype=x.dtype, device=x.device)
sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0))
bgmv_shrink(x, lora_a_stacked, buffer, sampler_indices, scale)
bgmv_expand(buffer, lora_b_stacked, y, sampler_indices, add_inputs=True)
return y.view_as(y_org)
def moe_lora_align_block_size(
self,
topk_ids: torch.Tensor,
num_tokens: int,
block_size: int,
num_experts: int,
max_loras: int,
adapter_enabled: torch.Tensor,
expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False,
naive_block_assignment: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns tokens and experts into block-sized chunks for LoRA-based
mixture-of-experts (MoE) execution.
"""
(token_lora_mapping, _, _, _, lora_ids, _, _) = (
self.token_mapping_meta.meta_args(
num_tokens, self.lora_config.specialize_active_lora
)
)
if naive_block_assignment:
expert_ids = topk_ids.reshape(-1)
sorted_ids = None
num_tokens_post_pad = None
else:
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
sorted_ids = torch.empty(
(max_loras * max_num_tokens_padded,),
dtype=torch.int32,
device=topk_ids.device,
)
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
# Expert ids must be set default to -1 to prevent a blank block
expert_ids = torch.empty(
(max_loras * max_num_m_blocks,),
dtype=torch.int32,
device=topk_ids.device,
)
num_tokens_post_pad = torch.empty(
(max_loras), dtype=torch.int32, device=topk_ids.device
)
ops.moe_lora_align_block_size(
topk_ids,
token_lora_mapping,
num_experts,
block_size,
max_loras,
max_num_tokens_padded,
max_num_m_blocks,
sorted_ids,
expert_ids,
num_tokens_post_pad,
adapter_enabled,
lora_ids,
)
if expert_map is not None:
expert_ids = expert_map[expert_ids]
return None, sorted_ids, expert_ids, num_tokens_post_pad
def add_lora_fused_moe(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor | None,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor | None,
max_lora_rank: int,
top_k_num: int,
shrink_config,
expand_config,
adapter_enabled: torch.Tensor,
mul_routed_weight=False,
fully_sharded: bool = False,
offset: int = 0,
token_lora_mapping: torch.Tensor | None = None,
):
"""
Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
"""
(
token_lora_mapping_meta,
_,
_,
_,
lora_ids,
_,
num_active_loras,
) = self.token_mapping_meta.meta_args(
x.size(0), self.lora_config.specialize_active_lora
)
if token_lora_mapping is None:
token_lora_mapping = token_lora_mapping_meta
fused_moe_lora(
y,
x,
lora_a_stacked,
lora_b_stacked,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
token_lora_mapping,
max_lora_rank,
top_k_num,
lora_ids,
num_active_loras,
adapter_enabled,
shrink_config.get("BLOCK_SIZE_M", 64),
shrink_config.get("BLOCK_SIZE_N", 64),
shrink_config.get("BLOCK_SIZE_K", 32),
shrink_config.get("GROUP_SIZE_M", 8),
shrink_config.get("NUM_WARPS", 4),
shrink_config.get("NUM_STAGES", 3),
shrink_config.get("SPLIT_K", 1),
expand_config.get("BLOCK_SIZE_M", 64),
expand_config.get("BLOCK_SIZE_N", 64),
expand_config.get("BLOCK_SIZE_K", 32),
expand_config.get("GROUP_SIZE_M", 8),
expand_config.get("NUM_WARPS", 4),
expand_config.get("NUM_STAGES", 3),
expand_config.get("SPLIT_K", 1),
mul_routed_weight,
fully_sharded,
offset,
)