permute/unpermute kernel for moe optimization (#14568)
Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
This commit is contained in:
@@ -338,7 +338,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
|
||||
M, K = a.shape
|
||||
N = w2.shape[-1]
|
||||
|
||||
topk_weight, topk_ids = fused_topk(a, score.float(), topk, False)
|
||||
topk_weight, topk_ids, token_expert_indices = fused_topk(
|
||||
a, score.float(), topk, False)
|
||||
|
||||
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
|
||||
|
||||
@@ -435,7 +436,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
|
||||
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score,
|
||||
topk, block_size)
|
||||
|
||||
topk_weights, topk_ids = fused_topk(a, score.float(), topk, False)
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
a, score.float(), topk, False)
|
||||
|
||||
out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user