[lora/moe] Avoid extra intermediate buffer & Python slicing in expand phase when split_k == 1 (#32774)
Signed-off-by: 陈建华 <1647430658@qq.com>
This commit is contained in:
@@ -84,6 +84,7 @@ def _fused_moe_lora_kernel(
|
||||
num_slice_c: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
USE_B_L2_CACHE: tl.constexpr, # new, enable .ca load for B
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
@@ -211,7 +212,11 @@ def _fused_moe_lora_kernel(
|
||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||
|
||||
if SPLIT_K == 1:
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
if ADD_INPUTS:
|
||||
prev = tl.load(c_ptrs, mask=c_mask, other=0.0)
|
||||
tl.store(c_ptrs, prev + accumulator, mask=c_mask)
|
||||
else:
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
else:
|
||||
tl.atomic_add(c_ptrs, accumulator, mask=c_mask, sem="relaxed")
|
||||
|
||||
@@ -305,6 +310,7 @@ def _fused_moe_lora_shrink(
|
||||
num_slice_c=num_slices,
|
||||
top_k=1 if mul_routed_weight else top_k_num,
|
||||
MUL_ROUTED_WEIGHT=False,
|
||||
ADD_INPUTS=False,
|
||||
USE_B_L2_CACHE=True, # new
|
||||
IS_PRIMARY=True,
|
||||
**shrink_config,
|
||||
@@ -315,7 +321,6 @@ def _fused_moe_lora_shrink(
|
||||
def _fused_moe_lora_expand(
|
||||
output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),)
|
||||
a_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, max_lora_rank)
|
||||
b_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, output_dim_size)
|
||||
lora_b_stacked: list[
|
||||
torch.Tensor
|
||||
], # [(max_loras, num_experts, max_lora_rank, K,),...]
|
||||
@@ -376,10 +381,15 @@ def _fused_moe_lora_expand(
|
||||
## max_loras + 1 to handle the no-lora case (lora_id == -1)
|
||||
lora_b_stacked[0].shape[0] + 1,
|
||||
)
|
||||
|
||||
# Fast path: directly accumulate into the corresponding slice interval of output.
|
||||
out_view = output[:, :, offset : offset + num_slices * N]
|
||||
slice_c_size = N * out_view.stride(2)
|
||||
|
||||
_fused_moe_lora_kernel[grid](
|
||||
a_intermediate_cache1,
|
||||
b_ptr,
|
||||
b_intermediate_cache1,
|
||||
out_view,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
@@ -398,22 +408,21 @@ def _fused_moe_lora_expand(
|
||||
w1_lora_b_stacked.stride(1),
|
||||
w1_lora_b_stacked.stride(3),
|
||||
w1_lora_b_stacked.stride(2),
|
||||
b_intermediate_cache1.stride(2),
|
||||
b_intermediate_cache1.stride(3),
|
||||
out_view.stride(1),
|
||||
out_view.stride(2),
|
||||
sorted_token_ids.stride(0),
|
||||
expert_ids.stride(0),
|
||||
slice_a_size=a_intermediate_cache1.numel() // num_slices,
|
||||
slice_c_size=b_intermediate_cache1.numel() // num_slices,
|
||||
slice_c_size=slice_c_size,
|
||||
num_slice_a=num_slices,
|
||||
num_slice_c=num_slices,
|
||||
top_k=1,
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
ADD_INPUTS=True,
|
||||
USE_B_L2_CACHE=True, # new
|
||||
IS_PRIMARY=False,
|
||||
**expand_config,
|
||||
)
|
||||
for i in range(num_slices):
|
||||
output[:, :, i * N + offset : (i + 1) * N + offset] += b_intermediate_cache1[i]
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -484,11 +493,6 @@ def _fused_moe_lora(
|
||||
device=device,
|
||||
)
|
||||
|
||||
b_intermediate_cache1 = torch.zeros(
|
||||
(num_slices, M, top_k_num, w1_output_dim_size),
|
||||
dtype=output.dtype,
|
||||
device=device,
|
||||
)
|
||||
use_gdc = supports_pdl(device) and not fully_sharded
|
||||
_fused_moe_lora_shrink(
|
||||
a_intermediate_cache1,
|
||||
@@ -537,7 +541,6 @@ def _fused_moe_lora(
|
||||
_fused_moe_lora_expand(
|
||||
output,
|
||||
a_intermediate_cache1,
|
||||
b_intermediate_cache1,
|
||||
lora_b_stacked,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
|
||||
Reference in New Issue
Block a user