[MoE Refactor] Oracle Select FP8+NVFP4 Kernels In Priority (#32414)

This commit is contained in:
Robert Shaw
2026-01-21 08:22:33 -05:00
committed by GitHub
parent e14467be43
commit 42135d6898
82 changed files with 2710 additions and 1563 deletions

View File

@@ -8,7 +8,12 @@ from tests.kernels.quant_utils import per_block_cast_to_int8
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize,
BatchedTritonExperts,
@@ -20,6 +25,34 @@ from vllm.utils.deep_gemm import per_block_cast_to_fp8
from vllm.utils.math_utils import round_up
def make_dummy_moe_config(
num_experts: int = 1,
experts_per_token: int = 1,
hidden_dim: int = 1,
intermediate_size_per_partition: int = 1,
in_dtype: torch.dtype = torch.bfloat16,
) -> FusedMoEConfig:
"""
This is a dummy config for the mk constructor interface
as most kernels like DeepGEMM, CUTLASSFp4, Triton, MARLIN
do not actually use this config.
CUTLASSFp8 needs to set some params for workshapes.
"""
return FusedMoEConfig(
num_experts=num_experts,
experts_per_token=experts_per_token,
hidden_dim=hidden_dim,
intermediate_size_per_partition=intermediate_size_per_partition,
num_local_experts=num_experts,
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
activation="silu",
in_dtype=in_dtype,
device="cuda",
routing_method=RoutingMethodType.TopK,
)
def triton_moe(
a: torch.Tensor,
w1: torch.Tensor,
@@ -81,6 +114,7 @@ def batched_moe(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
),
)
@@ -121,6 +155,7 @@ def naive_batched_moe(
max_num_tokens=max_num_tokens,
num_dispatchers=1,
quant_config=quant_config,
moe_config=make_dummy_moe_config(),
),
)