[MoE Refactor][16/N] Apply Refactor to NVFP4 (#31692)

Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
Robert Shaw
2026-01-07 22:46:27 -05:00
committed by GitHub
parent 8dd2419fa9
commit 9f6dcb71ae
15 changed files with 777 additions and 681 deletions

View File

@@ -11,14 +11,20 @@ import nvtx
import torch
import torch.utils.benchmark as benchmark
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
nvfp4_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.scalar_type import scalar_types
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.worker.workspace import init_workspace_manager
@@ -188,19 +194,24 @@ def bench_run(
g1_alphas=w1_gs,
g2_alphas=w2_gs,
)
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
CutlassExpertsFp4(
out_dtype=dtype,
max_experts_per_worker=e,
quant_config=quant_config,
),
)
for _ in range(num_repeats):
with nvtx.annotate("cutlass_moe_fp4", color="green"):
cutlass_moe_fp4(
a=a,
w1_fp4=w1_fp4,
w2_fp4=w2_fp4,
kernel(
hidden_states=a,
w1=w1_fp4,
w2=w2_fp4,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=m,
n=n,
k=k,
e=num_experts,
quant_config=quant_config,
)
def run_cutlass_from_graph(
@@ -230,20 +241,24 @@ def bench_run(
g2_alphas=w2_gs,
)
kernel = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
CutlassExpertsFp4(
out_dtype=dtype,
max_experts_per_worker=e,
quant_config=quant_config,
),
)
with set_current_vllm_config(
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
):
return cutlass_moe_fp4(
a=a,
w1_fp4=w1_fp4,
w2_fp4=w2_fp4,
return kernel(
hidden_states=a,
w1=w1_fp4,
w2=w2_fp4,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=m,
n=n,
k=k,
e=num_experts,
quant_config=quant_config,
)
def run_triton_from_graph(