diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index 6715c9b54..8ca3cf78f 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -842,6 +842,7 @@ class BenchmarkTensors: "sorted_token_ids": sorted_token_ids, "expert_ids": expert_ids, "num_tokens_post_padded": num_tokens_post_padded, + "token_lora_mapping": self.lora_kernel_meta.token_lora_mapping, "top_k_num": ctx.top_k_num, "device": self.input.device, "N": lora_rank, @@ -915,6 +916,7 @@ class BenchmarkTensors: "sorted_token_ids": sorted_token_ids, "expert_ids": expert_ids, "num_tokens_post_padded": num_tokens_post_padded, + "token_lora_mapping": self.lora_kernel_meta.token_lora_mapping, "top_k_num": ctx.top_k_num, "device": self.input.device, "N": lora_rank, diff --git a/tests/lora/test_fused_moe_lora_kernel.py b/tests/lora/test_fused_moe_lora_kernel.py index a4d314be0..c97421a3f 100644 --- a/tests/lora/test_fused_moe_lora_kernel.py +++ b/tests/lora/test_fused_moe_lora_kernel.py @@ -190,6 +190,7 @@ def use_fused_moe_lora_kernel( sorted_token_ids, expert_ids, num_tokens_post_padded, + token_lora_mapping, max_lora_rank, top_k_num, lora_ids, @@ -333,6 +334,189 @@ def test_fused_moe_lora_kernel( torch.testing.assert_close(output, output2, atol=1e-1, rtol=1e-1) +def use_fused_moe_lora_kernel_naive( + topk_ids, + topk_weights, + token_lora_mapping, + max_lora_rank, + top_k_num, + lora_a_stacked, + lora_b_stacked, + hidden_states, + output, + max_loras, + block_size, + fully_sharded=False, + offset=0, +): + """ + Test helper for naive_block_assignment path. + Skips moe_lora_align_block_size and uses flattened topk_ids as expert_ids. + """ + config = { + "BLOCK_SIZE_M": block_size, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "NUM_WARPS": 4, + "NUM_STAGES": 3, + "SPLIT_K": 1, + } + + mul_routed_weight = False + + # In naive mode: + # - expert_ids = topk_ids.view(-1), shape: (num_tokens * top_k,) + # - sorted_token_ids = None + # - num_tokens_post_padded = None + expert_ids = topk_ids.reshape(-1) + sorted_token_ids = None + num_tokens_post_padded = None + + adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32) + lora_ids = torch.arange(max_loras + 2, dtype=torch.int32) + + fused_moe_lora( + output, + hidden_states, + lora_a_stacked, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + token_lora_mapping, + max_lora_rank, + top_k_num, + lora_ids, + adapter_enabled, + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_K"], + config["GROUP_SIZE_M"], + config["NUM_WARPS"], + config["NUM_STAGES"], + config["SPLIT_K"], + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_K"], + config["GROUP_SIZE_M"], + config["NUM_WARPS"], + config["NUM_STAGES"], + config["SPLIT_K"], + mul_routed_weight=mul_routed_weight, + fully_sharded=fully_sharded, + offset=offset, + ) + + +@pytest.mark.parametrize("num_tokens", [1, 2, 4, 8]) +@pytest.mark.parametrize("top_k_num", [1, 2]) +@pytest.mark.parametrize("num_experts", [64, 128]) +@pytest.mark.parametrize("max_loras", [4, 8]) +@pytest.mark.parametrize("N", [1408]) +@pytest.mark.parametrize("K", [2048]) +@pytest.mark.parametrize("max_lora_rank", [16, 32]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("seed", SEED) +def test_fused_moe_lora_kernel_naive_block_assignment( + num_tokens, + top_k_num, + num_experts, + max_loras, + N, + K, + max_lora_rank, + block_size, + dtype, + device, + seed, +): + """ + Test the naive_block_assignment path of the fused_moe_lora kernel. + This path is triggered when batch_size * top_k is much smaller than + num_experts * max_loras, and skips the moe_lora_align_block_size kernel. + """ + torch.set_default_device(device) + set_random_seed(seed) + + # Verify this configuration would trigger naive_block_assignment + # (num_tokens * top_k * SPARSITY_FACTOR <= num_experts * max_loras) + SPARSITY_FACTOR = 8 + assert num_tokens * top_k_num * SPARSITY_FACTOR <= num_experts * max_loras, ( + f"Test configuration doesn't meet naive_block_assignment condition: " + f"{num_tokens} * {top_k_num} * {SPARSITY_FACTOR} > {num_experts} * {max_loras}" + ) + + # the number of randomly generated sentences. + num_sequences = min(num_tokens, 4) + # generate data + topk_ids, topk_weights, token_lora_mapping = sample_data( + num_tokens, num_sequences, max_loras, num_experts, top_k_num + ) + + # init lora weights + lora_a_stacked = [ + torch.rand( + ( + max_loras, + num_experts, + max_lora_rank, + K, + ), + dtype=dtype, + ) + ] + lora_b_stacked = [ + torch.rand( + ( + max_loras, + num_experts, + N, + max_lora_rank, + ), + dtype=dtype, + ) + ] + hidden_states = torch.rand( + ( + num_tokens, + K, + ), + dtype=dtype, + ) + + # fused_moe_lora_kernel output (naive path) + output = torch.zeros((num_tokens, top_k_num, N), dtype=dtype) + use_fused_moe_lora_kernel_naive( + topk_ids, + topk_weights, + token_lora_mapping, + max_lora_rank, + top_k_num, + lora_a_stacked, + lora_b_stacked, + hidden_states, + output, + max_loras, + block_size, + ) + + # pytorch reference output + output_ref = use_torch( + hidden_states, + token_lora_mapping, + topk_ids, + lora_a_stacked, + lora_b_stacked, + top_k_num, + ) + + torch.testing.assert_close(output, output_ref, atol=1e-1, rtol=1e-1) + + @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize("num_tokens", [100]) @pytest.mark.parametrize("top_k_num", [6]) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index be1fd7cdb..c2b35fbb1 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -190,8 +190,18 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): config_dtype=config_dtype, ) + # SPARSITY_FACTOR is a heuristic margin ensuring tokens * top_k + # activates only a small fraction of total experts * loras. + SPARSITY_FACTOR = 8 + naive_block_assignment = ( + expert_map is None + and num_tokens * top_k * SPARSITY_FACTOR + <= self.base_layer.local_num_experts * self.max_loras + ) + # get the block size of m from customized config or default config ( + token_lora_mapping, sorted_token_ids_lora, expert_ids_lora, num_tokens_post_padded_lora, @@ -203,6 +213,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): self.max_loras, self.adapter_enabled, expert_map, + naive_block_assignment, ) moe_state_dict["sorted_token_ids_lora"] = sorted_token_ids_lora @@ -210,9 +221,13 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): moe_state_dict["num_tokens_post_padded_lora"] = ( num_tokens_post_padded_lora ) + moe_state_dict["token_lora_mapping"] = token_lora_mapping - expert_ids_lora = expert_ids_lora.view(self.max_loras, -1) - sorted_token_ids_lora = sorted_token_ids_lora.view(self.max_loras, -1) + if sorted_token_ids_lora is not None: + expert_ids_lora = expert_ids_lora.view(self.max_loras, -1) + sorted_token_ids_lora = sorted_token_ids_lora.view( + self.max_loras, -1 + ) # self.punica_wrapper.add_lora_fused_moe( @@ -230,6 +245,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): expand_config, ## pass the expand config self.adapter_enabled, fully_sharded=self.fully_sharded, + token_lora_mapping=token_lora_mapping, ) result = func(*args, **kwargs) @@ -270,9 +286,13 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): num_tokens_post_padded_lora = moe_state_dict[ "num_tokens_post_padded_lora" ] + token_lora_mapping = moe_state_dict.get("token_lora_mapping") - expert_ids_lora = expert_ids_lora.view(self.max_loras, -1) - sorted_token_ids_lora = sorted_token_ids_lora.view(self.max_loras, -1) + if sorted_token_ids_lora is not None: + expert_ids_lora = expert_ids_lora.view(self.max_loras, -1) + sorted_token_ids_lora = sorted_token_ids_lora.view( + self.max_loras, -1 + ) intermediate_cache2 = moe_state_dict["intermediate_cache2"] intermediate_cache3 = args[0] @@ -295,6 +315,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): True, fully_sharded=self.fully_sharded, offset=shard_size_w2 * self.tp_rank if self.fully_sharded else 0, + token_lora_mapping=token_lora_mapping, ) result = func(*args, **kwargs) diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 58549ee9f..9e76d742b 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -12,6 +12,64 @@ from vllm.utils.torch_utils import direct_register_custom_op from .utils import supports_pdl + +@triton.jit +def _get_lora_id( + lora_ids, + token_lora_mapping_ptr, + lora_idx, + pid_m, + top_k_num, + naive_block_assignment: tl.constexpr, +): + """Returns lora_id""" + if naive_block_assignment: + token_idx = pid_m // top_k_num + return tl.load(token_lora_mapping_ptr + token_idx) + else: + return tl.load(lora_ids + lora_idx) + + +@triton.jit +def _get_expert_id( + expert_ids_ptr, + lora_id, + pid_m, + stride_el, + max_loras, + naive_block_assignment: tl.constexpr, +): + """Returns expert_id""" + if naive_block_assignment: + return tl.load(expert_ids_ptr + pid_m) + else: + ind = lora_id * stride_el + pid_m + return tl.load(expert_ids_ptr + ind, ind < max_loras * stride_el, -1) + + +@triton.jit +def _get_token_offs( + sorted_token_ids_ptr, + lora_id, + pid_m, + offs, + stride_tl, + max_loras, + num_valid_tokens, + naive_block_assignment: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, +): + """Returns token offsets""" + if naive_block_assignment: + return tl.where(offs == 0, pid_m, num_valid_tokens) + else: + offs_token_id = pid_m * BLOCK_SIZE_M + offs + token_ind = stride_tl * lora_id + offs_token_id + return tl.load( + sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0 + ) + + _LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {} @@ -36,6 +94,25 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device): return _LORA_PTR_DICT.get(key) +def _adjust_kernel_inputs( + max_loras: int, + sorted_token_ids: torch.Tensor | None, + expert_ids: torch.Tensor, +): + """ + helper function to adjust kernel inputs when sorted_token_ids is None + """ + if sorted_token_ids is None: + stride_tl = 0 + stride_el = 0 + grid_lora_dim = 1 + else: + stride_tl = sorted_token_ids.stride(0) + stride_el = expert_ids.stride(0) + grid_lora_dim = max_loras + 1 + return grid_lora_dim, stride_tl, stride_el + + @triton.jit( do_not_specialize=[ "num_valid_tokens", @@ -54,12 +131,14 @@ def _fused_moe_lora_kernel( sorted_token_ids_ptr, expert_ids_ptr, num_tokens_post_padded_ptr, + token_lora_mapping_ptr, # Matrix dimensions N, K, EM, num_valid_tokens, num_experts, + top_k_num, lora_ids, adapter_enabled, max_loras, # <<< PR2: rename, used for masks when grid axis-2 != max_loras @@ -82,7 +161,11 @@ def _fused_moe_lora_kernel( # Meta-parameters num_slice_a: tl.constexpr, num_slice_c: tl.constexpr, - top_k: tl.constexpr, + # top_k_num or 1 depending on input token + # is expanded by top_k or not + token_mapping_factor: tl.constexpr, + # whether use naive block assignment + naive_block_assignment: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr, ADD_INPUTS: tl.constexpr, USE_B_L2_CACHE: tl.constexpr, # new, enable .ca load for B @@ -97,26 +180,10 @@ def _fused_moe_lora_kernel( ): pid = tl.program_id(axis=0) slice_id = tl.program_id(axis=1) - lora_idx = tl.program_id(axis=2) - lora_id = tl.load(lora_ids + lora_idx) - - if lora_id == -1: - # Early exit for the no-lora case. - return - moe_enabled = tl.load(adapter_enabled + lora_id) - if moe_enabled == 0: - # Early exit for the no moe lora case. - return - # The grid's axis-2 dimension is max_loras + 1 to accommodate the -1 sentinel. - # This guard ensures we don't access sorted_token_ids / expert_ids / - # num_tokens_post_padded beyond their allocated bounds if an invalid - # lora_id somehow appears. Although the caller should pass correct - # max_loras, defensive programming prevents accidental out-of-bounds. - if lora_id >= max_loras: - return grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) # calculate pid_m,pid_n + lora_idx = tl.program_id(axis=2) pid_sk = pid % SPLIT_K pid_m_n = pid // SPLIT_K num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) @@ -129,14 +196,55 @@ def _fused_moe_lora_kernel( pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m) pid_n = (pid_m_n % num_pid_in_group) // group_size_m - num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_id) - if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + offs = tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + + # Get lora_id + lora_id = _get_lora_id( + lora_ids, + token_lora_mapping_ptr, + lora_idx, + pid_m, + top_k_num, + naive_block_assignment, + ) + if lora_id == -1: return - # get the expert_id to process curr shard - ind = lora_id * stride_el + pid_m - expert_id = tl.load(expert_ids_ptr + ind, ind < max_loras * stride_el, -1) + moe_enabled = tl.load(adapter_enabled + lora_id) + if moe_enabled == 0: + return + if lora_id >= max_loras: + return + + # Non-naive only: check num_tokens_post_padded + if not naive_block_assignment: + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_id) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + + # Get expert_id + expert_id = _get_expert_id( + expert_ids_ptr, + lora_id, + pid_m, + stride_el, + max_loras, + naive_block_assignment, + ) if expert_id == -1: return + + # Get token offsets + offs_token = _get_token_offs( + sorted_token_ids_ptr, + lora_id, + pid_m, + offs, + stride_tl, + max_loras, + num_valid_tokens, + naive_block_assignment, + BLOCK_SIZE_M, + ) # get a_ptr,b_ptr,c_ptr cur_a_ptr = a_ptr + (slice_id % num_slice_a) * slice_a_size cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty)) @@ -145,19 +253,12 @@ def _fused_moe_lora_kernel( # remove modulo wrap-around offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int32) offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int32) - token_ind = stride_tl * lora_id + offs_token_id - offs_token = tl.load( - sorted_token_ids_ptr + token_ind, - mask=token_ind < max_loras * stride_tl, - other=num_valid_tokens, - ) token_mask = offs_token < num_valid_tokens # get a_ptrs,b_ptrs a_ptrs = cur_a_ptr + ( - offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + offs_token[:, None] // token_mapping_factor * stride_am + + offs_k[None, :] * stride_ak ) b_ptrs = ( @@ -230,9 +331,10 @@ def _fused_moe_lora_shrink( torch.Tensor ], # [(max_loras, num_experts, max_lora_rank, K,),...] topk_weights: torch.Tensor, # (num_tokens, top_k_num) - sorted_token_ids: torch.Tensor, # (max_loras, _) - expert_ids: torch.Tensor, # (max_loras, _ ,) - num_tokens_post_padded: torch.Tensor, # (max_loras, ) + sorted_token_ids: torch.Tensor | None, # (max_loras, _) + expert_ids: torch.Tensor, # (max_loras, _ ,) or (num_tokens * top_k,) + num_tokens_post_padded: torch.Tensor | None, # (max_loras, ) + token_lora_mapping: torch.Tensor, top_k_num: int, lora_ids: torch.Tensor, adapter_enabled: torch.Tensor, @@ -270,13 +372,15 @@ def _fused_moe_lora_shrink( b_ptr = _get_ptr(lora_a_stacked, device) + grid_lora_dim, stride_tl, stride_el = _adjust_kernel_inputs( + w1_lora_a_stacked.shape[0], sorted_token_ids, expert_ids + ) grid = lambda META: ( split_k * triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), len(lora_a_stacked), - ## max_loras + 1 to handle the no-lora case (lora_id == -1) - lora_a_stacked[0].shape[0] + 1, + grid_lora_dim, ) _fused_moe_lora_kernel[grid]( qcurr_hidden_states, @@ -286,11 +390,13 @@ def _fused_moe_lora_shrink( sorted_token_ids, expert_ids, num_tokens_post_padded, + token_lora_mapping, N, K, EM, num_tokens, num_experts, + top_k_num, lora_ids, adapter_enabled, lora_a_stacked[0].shape[0], @@ -302,13 +408,14 @@ def _fused_moe_lora_shrink( w1_lora_a_stacked.stride(2), a_intermediate_cache1.stride(2), a_intermediate_cache1.stride(3), - sorted_token_ids.stride(0), - expert_ids.stride(0), + stride_tl, + stride_el, slice_a_size=qcurr_hidden_states.numel(), slice_c_size=a_intermediate_cache1.numel() // num_slices, num_slice_a=1, num_slice_c=num_slices, - top_k=1 if mul_routed_weight else top_k_num, + token_mapping_factor=1 if mul_routed_weight else top_k_num, + naive_block_assignment=sorted_token_ids is None, MUL_ROUTED_WEIGHT=False, ADD_INPUTS=False, USE_B_L2_CACHE=True, # new @@ -325,9 +432,10 @@ def _fused_moe_lora_expand( torch.Tensor ], # [(max_loras, num_experts, max_lora_rank, K,),...] topk_weights: torch.Tensor, # (num_tokens, top_k_num) - sorted_token_ids: torch.Tensor, # (max_loras, _) - expert_ids: torch.Tensor, # (max_loras, _ ,) - num_tokens_post_padded: torch.Tensor, # (max_loras, ) + sorted_token_ids: torch.Tensor | None, # (max_loras, _) + expert_ids: torch.Tensor, # (max_loras, _ ,) or (num_tokens * top_k,) + num_tokens_post_padded: torch.Tensor | None, # (max_loras, ) + token_lora_mapping: torch.Tensor, top_k_num: int, lora_ids: torch.Tensor, adapter_enabled: torch.Tensor, @@ -375,11 +483,14 @@ def _fused_moe_lora_expand( "launch_pdl": use_gdc, # triton kernel metadata } + grid_lora_dim, stride_tl, stride_el = _adjust_kernel_inputs( + w1_lora_b_stacked.shape[0], sorted_token_ids, expert_ids + ) + grid = lambda META: ( triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), len(lora_b_stacked), - ## max_loras + 1 to handle the no-lora case (lora_id == -1) - lora_b_stacked[0].shape[0] + 1, + grid_lora_dim, ) # Fast path: directly accumulate into the corresponding slice interval of output. @@ -394,11 +505,13 @@ def _fused_moe_lora_expand( sorted_token_ids, expert_ids, num_tokens_post_padded, + token_lora_mapping, N, K, EM, num_tokens, num_experts, + top_k_num, lora_ids, adapter_enabled, lora_b_stacked[0].shape[0], @@ -410,13 +523,14 @@ def _fused_moe_lora_expand( w1_lora_b_stacked.stride(2), out_view.stride(1), out_view.stride(2), - sorted_token_ids.stride(0), - expert_ids.stride(0), + stride_tl, + stride_el, slice_a_size=a_intermediate_cache1.numel() // num_slices, slice_c_size=slice_c_size, num_slice_a=num_slices, num_slice_c=num_slices, - top_k=1, + token_mapping_factor=1, + naive_block_assignment=sorted_token_ids is None, MUL_ROUTED_WEIGHT=mul_routed_weight, ADD_INPUTS=True, USE_B_L2_CACHE=True, # new @@ -436,9 +550,10 @@ def _fused_moe_lora( torch.Tensor ], # [(max_loras, num_experts, N, max_lora_rank,),...] topk_weights: torch.Tensor, # (num_tokens, top_k_num) - sorted_token_ids: torch.Tensor, # (max_loras, _) - expert_ids: torch.Tensor, # (max_loras, _ ,) - num_tokens_post_padded: torch.Tensor, # (max_loras, ) + sorted_token_ids: torch.Tensor | None, # (max_loras, _) + expert_ids: torch.Tensor, # (max_loras, _ ,) or (num_tokens * top_k,) + num_tokens_post_padded: torch.Tensor | None, # (max_loras, ) + token_lora_mapping: torch.Tensor, max_lora_rank: int, top_k_num: int, lora_ids: torch.Tensor, @@ -462,18 +577,24 @@ def _fused_moe_lora( offset: int = 0, ) -> None: assert len(lora_a_stacked) == len(lora_b_stacked) > 0 - assert ( - sorted_token_ids.dim() - == expert_ids.dim() - == topk_weights.dim() - == qcurr_hidden_states.dim() - == 2 - ) - assert ( - sorted_token_ids.shape[0] - == expert_ids.shape[0] - == num_tokens_post_padded.shape[0] - ) + assert topk_weights.dim() == qcurr_hidden_states.dim() == 2 + if sorted_token_ids is None: + assert expert_ids.dim() == 1 + else: + assert sorted_token_ids is not None + assert num_tokens_post_padded is not None + assert ( + sorted_token_ids.dim() + == expert_ids.dim() + == topk_weights.dim() + == qcurr_hidden_states.dim() + == 2 + ) + assert ( + sorted_token_ids.shape[0] + == expert_ids.shape[0] + == num_tokens_post_padded.shape[0] + ) assert output.shape[0] == topk_weights.shape[0] assert top_k_num == topk_weights.shape[1] device = qcurr_hidden_states.device @@ -482,10 +603,15 @@ def _fused_moe_lora( num_experts = lora_a_stacked[0].shape[1] N = max_lora_rank M = topk_weights.shape[0] - EM = sorted_token_ids.shape[1] K = qcurr_hidden_states.shape[1] num_tokens = M * top_k_num w1_output_dim_size = w1_lora_b_stacked.shape[2] + assert shrink_block_size_m == expand_block_size_m + EM = ( + sorted_token_ids.shape[1] + if sorted_token_ids is not None + else num_tokens * shrink_block_size_m + ) a_intermediate_cache1 = torch.zeros( (num_slices, M, top_k_num, max_lora_rank), @@ -502,6 +628,7 @@ def _fused_moe_lora( sorted_token_ids, expert_ids, num_tokens_post_padded, + token_lora_mapping, top_k_num, lora_ids, adapter_enabled, @@ -546,6 +673,7 @@ def _fused_moe_lora( sorted_token_ids, expert_ids, num_tokens_post_padded, + token_lora_mapping, top_k_num, lora_ids, adapter_enabled, @@ -579,9 +707,10 @@ def _fused_moe_lora_fake( lora_a_stacked: list[torch.Tensor], lora_b_stacked: list[torch.Tensor], topk_weights: torch.Tensor, - sorted_token_ids: torch.Tensor, + sorted_token_ids: torch.Tensor | None, expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, + num_tokens_post_padded: torch.Tensor | None, + token_lora_mapping: torch.Tensor, max_lora_rank: int, top_k_num: int, lora_ids: torch.Tensor, @@ -610,9 +739,10 @@ def _fused_moe_lora_shrink_fake( qcurr_hidden_states: torch.Tensor, lora_a_stacked: list[torch.Tensor], topk_weights: torch.Tensor, - sorted_token_ids: torch.Tensor, + sorted_token_ids: torch.Tensor | None, expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, + num_tokens_post_padded: torch.Tensor | None, + token_lora_mapping: torch.Tensor, top_k_num: int, lora_ids: torch.Tensor, adapter_enabled: torch.Tensor, @@ -642,9 +772,10 @@ def _fused_moe_lora_expand_fake( a_intermediate_cache1: torch.Tensor, lora_b_stacked: list[torch.Tensor], topk_weights: torch.Tensor, - sorted_token_ids: torch.Tensor, + sorted_token_ids: torch.Tensor | None, expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, + num_tokens_post_padded: torch.Tensor | None, + token_lora_mapping: torch.Tensor, top_k_num: int, lora_ids: torch.Tensor, adapter_enabled: torch.Tensor, diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index 47c42b095..fdcf6c0cb 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -458,7 +458,7 @@ class PunicaWrapperBase(PunicaWrapperABC): adapter_enabled: torch.Tensor, expert_map: torch.Tensor | None = None, pad_sorted_ids: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Aligns tokens and experts into block-sized chunks for LoRA-based mixture-of-experts (MoE) execution. @@ -473,9 +473,9 @@ class PunicaWrapperBase(PunicaWrapperABC): lora_a_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...], topk_weights: torch.Tensor, - sorted_token_ids: torch.Tensor, + sorted_token_ids: torch.Tensor | None, expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, + num_tokens_post_padded: torch.Tensor | None, max_lora_rank: int, top_k_num: int, shrink_config, @@ -484,6 +484,7 @@ class PunicaWrapperBase(PunicaWrapperABC): mul_routed_weight=False, fully_sharded: bool = False, offset: int = 0, + token_lora_mapping: torch.Tensor | None = None, ): """ Performs a fused forward computation for LoRA of diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index f765df0b3..b704a74c7 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -310,52 +310,57 @@ class PunicaWrapperGPU(PunicaWrapperBase): adapter_enabled: torch.Tensor, expert_map: torch.Tensor | None = None, pad_sorted_ids: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + naive_block_assignment: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Aligns tokens and experts into block-sized chunks for LoRA-based mixture-of-experts (MoE) execution. """ - max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) - if pad_sorted_ids: - max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) - sorted_ids = torch.empty( - (max_loras * max_num_tokens_padded,), - dtype=torch.int32, - device=topk_ids.device, - ) - max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) - # Expert ids must be set default to -1 to prevent a blank block - expert_ids = torch.empty( - (max_loras * max_num_m_blocks,), - dtype=torch.int32, - device=topk_ids.device, - ) - num_tokens_post_pad = torch.empty( - (max_loras), dtype=torch.int32, device=topk_ids.device - ) - (token_lora_mapping, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args( num_tokens ) + if naive_block_assignment: + expert_ids = topk_ids.reshape(-1) + sorted_ids = None + num_tokens_post_pad = None + else: + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + if pad_sorted_ids: + max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + sorted_ids = torch.empty( + (max_loras * max_num_tokens_padded,), + dtype=torch.int32, + device=topk_ids.device, + ) + max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) + # Expert ids must be set default to -1 to prevent a blank block + expert_ids = torch.empty( + (max_loras * max_num_m_blocks,), + dtype=torch.int32, + device=topk_ids.device, + ) + num_tokens_post_pad = torch.empty( + (max_loras), dtype=torch.int32, device=topk_ids.device + ) - ops.moe_lora_align_block_size( - topk_ids, - token_lora_mapping, - num_experts, - block_size, - max_loras, - max_num_tokens_padded, - max_num_m_blocks, - sorted_ids, - expert_ids, - num_tokens_post_pad, - adapter_enabled, - lora_ids, - ) - if expert_map is not None: - expert_ids = expert_map[expert_ids] + ops.moe_lora_align_block_size( + topk_ids, + token_lora_mapping, + num_experts, + block_size, + max_loras, + max_num_tokens_padded, + max_num_m_blocks, + sorted_ids, + expert_ids, + num_tokens_post_pad, + adapter_enabled, + lora_ids, + ) + if expert_map is not None: + expert_ids = expert_map[expert_ids] - return sorted_ids, expert_ids, num_tokens_post_pad + return None, sorted_ids, expert_ids, num_tokens_post_pad def add_lora_fused_moe( self, @@ -364,9 +369,9 @@ class PunicaWrapperGPU(PunicaWrapperBase): lora_a_stacked: tuple[torch.Tensor, ...], lora_b_stacked: tuple[torch.Tensor, ...], topk_weights: torch.Tensor, - sorted_token_ids: torch.Tensor, + sorted_token_ids: torch.Tensor | None, expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, + num_tokens_post_padded: torch.Tensor | None, max_lora_rank: int, top_k_num: int, shrink_config, @@ -375,11 +380,21 @@ class PunicaWrapperGPU(PunicaWrapperBase): mul_routed_weight=False, fully_sharded: bool = False, offset: int = 0, + token_lora_mapping: torch.Tensor | None = None, ): """ Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer. """ - (_, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(x.size(0)) + ( + token_lora_mapping_meta, + _, + _, + _, + lora_ids, + _, + ) = self.token_mapping_meta.meta_args(x.size(0)) + if token_lora_mapping is None: + token_lora_mapping = token_lora_mapping_meta fused_moe_lora( y, x, @@ -389,6 +404,7 @@ class PunicaWrapperGPU(PunicaWrapperBase): sorted_token_ids, expert_ids, num_tokens_post_padded, + token_lora_mapping, max_lora_rank, top_k_num, lora_ids,