[MoE Refactor][16/N] Apply Refactor to NVFP4 (#31692)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@@ -11,14 +11,20 @@ import nvtx
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config,
|
||||
nvfp4_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
CutlassExpertsFp4,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
@@ -188,19 +194,24 @@ def bench_run(
|
||||
g1_alphas=w1_gs,
|
||||
g2_alphas=w2_gs,
|
||||
)
|
||||
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
|
||||
CutlassExpertsFp4(
|
||||
out_dtype=dtype,
|
||||
max_experts_per_worker=e,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
for _ in range(num_repeats):
|
||||
with nvtx.annotate("cutlass_moe_fp4", color="green"):
|
||||
cutlass_moe_fp4(
|
||||
a=a,
|
||||
w1_fp4=w1_fp4,
|
||||
w2_fp4=w2_fp4,
|
||||
kernel(
|
||||
hidden_states=a,
|
||||
w1=w1_fp4,
|
||||
w2=w2_fp4,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
e=num_experts,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def run_cutlass_from_graph(
|
||||
@@ -230,20 +241,24 @@ def bench_run(
|
||||
g2_alphas=w2_gs,
|
||||
)
|
||||
|
||||
kernel = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(defer_input_quant=True),
|
||||
CutlassExpertsFp4(
|
||||
out_dtype=dtype,
|
||||
max_experts_per_worker=e,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
)
|
||||
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||
):
|
||||
return cutlass_moe_fp4(
|
||||
a=a,
|
||||
w1_fp4=w1_fp4,
|
||||
w2_fp4=w2_fp4,
|
||||
return kernel(
|
||||
hidden_states=a,
|
||||
w1=w1_fp4,
|
||||
w2=w2_fp4,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
m=m,
|
||||
n=n,
|
||||
k=k,
|
||||
e=num_experts,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def run_triton_from_graph(
|
||||
|
||||
Reference in New Issue
Block a user