[MoE Refactor][15/N] Apply Refactor to Fp8 (#31415)

This commit is contained in:
Robert Shaw
2026-01-07 19:42:33 -05:00
committed by GitHub
parent ffc0a2798b
commit 5dcd7ef1f2
38 changed files with 1439 additions and 1528 deletions

View File

@@ -5,14 +5,18 @@ import torch
import torch.utils.benchmark as benchmark
from benchmark_shapes import WEIGHT_SHAPES_MOE
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts,
fused_topk,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.worker.workspace import init_workspace_manager
@@ -45,6 +49,7 @@ def bench_run(
per_out_ch: bool,
mkn: tuple[int, int, int],
):
init_workspace_manager(torch.cuda.current_device())
label = "Quant Matmul"
sub_label = (
@@ -82,11 +87,6 @@ def bench_run(
a, score, topk, renormalize=False
)
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
def run_triton_moe(
a: torch.Tensor,
w1: torch.Tensor,
@@ -120,10 +120,6 @@ def bench_run(
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
per_act_token: bool,
@@ -135,31 +131,29 @@ def bench_run(
per_act_token_quant=per_act_token,
)
for _ in range(num_repeats):
cutlass_moe_fp8(
a,
w1,
w2,
topk_weights,
topk_ids,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
out_dtype=a.dtype,
# NOTE(rob): w2 is shaped as [E, hidden, intermediate]
e=w2.shape[0],
n=w2.shape[2],
k=w2.shape[1],
quant_config=quant_config,
)
device=w1.device,
),
)
for _ in range(num_repeats):
fn(a, w1, w2, topk_weights, topk_ids)
def run_cutlass_from_graph(
a: torch.Tensor,
a_scale: torch.Tensor,
w1_q: torch.Tensor,
w2_q: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
ab_strides1: torch.Tensor,
ab_strides2: torch.Tensor,
c_strides1: torch.Tensor,
c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
):
@@ -169,21 +163,23 @@ def bench_run(
per_act_token_quant=per_act_token,
)
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
out_dtype=a.dtype,
# NOTE(rob): w2 is shaped as [E, hidden, intermediate]
e=w2.shape[0],
n=w2.shape[2],
k=w2.shape[1],
quant_config=quant_config,
device=w1.device,
),
)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
return cutlass_moe_fp8(
a,
w1_q,
w2_q,
topk_weights,
topk_ids,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
quant_config=quant_config,
)
return fn(a, w1, w2, topk_weights, topk_ids)
def run_triton_from_graph(
a: torch.Tensor,
@@ -227,10 +223,6 @@ def bench_run(
w2_q,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
topk_weights,
topk_ids,
)
@@ -268,10 +260,6 @@ def bench_run(
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"per_act_token": per_act_token,
"ab_strides1": ab_strides1,
"ab_strides2": ab_strides2,
"c_strides1": c_strides1,
"c_strides2": c_strides2,
# cuda graph params
"cutlass_graph": cutlass_graph,
"triton_graph": triton_graph,
@@ -330,10 +318,6 @@ def bench_run(
w2_q,
w1_scale,
w2_scale,
ab_strides1,
ab_strides2,
c_strides1,
c_strides2,
topk_weights,
topk_ids,
per_act_token,
@@ -342,7 +326,7 @@ def bench_run(
results.append(
benchmark.Timer(
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,