[MoE Refactor] Oracle Select FP8+NVFP4 Kernels In Priority (#32414)
This commit is contained in:
@@ -8,7 +8,12 @@ from tests.kernels.quant_utils import per_block_cast_to_int8
|
||||
from tests.kernels.quantization.nvfp4_utils import FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize,
|
||||
BatchedTritonExperts,
|
||||
@@ -20,6 +25,34 @@ from vllm.utils.deep_gemm import per_block_cast_to_fp8
|
||||
from vllm.utils.math_utils import round_up
|
||||
|
||||
|
||||
def make_dummy_moe_config(
|
||||
num_experts: int = 1,
|
||||
experts_per_token: int = 1,
|
||||
hidden_dim: int = 1,
|
||||
intermediate_size_per_partition: int = 1,
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
) -> FusedMoEConfig:
|
||||
"""
|
||||
This is a dummy config for the mk constructor interface
|
||||
as most kernels like DeepGEMM, CUTLASSFp4, Triton, MARLIN
|
||||
do not actually use this config.
|
||||
|
||||
CUTLASSFp8 needs to set some params for workshapes.
|
||||
"""
|
||||
return FusedMoEConfig(
|
||||
num_experts=num_experts,
|
||||
experts_per_token=experts_per_token,
|
||||
hidden_dim=hidden_dim,
|
||||
intermediate_size_per_partition=intermediate_size_per_partition,
|
||||
num_local_experts=num_experts,
|
||||
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
|
||||
activation="silu",
|
||||
in_dtype=in_dtype,
|
||||
device="cuda",
|
||||
routing_method=RoutingMethodType.TopK,
|
||||
)
|
||||
|
||||
|
||||
def triton_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@@ -81,6 +114,7 @@ def batched_moe(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
quant_config=quant_config,
|
||||
moe_config=make_dummy_moe_config(),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -121,6 +155,7 @@ def naive_batched_moe(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=1,
|
||||
quant_config=quant_config,
|
||||
moe_config=make_dummy_moe_config(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user