diff --git a/tests/lora/test_fused_moe_lora_kernel.py b/tests/lora/test_fused_moe_lora_kernel.py index b79b668f3..382999bca 100644 --- a/tests/lora/test_fused_moe_lora_kernel.py +++ b/tests/lora/test_fused_moe_lora_kernel.py @@ -18,6 +18,7 @@ from vllm.distributed.parallel_state import ( get_tensor_model_parallel_world_size, ) from vllm.lora.ops.triton_ops import fused_moe_lora +from vllm.platforms import current_platform from vllm.utils.network_utils import get_open_port from vllm.utils.torch_utils import set_random_seed @@ -244,8 +245,9 @@ def use_torch( return torch.stack(outputs, dim=0) +DEVICE_TYPE = current_platform.device_type DTYPES = [torch.float16, torch.bfloat16] -DEVICES = [f"cuda:{0}"] +DEVICES = [f"{DEVICE_TYPE}:{0}"] SEED = [42] diff --git a/tests/lora/test_punica_xpu_ops.py b/tests/lora/test_punica_xpu_ops.py new file mode 100644 index 000000000..585c97cfa --- /dev/null +++ b/tests/lora/test_punica_xpu_ops.py @@ -0,0 +1,298 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from tests.lora.utils import ( + PunicaTensors, + assert_close, + generate_data, + generate_data_for_expand_nslices, +) +from vllm.lora.ops.xpu_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink +from vllm.platforms import current_platform + + +def torch_bgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True, +): + selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + limit = output_tensor.shape[0] + if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: + limit = 1 + + # LoRA adapter and model may add different amounts of padding to output + common_len = min(outputs.shape[1], output_tensor.shape[1]) + + if add_inputs: + output_tensor[:, :common_len] += outputs[:limit, :common_len] + else: + output_tensor[:, :common_len] = outputs[:limit, :common_len] + + +def torch_bgmv_shrink( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0, +): + selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + inputs = inputs.to(dtype=output_tensor.dtype) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + output_tensor[:, : outputs.shape[1]] = scaling * outputs[:] + + +def torch_bgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True, +): + selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype) + inputs = inputs.to(dtype=output_tensor.dtype) + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(dim=1) + outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras) + + if add_inputs: + output_tensor[:, slice_offset : slice_offset + slice_size] += outputs[:] + else: + output_tensor[:, slice_offset : slice_offset + slice_size] = outputs[:] + + +def check_bgmv_shrink( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + dtype: torch.dtype, + device: str, + scaling: float, +): + """ + Compare vllm.bgmv_shrink against a reference implementation. + """ + seq_length = 1 + data: PunicaTensors = generate_data( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + "shrink", + device, + ) + + bgmv_shrink( + data.inputs_tensor, + data.lora_weights, + data.our_out_tensor, + data.token_lora_mapping, + scaling, + ) + + torch_bgmv_shrink( + data.inputs_tensor, + data.lora_weights, + data.ref_out_tensor, + data.token_lora_mapping, + scaling, + ) + + data.ref_out_tensor = data.ref_out_tensor.to(torch.float32) + assert_close(data.our_out_tensor, data.ref_out_tensor) + + +def check_bgmv_expand( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + dtype: torch.dtype, + device: str, + add_inputs: bool, +): + """ + Compare vllm.bgmv_expand against a reference implementation. + """ + seq_length = 1 + data: PunicaTensors = generate_data( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + "expand", + device, + ) + + bgmv_expand( + data.inputs_tensor, + data.lora_weights, + data.our_out_tensor, + data.token_lora_mapping, + add_inputs=add_inputs, + ) + torch_bgmv_expand( + data.inputs_tensor, + data.lora_weights, + data.ref_out_tensor, + data.token_lora_mapping, + add_inputs=add_inputs, + ) + assert_close(data.ref_out_tensor, data.our_out_tensor) + + +def check_bgmv_expand_slice( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + nslices: int, + dtype: torch.dtype, + device: str, + add_inputs: bool, +): + """ + Compare vllm.bgmv_expand_slice against a reference implementation. + """ + seq_length = 1 + data: PunicaTensors = generate_data_for_expand_nslices( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + nslices, + device, + ) + + slice_offset = 0 + for index in range(nslices): + bgmv_expand_slice( + data.inputs_tensor, + data.lora_weights[index], + data.our_out_tensor, + data.token_lora_mapping, + slice_offset, + slice_size=hidden_size, + add_inputs=add_inputs, + ) + torch_bgmv_expand_slice( + data.inputs_tensor, + data.lora_weights[index], + data.ref_out_tensor, + data.token_lora_mapping, + slice_offset, + slice_size=hidden_size, + add_inputs=add_inputs, + ) + + slice_offset += hidden_size + assert_close(data.ref_out_tensor, data.our_out_tensor) + + +# General tests params that tests for variations in all dimensions +# except hidden_size. +test_params = { + "hidden_sizes": [2049], + "batches": [4], + "num_loras": [4], + "max_ranks": [32], +} + +DTYPES = [torch.float16, torch.bfloat16] +DEVICES = [f"xpu:{0}"] +SEED = [0] + + +@pytest.mark.parametrize("batches", test_params["batches"]) +@pytest.mark.parametrize("num_loras", test_params["num_loras"]) +@pytest.mark.parametrize("rank", test_params["max_ranks"]) +@pytest.mark.parametrize("hidden_size", test_params["hidden_sizes"]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +@pytest.mark.skipif(not current_platform.is_xpu(), reason="skip for non xpu platform") +def test_bgmv( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + dtype: torch.dtype, + device: str, + seed: int, + op_type: str, +): + if op_type == "shrink": + check_bgmv_shrink( + batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + dtype=dtype, + device=device, + scaling=0.5, + ) + else: + check_bgmv_expand( + batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + dtype=dtype, + device=device, + add_inputs=True, + ) + + +@pytest.mark.parametrize("batches", test_params["batches"]) +@pytest.mark.parametrize("num_loras", test_params["num_loras"]) +@pytest.mark.parametrize("rank", test_params["max_ranks"]) +@pytest.mark.parametrize("hidden_size", test_params["hidden_sizes"]) +@pytest.mark.parametrize("nslices", [2, 3]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.skipif(not current_platform.is_xpu(), reason="skip for non xpu platform") +def test_bgmv_expand_nslices( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + nslices: int, + dtype: torch.dtype, + device: str, + seed: int, +): + check_bgmv_expand_slice( + batches=batches, + num_loras=num_loras, + rank=rank, + hidden_size=hidden_size, + nslices=nslices, + dtype=dtype, + device=device, + add_inputs=True, + ) diff --git a/vllm/lora/ops/ipex_ops/__init__.py b/vllm/lora/ops/xpu_ops/__init__.py similarity index 66% rename from vllm/lora/ops/ipex_ops/__init__.py rename to vllm/lora/ops/xpu_ops/__init__.py index f5a5e0e6f..f7f16bf23 100644 --- a/vllm/lora/ops/ipex_ops/__init__.py +++ b/vllm/lora/ops/xpu_ops/__init__.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.lora.ops.ipex_ops.lora_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink +from vllm.lora.ops.xpu_ops.lora_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink __all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"] diff --git a/vllm/lora/ops/ipex_ops/lora_ops.py b/vllm/lora/ops/xpu_ops/lora_ops.py similarity index 74% rename from vllm/lora/ops/ipex_ops/lora_ops.py rename to vllm/lora/ops/xpu_ops/lora_ops.py index 0767f90b2..6d1751c37 100644 --- a/vllm/lora/ops/ipex_ops/lora_ops.py +++ b/vllm/lora/ops/xpu_ops/lora_ops.py @@ -7,11 +7,6 @@ from vllm.logger import init_logger logger = init_logger(__name__) -try: - import intel_extension_for_pytorch as ipex -except ImportError as e: - raise e - def bgmv_shrink( inputs: torch.Tensor, @@ -20,8 +15,8 @@ def bgmv_shrink( lora_indices_tensor: torch.Tensor, scaling: float = 1.0, ) -> None: - ipex.llm.functional.bgmv_shrink( - inputs, lora_a_weights, output_tensor, lora_indices_tensor, scaling + torch.ops._xpu_C.bgmv_shrink( + output_tensor, inputs, lora_a_weights, lora_indices_tensor, scaling ) @@ -32,8 +27,8 @@ def bgmv_expand( lora_indices_tensor: torch.Tensor, add_inputs: bool = True, ) -> None: - ipex.llm.functional.bgmv_expand( - inputs, lora_b_weights, output_tensor, lora_indices_tensor, add_inputs + torch.ops._xpu_C.bgmv_expand( + output_tensor, inputs, lora_b_weights, lora_indices_tensor, add_inputs ) @@ -46,10 +41,12 @@ def bgmv_expand_slice( slice_size: int, add_inputs: bool = True, ) -> None: - ipex.llm.functional.bgmv_expand_slice( + assert slice_size == lora_b_weights.size(-2) + assert slice_offset + slice_size <= output_tensor.size(1) + torch.ops._xpu_C.bgmv_expand_slice( + output_tensor, inputs, lora_b_weights, - output_tensor, lora_indices_tensor, slice_offset, slice_size, diff --git a/vllm/lora/punica_wrapper/punica_xpu.py b/vllm/lora/punica_wrapper/punica_xpu.py index 00c007828..f031e1bfa 100644 --- a/vllm/lora/punica_wrapper/punica_xpu.py +++ b/vllm/lora/punica_wrapper/punica_xpu.py @@ -11,8 +11,17 @@ from typing import final import torch +from vllm import _custom_ops as ops from vllm.lora.layers import LoRAMapping -from vllm.lora.ops.ipex_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink +from vllm.lora.ops.xpu_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink +from vllm.triton_utils import HAS_TRITON, triton +from vllm.utils.math_utils import round_up + +if HAS_TRITON: + from vllm.lora.ops.triton_ops import ( + LoRAKernelMeta, + fused_moe_lora, + ) from .punica_base import PunicaWrapperBase @@ -37,6 +46,12 @@ class PunicaWrapperXPU(PunicaWrapperBase): torch._dynamo.mark_dynamic(self._embeddings_indices, 1) torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0) + self.lora_config = kwargs["lora_config"] + self.max_loras = self.lora_config.max_loras + self.token_mapping_meta = LoRAKernelMeta.make( + self.max_loras, max_num_batched_tokens, device=device + ) + def update_metadata( self, mapping: LoRAMapping, @@ -206,11 +221,9 @@ class PunicaWrapperXPU(PunicaWrapperBase): if buffer is None: r = lora_b_stacked[0].size(-1) - # We set the buffer to be float32 by default, refer to: - # https://github.com/triton-lang/triton/issues/1387 buffer = torch.zeros( # type: ignore (len(output_slices), x.size(0), r), - dtype=torch.float32, + dtype=x.dtype, device=x.device, ) self.add_shrink( @@ -267,10 +280,142 @@ class PunicaWrapperXPU(PunicaWrapperBase): x = x.view(-1, x.shape[-1]) r = lora_b_stacked.size(-1) if buffer is None: - # We set the buffer to be float32 by default, refer to: - # https://github.com/triton-lang/triton/issues/1387 - buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) + buffer = torch.zeros((x.size(0), r), dtype=x.dtype, device=x.device) sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0)) bgmv_shrink(x, lora_a_stacked, buffer, sampler_indices, scale) bgmv_expand(buffer, lora_b_stacked, y, sampler_indices, add_inputs=True) return y.view_as(y_org) + + def moe_lora_align_block_size( + self, + topk_ids: torch.Tensor, + num_tokens: int, + block_size: int, + num_experts: int, + max_loras: int, + adapter_enabled: torch.Tensor, + expert_map: torch.Tensor | None = None, + pad_sorted_ids: bool = False, + naive_block_assignment: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Aligns tokens and experts into block-sized chunks for LoRA-based + mixture-of-experts (MoE) execution. + """ + (token_lora_mapping, _, _, _, lora_ids, _, _) = ( + self.token_mapping_meta.meta_args( + num_tokens, self.lora_config.specialize_active_lora + ) + ) + if naive_block_assignment: + expert_ids = topk_ids.reshape(-1) + sorted_ids = None + num_tokens_post_pad = None + else: + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + if pad_sorted_ids: + max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + sorted_ids = torch.empty( + (max_loras * max_num_tokens_padded,), + dtype=torch.int32, + device=topk_ids.device, + ) + max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) + # Expert ids must be set default to -1 to prevent a blank block + expert_ids = torch.empty( + (max_loras * max_num_m_blocks,), + dtype=torch.int32, + device=topk_ids.device, + ) + num_tokens_post_pad = torch.empty( + (max_loras), dtype=torch.int32, device=topk_ids.device + ) + + ops.moe_lora_align_block_size( + topk_ids, + token_lora_mapping, + num_experts, + block_size, + max_loras, + max_num_tokens_padded, + max_num_m_blocks, + sorted_ids, + expert_ids, + num_tokens_post_pad, + adapter_enabled, + lora_ids, + ) + if expert_map is not None: + expert_ids = expert_map[expert_ids] + + return None, sorted_ids, expert_ids, num_tokens_post_pad + + def add_lora_fused_moe( + self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + topk_weights: torch.Tensor, + sorted_token_ids: torch.Tensor | None, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor | None, + max_lora_rank: int, + top_k_num: int, + shrink_config, + expand_config, + adapter_enabled: torch.Tensor, + mul_routed_weight=False, + fully_sharded: bool = False, + offset: int = 0, + token_lora_mapping: torch.Tensor | None = None, + ): + """ + Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer. + """ + ( + token_lora_mapping_meta, + _, + _, + _, + lora_ids, + _, + num_active_loras, + ) = self.token_mapping_meta.meta_args( + x.size(0), self.lora_config.specialize_active_lora + ) + if token_lora_mapping is None: + token_lora_mapping = token_lora_mapping_meta + fused_moe_lora( + y, + x, + lora_a_stacked, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + token_lora_mapping, + max_lora_rank, + top_k_num, + lora_ids, + num_active_loras, + adapter_enabled, + shrink_config.get("BLOCK_SIZE_M", 64), + shrink_config.get("BLOCK_SIZE_N", 64), + shrink_config.get("BLOCK_SIZE_K", 32), + shrink_config.get("GROUP_SIZE_M", 8), + shrink_config.get("NUM_WARPS", 4), + shrink_config.get("NUM_STAGES", 3), + shrink_config.get("SPLIT_K", 1), + expand_config.get("BLOCK_SIZE_M", 64), + expand_config.get("BLOCK_SIZE_N", 64), + expand_config.get("BLOCK_SIZE_K", 32), + expand_config.get("GROUP_SIZE_M", 8), + expand_config.get("NUM_WARPS", 4), + expand_config.get("NUM_STAGES", 3), + expand_config.get("SPLIT_K", 1), + mul_routed_weight, + fully_sharded, + offset, + )