permute/unpermute kernel for moe optimization (#14568)

Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
This commit is contained in:
Caleb_Du
2025-05-03 02:31:55 +08:00
committed by GitHub
parent 0f87d8f7b2
commit 3e887d2e0c
19 changed files with 1474 additions and 28 deletions

View File

@@ -338,7 +338,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
M, K = a.shape
N = w2.shape[-1]
topk_weight, topk_ids = fused_topk(a, score.float(), topk, False)
topk_weight, topk_ids, token_expert_indices = fused_topk(
a, score.float(), topk, False)
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
@@ -435,7 +436,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score,
topk, block_size)
topk_weights, topk_ids = fused_topk(a, score.float(), topk, False)
topk_weights, topk_ids, token_expert_indices = fused_topk(
a, score.float(), topk, False)
out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)