[Perf] Optimize cutlass moe problem size calculation, 5.3% E2E Throughput improvement, 2.2% TTFT improvement (#31830)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Wentao Ye
2026-01-09 14:13:43 -05:00
committed by GitHub
parent 28ae32a5d3
commit 308feab33f
6 changed files with 172 additions and 63 deletions

View File

@@ -108,15 +108,7 @@ def run_cutlass_moe_fp8(
assert global_num_experts != -1
assert a1q_scale is not None
if expert_map is not None:
"Translate info from expert_map to topk_ids"
local_topk_ids = torch.where(
expert_map[topk_ids] != -1, expert_map[topk_ids], -1
)
else:
local_topk_ids = topk_ids
topk = local_topk_ids.size(1)
topk = topk_ids.size(1)
local_E = w1.size(0)
if use_batched_format:
@@ -164,12 +156,8 @@ def run_cutlass_moe_fp8(
# during offset calculations
expert_offsets = expert_offsets.to(torch.int64)
else:
problem_sizes1 = torch.empty(
(global_num_experts, 3), dtype=torch.int32, device=device
)
problem_sizes2 = torch.empty(
(global_num_experts, 3), dtype=torch.int32, device=device
)
problem_sizes1 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
problem_sizes2 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
num_expert = global_num_experts if expert_map is None else expert_map.size(0)
# permuted a1q reuses workspace2
@@ -182,11 +170,12 @@ def run_cutlass_moe_fp8(
expert_map,
permuted_hidden_states=a1q_perm,
)
expert_offsets = expert_first_token_offset[:-1]
ops.get_cutlass_moe_mm_problem_sizes(
local_topk_ids, problem_sizes1, problem_sizes2, global_num_experts, N, K
# swap_ab is a CUTLASS grouped-GEMM optimization (M <= 64 reduces padding).
swap_ab = a1q.size(0) <= 64
ops.get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
expert_first_token_offset, problem_sizes1, problem_sizes2, N, K, swap_ab
)
expert_offsets = expert_first_token_offset[:-1]
if not per_act_token and (expert_map is not None or use_batched_format):
# this is necessary to avoid imprecise scale calculation caused by
@@ -240,9 +229,7 @@ def run_cutlass_moe_fp8(
permuted_hidden_states=mm2_out,
topk_weights=topk_weights,
inv_permuted_idx=inv_perm,
expert_first_token_offset=(
expert_first_token_offset if expert_map is not None else None
),
expert_first_token_offset=expert_first_token_offset,
)
@@ -772,15 +759,7 @@ def run_cutlass_moe_w4a8_fp8(
f"w1 hidden size mismatch: got {w1.size(2) * 8}, expected {K=}"
)
# Translate info from expert_map to topk_ids
if expert_map is not None:
local_topk_ids = torch.where(
expert_map[topk_ids] != -1, expert_map[topk_ids], -1
)
else:
local_topk_ids = topk_ids
topk = local_topk_ids.size(1)
topk = topk_ids.size(1)
a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), (M * topk, K))
mm1_out = _resize_cache(workspace13, (M * topk, N * 2))
act_out = _resize_cache(workspace2, (M * topk, N))
@@ -790,12 +769,8 @@ def run_cutlass_moe_w4a8_fp8(
)
mm2_out = _resize_cache(workspace2, (M * topk, K))
problem_sizes1 = torch.empty(
(global_num_experts, 3), dtype=torch.int32, device=device
)
problem_sizes2 = torch.empty(
(global_num_experts, 3), dtype=torch.int32, device=device
)
problem_sizes1 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
problem_sizes2 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
num_expert = global_num_experts if expert_map is None else expert_map.size(0)
# permuted a1q reuses workspace2
@@ -808,18 +783,11 @@ def run_cutlass_moe_w4a8_fp8(
expert_map,
permuted_hidden_states=a1q_perm,
)
expert_offsets = expert_first_token_offset[:-1]
# For RS gemm SwapAB is always enabled (swap logical M, N in the problem shape)
ops.get_cutlass_moe_mm_problem_sizes(
local_topk_ids,
problem_sizes1,
problem_sizes2,
global_num_experts,
N,
K,
force_swap_ab=True,
# for RS gemm SwapAB is always enabled (swap logical M, N in the problem shape).
ops.get_cutlass_moe_mm_problem_sizes_from_expert_offsets(
expert_first_token_offset, problem_sizes1, problem_sizes2, N, K, True
)
expert_offsets = expert_first_token_offset[:-1]
ops.cutlass_w4a8_moe_mm(
mm1_out,
@@ -866,9 +834,7 @@ def run_cutlass_moe_w4a8_fp8(
permuted_hidden_states=mm2_out,
topk_weights=topk_weights,
inv_permuted_idx=inv_perm,
expert_first_token_offset=(
expert_first_token_offset if expert_map is not None else None
),
expert_first_token_offset=expert_first_token_offset,
)