[Perf] Optimize cutlass moe problem size calculation, 5.3% E2E Throughput improvement, 2.2% TTFT improvement (#31830)
Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -108,15 +108,7 @@ def run_cutlass_moe_fp8(
|
||||
assert global_num_experts != -1
|
||||
assert a1q_scale is not None
|
||||
|
||||
if expert_map is not None:
|
||||
"Translate info from expert_map to topk_ids"
|
||||
local_topk_ids = torch.where(
|
||||
expert_map[topk_ids] != -1, expert_map[topk_ids], -1
|
||||
)
|
||||
else:
|
||||
local_topk_ids = topk_ids
|
||||
|
||||
topk = local_topk_ids.size(1)
|
||||
topk = topk_ids.size(1)
|
||||
local_E = w1.size(0)
|
||||
|
||||
if use_batched_format:
|
||||
@@ -164,12 +156,8 @@ def run_cutlass_moe_fp8(
|
||||
# during offset calculations
|
||||
expert_offsets = expert_offsets.to(torch.int64)
|
||||
else:
|
||||
problem_sizes1 = torch.empty(
|
||||
(global_num_experts, 3), dtype=torch.int32, device=device
|
||||
)
|
||||
problem_sizes2 = torch.empty(
|
||||
(global_num_experts, 3), dtype=torch.int32, device=device
|
||||
)
|
||||
problem_sizes1 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
|
||||
problem_sizes2 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
|
||||
|
||||
num_expert = global_num_experts if expert_map is None else expert_map.size(0)
|
||||
# permuted a1q reuses workspace2
|
||||
@@ -182,11 +170,12 @@ def run_cutlass_moe_fp8(
|
||||
expert_map,
|
||||
permuted_hidden_states=a1q_perm,
|
||||
)
|
||||
expert_offsets = expert_first_token_offset[:-1]
|
||||
|
||||
ops.get_cutlass_moe_mm_problem_sizes(
|
||||
local_topk_ids, problem_sizes1, problem_sizes2, global_num_experts, N, K
|
||||
# swap_ab is a CUTLASS grouped-GEMM optimization (M <= 64 reduces padding).
|
||||
swap_ab = a1q.size(0) <= 64
|
||||
ops.get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
||||
expert_first_token_offset, problem_sizes1, problem_sizes2, N, K, swap_ab
|
||||
)
|
||||
expert_offsets = expert_first_token_offset[:-1]
|
||||
|
||||
if not per_act_token and (expert_map is not None or use_batched_format):
|
||||
# this is necessary to avoid imprecise scale calculation caused by
|
||||
@@ -240,9 +229,7 @@ def run_cutlass_moe_fp8(
|
||||
permuted_hidden_states=mm2_out,
|
||||
topk_weights=topk_weights,
|
||||
inv_permuted_idx=inv_perm,
|
||||
expert_first_token_offset=(
|
||||
expert_first_token_offset if expert_map is not None else None
|
||||
),
|
||||
expert_first_token_offset=expert_first_token_offset,
|
||||
)
|
||||
|
||||
|
||||
@@ -772,15 +759,7 @@ def run_cutlass_moe_w4a8_fp8(
|
||||
f"w1 hidden size mismatch: got {w1.size(2) * 8}, expected {K=}"
|
||||
)
|
||||
|
||||
# Translate info from expert_map to topk_ids
|
||||
if expert_map is not None:
|
||||
local_topk_ids = torch.where(
|
||||
expert_map[topk_ids] != -1, expert_map[topk_ids], -1
|
||||
)
|
||||
else:
|
||||
local_topk_ids = topk_ids
|
||||
|
||||
topk = local_topk_ids.size(1)
|
||||
topk = topk_ids.size(1)
|
||||
a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), (M * topk, K))
|
||||
mm1_out = _resize_cache(workspace13, (M * topk, N * 2))
|
||||
act_out = _resize_cache(workspace2, (M * topk, N))
|
||||
@@ -790,12 +769,8 @@ def run_cutlass_moe_w4a8_fp8(
|
||||
)
|
||||
mm2_out = _resize_cache(workspace2, (M * topk, K))
|
||||
|
||||
problem_sizes1 = torch.empty(
|
||||
(global_num_experts, 3), dtype=torch.int32, device=device
|
||||
)
|
||||
problem_sizes2 = torch.empty(
|
||||
(global_num_experts, 3), dtype=torch.int32, device=device
|
||||
)
|
||||
problem_sizes1 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
|
||||
problem_sizes2 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
|
||||
|
||||
num_expert = global_num_experts if expert_map is None else expert_map.size(0)
|
||||
# permuted a1q reuses workspace2
|
||||
@@ -808,18 +783,11 @@ def run_cutlass_moe_w4a8_fp8(
|
||||
expert_map,
|
||||
permuted_hidden_states=a1q_perm,
|
||||
)
|
||||
expert_offsets = expert_first_token_offset[:-1]
|
||||
|
||||
# For RS gemm SwapAB is always enabled (swap logical M, N in the problem shape)
|
||||
ops.get_cutlass_moe_mm_problem_sizes(
|
||||
local_topk_ids,
|
||||
problem_sizes1,
|
||||
problem_sizes2,
|
||||
global_num_experts,
|
||||
N,
|
||||
K,
|
||||
force_swap_ab=True,
|
||||
# for RS gemm SwapAB is always enabled (swap logical M, N in the problem shape).
|
||||
ops.get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
|
||||
expert_first_token_offset, problem_sizes1, problem_sizes2, N, K, True
|
||||
)
|
||||
expert_offsets = expert_first_token_offset[:-1]
|
||||
|
||||
ops.cutlass_w4a8_moe_mm(
|
||||
mm1_out,
|
||||
@@ -866,9 +834,7 @@ def run_cutlass_moe_w4a8_fp8(
|
||||
permuted_hidden_states=mm2_out,
|
||||
topk_weights=topk_weights,
|
||||
inv_permuted_idx=inv_perm,
|
||||
expert_first_token_offset=(
|
||||
expert_first_token_offset if expert_map is not None else None
|
||||
),
|
||||
expert_first_token_offset=expert_first_token_offset,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user