[Performance] Performance improvements in non-blockwise fp8 CUTLASS MoE (#20762)

Signed-off-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
ElizaWszola
2025-07-17 15:56:44 +02:00
committed by GitHub
parent 2d6a38209b
commit 9fb2d22032
6 changed files with 174 additions and 38 deletions

View File

@@ -80,6 +80,11 @@ def bench_run(
a, score, topk, renormalize=False
)
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
def run_triton_moe(
a: torch.Tensor,
w1: torch.Tensor,
@@ -111,6 +116,10 @@ def bench_run(
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
per_act_token: bool,
@@ -125,6 +134,10 @@ def bench_run(
topk_ids,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
per_act_token,
a1_scale=None,
)
@@ -136,6 +149,10 @@ def bench_run(
w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
):
@@ -150,6 +167,10 @@ def bench_run(
topk_ids,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
per_act_token,
a1_scale=None,
)
@@ -194,6 +215,10 @@ def bench_run(
w2_q,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
topk_weights,
topk_ids,
)
@@ -231,6 +256,10 @@ def bench_run(
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"per_act_token": per_act_token,
"ab_strides1": ab_strides1,
"ab_strides2": ab_strides2,
"c_strides1": c_strides1,
"c_strides2": c_strides2,
# cuda graph params
"cutlass_graph": cutlass_graph,
"triton_graph": triton_graph,
@@ -289,6 +318,10 @@ def bench_run(
w2_q,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
topk_weights,
topk_ids,
per_act_token,
@@ -297,7 +330,7 @@ def bench_run(
results.append(
benchmark.Timer(
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,