[Kernels] MoE refactor (#19636)

Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Co-authored-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
bnellnm
2025-07-02 09:08:27 -04:00
committed by GitHub
parent b95877509b
commit c1909e7e8c
36 changed files with 2698 additions and 1584 deletions

View File

@@ -113,6 +113,7 @@ def bench_run(
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
per_act_token: bool,
num_repeats: int,
):
for _ in range(num_repeats):
@@ -124,7 +125,8 @@ def bench_run(
topk_ids,
w1_scale,
w2_scale,
a1_scale=a_scale,
per_act_token,
a1_scale=None,
)
def run_cutlass_from_graph(
@@ -148,7 +150,8 @@ def bench_run(
topk_ids,
w1_scale,
w2_scale,
a1_scale=a_scale,
per_act_token,
a1_scale=None,
)
def run_triton_from_graph(
@@ -227,6 +230,7 @@ def bench_run(
"w2_q": w2_q,
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"per_act_token": per_act_token,
# cuda graph params
"cutlass_graph": cutlass_graph,
"triton_graph": triton_graph,
@@ -287,12 +291,13 @@ def bench_run(
w2_scale,
topk_weights,
topk_ids,
per_act_token,
num_warmup,
)
results.append(
benchmark.Timer(
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, num_runs)", # noqa: E501
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,