[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-09-17 19:43:31 -04:00
committed by GitHub
parent e6585ddb45
commit 5963b98b46
68 changed files with 2698 additions and 2526 deletions

View File

@@ -13,6 +13,10 @@ import torch.utils.benchmark as benchmark
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
nvfp4_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
from vllm.scalar_type import scalar_types
@@ -140,6 +144,12 @@ def bench_run(
a_fp8_scale: torch.Tensor,
num_repeats: int,
):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_fp8_scale,
)
for _ in range(num_repeats):
fused_experts(
a,
@@ -147,10 +157,7 @@ def bench_run(
w2,
topk_weights,
topk_ids,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_fp8_scale,
quant_config=quant_config,
)
def run_cutlass_moe_fp4(
@@ -172,25 +179,27 @@ def bench_run(
device: torch.device,
num_repeats: int,
):
quant_config = nvfp4_moe_quant_config(
a1_gscale=a1_gs,
a2_gscale=a2_gs,
w1_scale=w1_blockscale,
w2_scale=w2_blockscale,
g1_alphas=w1_gs,
g2_alphas=w2_gs,
)
for _ in range(num_repeats):
with nvtx.annotate("cutlass_moe_fp4", color="green"):
cutlass_moe_fp4(
a=a,
a1_gscale=a1_gs,
a2_gscale=a2_gs,
w1_fp4=w1_fp4,
w1_blockscale=w1_blockscale,
w1_alphas=w1_gs,
w2_fp4=w2_fp4,
w2_blockscale=w2_blockscale,
w2_alphas=w2_gs,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=m,
n=n,
k=k,
e=num_experts,
device=device,
quant_config=quant_config,
)
def run_cutlass_from_graph(
@@ -211,26 +220,29 @@ def bench_run(
e: int,
device: torch.device,
):
quant_config = nvfp4_moe_quant_config(
a1_gscale=a1_gs,
a2_gscale=a2_gs,
w1_scale=w1_blockscale,
w2_scale=w2_blockscale,
g1_alphas=w1_gs,
g2_alphas=w2_gs,
)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
return cutlass_moe_fp4(
a=a,
a1_gscale=a1_gs,
w1_fp4=w1_fp4,
w1_blockscale=w1_blockscale,
w1_alphas=w1_alphas,
a2_gscale=a2_gs,
w2_fp4=w2_fp4,
w2_blockscale=w2_blockscale,
w2_alphas=w2_alphas,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=m,
n=n,
k=k,
e=num_experts,
device=device,
quant_config=quant_config,
)
def run_triton_from_graph(
@@ -246,16 +258,18 @@ def bench_run(
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
quant_config = fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_fp8_scale,
)
return fused_experts(
a,
w1,
w2,
topk_weights,
topk_ids,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a_fp8_scale,
quant_config=quant_config,
)
def replay_graph(graph, num_repeats):