[MoE Refactor] Oracle Select FP8+NVFP4 Kernels In Priority (#32414)

This commit is contained in:
Robert Shaw
2026-01-21 08:22:33 -05:00
committed by GitHub
parent e14467be43
commit 42135d6898
82 changed files with 2710 additions and 1563 deletions

View File

@@ -16,10 +16,16 @@ import torch
from ray.experimental.tqdm_ray import tqdm
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
_get_config_dtype_str,
)
from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
)
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
from vllm.triton_utils import triton
@@ -194,10 +200,33 @@ def benchmark_config(
block_shape=block_quant_shape,
)
deep_gemm_experts = mk.FusedMoEModularKernel(
prepare_finalize=MoEPrepareAndFinalizeNoEP(),
fused_experts=TritonOrDeepGemmExperts(
moe_config=FusedMoEConfig(
num_experts=num_experts,
experts_per_token=topk,
hidden_dim=hidden_size,
intermediate_size_per_partition=shard_intermediate_size,
num_local_experts=num_experts,
activation="silu",
parallel_config=FusedMoEParallelConfig.make_no_parallel(),
in_dtype=init_dtype,
routing_method=RoutingMethodType.TopK,
),
quant_config=quant_config,
),
)
with override_config(config):
topk_weights, topk_ids, token_expert_indices = fused_topk(
x, input_gating, topk, renormalize=not use_deep_gemm
)
if use_deep_gemm:
return deep_gemm_experts(
x, w1, w2, topk_weights, topk_ids, inplace=True
)
return fused_experts(
x,
w1,
@@ -206,7 +235,6 @@ def benchmark_config(
topk_ids,
inplace=True,
quant_config=quant_config,
allow_deep_gemm=use_deep_gemm,
)
# JIT compilation & warmup