[Kernel] Delegate construction of FusedMoEQuantConfig to FusedMoEMethodBase subclasses (#22537)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-09-17 19:43:31 -04:00
committed by GitHub
parent e6585ddb45
commit 5963b98b46
68 changed files with 2698 additions and 2526 deletions

View File

@@ -15,6 +15,7 @@ from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
@@ -129,11 +130,9 @@ def make_modular_kernel(
num_local_experts: int,
q_dtype: Optional[torch.dtype],
use_fp8_dispatch: bool,
per_act_token_quant: bool,
quant_config: FusedMoEQuantConfig,
) -> FusedMoEModularKernel:
is_quantized = q_dtype is not None
ht_args: Optional[DeepEPHTArgs] = None
ll_args: Optional[DeepEPLLArgs] = None
@@ -159,24 +158,14 @@ def make_modular_kernel(
num_dispatchers = pgi.world_size // dp_size
if low_latency_mode:
assert not per_act_token_quant, "not supported in ll mode"
assert not quant_config.per_act_token_quant, "not supported in ll mode"
fused_experts = BatchedTritonExperts(
max_num_tokens=MAX_TOKENS_PER_RANK,
num_dispatchers=num_dispatchers,
use_fp8_w8a8=is_quantized,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_act_token_quant=False,
quant_config=quant_config,
)
else:
fused_experts = TritonExperts(
use_fp8_w8a8=is_quantized,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_act_token_quant=per_act_token_quant,
)
fused_experts = TritonExperts(quant_config=quant_config)
mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts)
@@ -217,11 +206,6 @@ def deep_ep_moe_impl(
if is_quantized:
q_dtype = torch.float8_e4m3fn
# Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(
pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts,
num_local_experts, q_dtype, use_fp8_dispatch, per_act_token_quant)
out_hidden_states = torch.empty_like(test_tensors.rank_tokens)
total_num_tokens = test_tensors.rank_tokens.size(0)
@@ -236,6 +220,19 @@ def deep_ep_moe_impl(
rank_token_scales_chunk = rank_token_scales_chunk[
chunk_start:chunk_end]
quant_config = FusedMoEQuantConfig.make(
q_dtype,
w1_scale=w1_scale,
w2_scale=w2_scale,
per_act_token_quant=per_act_token_quant,
a1_scale=rank_token_scales_chunk,
)
# Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(
pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts,
num_local_experts, q_dtype, use_fp8_dispatch, quant_config)
out = mk.forward(hidden_states=rank_tokens_chunk,
w1=w1,
w2=w2,
@@ -245,12 +242,6 @@ def deep_ep_moe_impl(
activation="silu",
global_num_experts=num_experts,
expert_map=build_expert_map(),
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=None,
w2_zp=None,
a1_scale=rank_token_scales_chunk,
a2_scale=None,
apply_router_weight_on_input=False)
if not skip_result_store:
@@ -407,7 +398,7 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("mnk", MNKs)
@pytest.mark.parametrize("m,n,k", MNKs)
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("topk", [6])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@@ -416,7 +407,9 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
@requires_deep_ep
def test_deep_ep_moe(
dtype: torch.dtype,
mnk: tuple[int, int, int],
m: int,
n: int,
k: int,
num_experts: int,
topk: int,
world_dp_size: tuple[int, int],
@@ -424,7 +417,6 @@ def test_deep_ep_moe(
):
low_latency_mode = False
use_fp8_dispatch = False
m, n, k = mnk
current_platform.seed_everything(7)
world_size, dp_size = world_dp_size
@@ -456,20 +448,24 @@ USE_FP8_DISPATCH = [True, False]
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("mnk", MNKs)
@pytest.mark.parametrize("m,n,k", MNKs)
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("topk", [6])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH)
@multi_gpu_test(num_gpus=2)
@requires_deep_ep
def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
num_experts: int, topk: int,
world_dp_size: tuple[int, int],
use_fp8_dispatch: bool):
def test_low_latency_deep_ep_moe(
dtype: torch.dtype,
m: int,
n: int,
k: int,
num_experts: int,
topk: int,
world_dp_size: tuple[int, int],
use_fp8_dispatch: bool,
):
low_latency_mode = True
m, n, k = mnk
if (low_latency_mode
and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES):