From f210f0b7b1e4f179cf194c09c8805898134b70d8 Mon Sep 17 00:00:00 2001 From: cwazai <38356712+cwazai@users.noreply.github.com> Date: Thu, 29 Jan 2026 00:22:45 +0800 Subject: [PATCH] [lora/moe] Avoid extra intermediate buffer & Python slicing in expand phase when split_k == 1 (#32774) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 陈建华 <1647430658@qq.com> --- vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 31 ++++++++++--------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index ff06725cd..58549ee9f 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -84,6 +84,7 @@ def _fused_moe_lora_kernel( num_slice_c: tl.constexpr, top_k: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr, + ADD_INPUTS: tl.constexpr, USE_B_L2_CACHE: tl.constexpr, # new, enable .ca load for B BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -211,7 +212,11 @@ def _fused_moe_lora_kernel( c_mask = token_mask[:, None] & (offs_cn[None, :] < N) if SPLIT_K == 1: - tl.store(c_ptrs, accumulator, mask=c_mask) + if ADD_INPUTS: + prev = tl.load(c_ptrs, mask=c_mask, other=0.0) + tl.store(c_ptrs, prev + accumulator, mask=c_mask) + else: + tl.store(c_ptrs, accumulator, mask=c_mask) else: tl.atomic_add(c_ptrs, accumulator, mask=c_mask, sem="relaxed") @@ -305,6 +310,7 @@ def _fused_moe_lora_shrink( num_slice_c=num_slices, top_k=1 if mul_routed_weight else top_k_num, MUL_ROUTED_WEIGHT=False, + ADD_INPUTS=False, USE_B_L2_CACHE=True, # new IS_PRIMARY=True, **shrink_config, @@ -315,7 +321,6 @@ def _fused_moe_lora_shrink( def _fused_moe_lora_expand( output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),) a_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, max_lora_rank) - b_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, output_dim_size) lora_b_stacked: list[ torch.Tensor ], # [(max_loras, num_experts, max_lora_rank, K,),...] @@ -376,10 +381,15 @@ def _fused_moe_lora_expand( ## max_loras + 1 to handle the no-lora case (lora_id == -1) lora_b_stacked[0].shape[0] + 1, ) + + # Fast path: directly accumulate into the corresponding slice interval of output. + out_view = output[:, :, offset : offset + num_slices * N] + slice_c_size = N * out_view.stride(2) + _fused_moe_lora_kernel[grid]( a_intermediate_cache1, b_ptr, - b_intermediate_cache1, + out_view, topk_weights, sorted_token_ids, expert_ids, @@ -398,22 +408,21 @@ def _fused_moe_lora_expand( w1_lora_b_stacked.stride(1), w1_lora_b_stacked.stride(3), w1_lora_b_stacked.stride(2), - b_intermediate_cache1.stride(2), - b_intermediate_cache1.stride(3), + out_view.stride(1), + out_view.stride(2), sorted_token_ids.stride(0), expert_ids.stride(0), slice_a_size=a_intermediate_cache1.numel() // num_slices, - slice_c_size=b_intermediate_cache1.numel() // num_slices, + slice_c_size=slice_c_size, num_slice_a=num_slices, num_slice_c=num_slices, top_k=1, MUL_ROUTED_WEIGHT=mul_routed_weight, + ADD_INPUTS=True, USE_B_L2_CACHE=True, # new IS_PRIMARY=False, **expand_config, ) - for i in range(num_slices): - output[:, :, i * N + offset : (i + 1) * N + offset] += b_intermediate_cache1[i] @torch.inference_mode() @@ -484,11 +493,6 @@ def _fused_moe_lora( device=device, ) - b_intermediate_cache1 = torch.zeros( - (num_slices, M, top_k_num, w1_output_dim_size), - dtype=output.dtype, - device=device, - ) use_gdc = supports_pdl(device) and not fully_sharded _fused_moe_lora_shrink( a_intermediate_cache1, @@ -537,7 +541,6 @@ def _fused_moe_lora( _fused_moe_lora_expand( output, a_intermediate_cache1, - b_intermediate_cache1, lora_b_stacked, topk_weights, sorted_token_ids,