permute/unpermute kernel for moe optimization (#14568)

Signed-off-by: Caleb_Du <Caleb_Du@zju.edu.cn>
This commit is contained in:
Caleb_Du
2025-05-03 02:31:55 +08:00
committed by GitHub
parent 0f87d8f7b2
commit 3e887d2e0c
19 changed files with 1474 additions and 28 deletions

View File

@@ -90,7 +90,8 @@ def bench_run(results: list[benchmark.Measurement], model: str,
score = torch.randn((m, num_experts), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
topk_weights, topk_ids, token_expert_indices = fused_topk(
a, score, topk, renormalize=False)
def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,