[lora/moe] Avoid extra intermediate buffer & Python slicing in expand phase when split_k == 1 (#32774)

Signed-off-by: 陈建华 <1647430658@qq.com>
This commit is contained in:
cwazai
2026-01-29 00:22:45 +08:00
committed by GitHub
parent 392c5af4fe
commit f210f0b7b1

View File

@@ -84,6 +84,7 @@ def _fused_moe_lora_kernel(
num_slice_c: tl.constexpr,
top_k: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
ADD_INPUTS: tl.constexpr,
USE_B_L2_CACHE: tl.constexpr, # new, enable .ca load for B
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
@@ -211,7 +212,11 @@ def _fused_moe_lora_kernel(
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
if SPLIT_K == 1:
tl.store(c_ptrs, accumulator, mask=c_mask)
if ADD_INPUTS:
prev = tl.load(c_ptrs, mask=c_mask, other=0.0)
tl.store(c_ptrs, prev + accumulator, mask=c_mask)
else:
tl.store(c_ptrs, accumulator, mask=c_mask)
else:
tl.atomic_add(c_ptrs, accumulator, mask=c_mask, sem="relaxed")
@@ -305,6 +310,7 @@ def _fused_moe_lora_shrink(
num_slice_c=num_slices,
top_k=1 if mul_routed_weight else top_k_num,
MUL_ROUTED_WEIGHT=False,
ADD_INPUTS=False,
USE_B_L2_CACHE=True, # new
IS_PRIMARY=True,
**shrink_config,
@@ -315,7 +321,6 @@ def _fused_moe_lora_shrink(
def _fused_moe_lora_expand(
output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),)
a_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, max_lora_rank)
b_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, output_dim_size)
lora_b_stacked: list[
torch.Tensor
], # [(max_loras, num_experts, max_lora_rank, K,),...]
@@ -376,10 +381,15 @@ def _fused_moe_lora_expand(
## max_loras + 1 to handle the no-lora case (lora_id == -1)
lora_b_stacked[0].shape[0] + 1,
)
# Fast path: directly accumulate into the corresponding slice interval of output.
out_view = output[:, :, offset : offset + num_slices * N]
slice_c_size = N * out_view.stride(2)
_fused_moe_lora_kernel[grid](
a_intermediate_cache1,
b_ptr,
b_intermediate_cache1,
out_view,
topk_weights,
sorted_token_ids,
expert_ids,
@@ -398,22 +408,21 @@ def _fused_moe_lora_expand(
w1_lora_b_stacked.stride(1),
w1_lora_b_stacked.stride(3),
w1_lora_b_stacked.stride(2),
b_intermediate_cache1.stride(2),
b_intermediate_cache1.stride(3),
out_view.stride(1),
out_view.stride(2),
sorted_token_ids.stride(0),
expert_ids.stride(0),
slice_a_size=a_intermediate_cache1.numel() // num_slices,
slice_c_size=b_intermediate_cache1.numel() // num_slices,
slice_c_size=slice_c_size,
num_slice_a=num_slices,
num_slice_c=num_slices,
top_k=1,
MUL_ROUTED_WEIGHT=mul_routed_weight,
ADD_INPUTS=True,
USE_B_L2_CACHE=True, # new
IS_PRIMARY=False,
**expand_config,
)
for i in range(num_slices):
output[:, :, i * N + offset : (i + 1) * N + offset] += b_intermediate_cache1[i]
@torch.inference_mode()
@@ -484,11 +493,6 @@ def _fused_moe_lora(
device=device,
)
b_intermediate_cache1 = torch.zeros(
(num_slices, M, top_k_num, w1_output_dim_size),
dtype=output.dtype,
device=device,
)
use_gdc = supports_pdl(device) and not fully_sharded
_fused_moe_lora_shrink(
a_intermediate_cache1,
@@ -537,7 +541,6 @@ def _fused_moe_lora(
_fused_moe_lora_expand(
output,
a_intermediate_cache1,
b_intermediate_cache1,
lora_b_stacked,
topk_weights,
sorted_token_ids,