[Perf] Optimize additional fill(0) in cutlass moe, 2.9% E2E throughput improvement, 10.8% TTFT improvement (#31754)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -173,7 +173,7 @@ def run_cutlass_moe_fp8(
|
|||||||
|
|
||||||
num_expert = global_num_experts if expert_map is None else expert_map.size(0)
|
num_expert = global_num_experts if expert_map is None else expert_map.size(0)
|
||||||
# permuted a1q reuses workspace2
|
# permuted a1q reuses workspace2
|
||||||
a1q, a1q_scale, expert_offsets, inv_perm, _ = moe_permute(
|
a1q, a1q_scale, expert_first_token_offset, inv_perm, _ = moe_permute(
|
||||||
a1q,
|
a1q,
|
||||||
a1q_scale,
|
a1q_scale,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
@@ -182,7 +182,7 @@ def run_cutlass_moe_fp8(
|
|||||||
expert_map,
|
expert_map,
|
||||||
permuted_hidden_states=a1q_perm,
|
permuted_hidden_states=a1q_perm,
|
||||||
)
|
)
|
||||||
expert_offsets = expert_offsets[:-1]
|
expert_offsets = expert_first_token_offset[:-1]
|
||||||
|
|
||||||
ops.get_cutlass_moe_mm_problem_sizes(
|
ops.get_cutlass_moe_mm_problem_sizes(
|
||||||
local_topk_ids, problem_sizes1, problem_sizes2, global_num_experts, N, K
|
local_topk_ids, problem_sizes1, problem_sizes2, global_num_experts, N, K
|
||||||
@@ -215,9 +215,6 @@ def run_cutlass_moe_fp8(
|
|||||||
act_out, a2_scale, use_per_token_if_dynamic=per_act_token, output=quant_out
|
act_out, a2_scale, use_per_token_if_dynamic=per_act_token, output=quant_out
|
||||||
)
|
)
|
||||||
|
|
||||||
if expert_map is not None:
|
|
||||||
mm2_out.fill_(0)
|
|
||||||
|
|
||||||
ops.cutlass_moe_mm(
|
ops.cutlass_moe_mm(
|
||||||
mm2_out,
|
mm2_out,
|
||||||
a2q,
|
a2q,
|
||||||
@@ -243,6 +240,9 @@ def run_cutlass_moe_fp8(
|
|||||||
permuted_hidden_states=mm2_out,
|
permuted_hidden_states=mm2_out,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
inv_permuted_idx=inv_perm,
|
inv_permuted_idx=inv_perm,
|
||||||
|
expert_first_token_offset=(
|
||||||
|
expert_first_token_offset if expert_map is not None else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -988,7 +988,7 @@ def run_cutlass_moe_w4a8_fp8(
|
|||||||
|
|
||||||
num_expert = global_num_experts if expert_map is None else expert_map.size(0)
|
num_expert = global_num_experts if expert_map is None else expert_map.size(0)
|
||||||
# permuted a1q reuses workspace2
|
# permuted a1q reuses workspace2
|
||||||
a1q, a1q_scale, expert_offsets, inv_perm, _ = moe_permute(
|
a1q, a1q_scale, expert_first_token_offset, inv_perm, _ = moe_permute(
|
||||||
a1q,
|
a1q,
|
||||||
a1q_scale,
|
a1q_scale,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
@@ -997,7 +997,7 @@ def run_cutlass_moe_w4a8_fp8(
|
|||||||
expert_map,
|
expert_map,
|
||||||
permuted_hidden_states=a1q_perm,
|
permuted_hidden_states=a1q_perm,
|
||||||
)
|
)
|
||||||
expert_offsets = expert_offsets[:-1]
|
expert_offsets = expert_first_token_offset[:-1]
|
||||||
|
|
||||||
# For RS gemm SwapAB is always enabled (swap logical M, N in the problem shape)
|
# For RS gemm SwapAB is always enabled (swap logical M, N in the problem shape)
|
||||||
ops.get_cutlass_moe_mm_problem_sizes(
|
ops.get_cutlass_moe_mm_problem_sizes(
|
||||||
@@ -1032,9 +1032,6 @@ def run_cutlass_moe_w4a8_fp8(
|
|||||||
act_out, a2_scale, use_per_token_if_dynamic=per_act_token, output=quant_out
|
act_out, a2_scale, use_per_token_if_dynamic=per_act_token, output=quant_out
|
||||||
)
|
)
|
||||||
|
|
||||||
if expert_map is not None:
|
|
||||||
mm2_out.fill_(0)
|
|
||||||
|
|
||||||
ops.cutlass_w4a8_moe_mm(
|
ops.cutlass_w4a8_moe_mm(
|
||||||
mm2_out,
|
mm2_out,
|
||||||
a2q,
|
a2q,
|
||||||
@@ -1058,6 +1055,9 @@ def run_cutlass_moe_w4a8_fp8(
|
|||||||
permuted_hidden_states=mm2_out,
|
permuted_hidden_states=mm2_out,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
inv_permuted_idx=inv_perm,
|
inv_permuted_idx=inv_perm,
|
||||||
|
expert_first_token_offset=(
|
||||||
|
expert_first_token_offset if expert_map is not None else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user