[MoE Refactor] Mxfp4 oracle rebased (#37128)

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Yongye Zhu
2026-03-20 22:37:04 -05:00
committed by GitHub
parent c7f98b4d0a
commit 87bd91892f
18 changed files with 1707 additions and 1381 deletions

View File

@@ -6,6 +6,7 @@ import pytest
import torch
import torch.nn.functional as F
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_triton_kernels
if not has_triton_kernels():
@@ -14,6 +15,7 @@ if not has_triton_kernels():
allow_module_level=True,
)
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
import triton_kernels.swiglu
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
from triton_kernels.numerics import InFlexData
@@ -303,6 +305,12 @@ def test_equiv(num_token, a_dtype, w_dtype, tp, workspace_init):
pc2,
) = init_compute_data(M, K, N, E, a_dtype, w_dtype, num_warps=8)
if current_platform.is_device_capability_family(100):
constraints = {
"is_persistent": True,
}
opt_flags.update_opt_flags_constraints(constraints)
if a_dtype == "bf16" and w_dtype == "mx4":
quant_config = mxfp4_w4a16_moe_quant_config(
w1_scale=pc1,

View File

@@ -82,7 +82,7 @@ def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase):
model_case.model_id,
tensor_parallel_size=model_case.tp,
load_format="dummy",
cudagraph_capture_sizes=[16],
compilation_config={"cudagraph_capture_sizes": [16]},
) as llm:
# Disabled as check_model is broken: https://github.com/vllm-project/vllm/pull/18465#issuecomment-3329880562
# def check_model(model):

View File

@@ -17,89 +17,6 @@ from unittest.mock import MagicMock, patch
import pytest
import torch
from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Backend,
Mxfp4MoEMethod,
)
def _make_mock_moe_config(ep_size: int = 1) -> MagicMock:
"""Create a mock FusedMoEConfig with the given EP size."""
parallel_config = MagicMock()
parallel_config.ep_size = ep_size
moe_config = MagicMock()
moe_config.ep_size = ep_size
moe_config.is_lora_enabled = False
moe_config.moe_parallel_config = parallel_config
return moe_config
class TestMxfp4TritonIsMonolithic:
"""Verify that is_monolithic is always True for the TRITON backend,
regardless of EP size, since triton_kernel_moe_forward now handles
expert_map remapping internally."""
@pytest.mark.parametrize(
"backend,ep_size,expected_monolithic",
[
# TRITON is always monolithic (handles EP via expert_map remapping)
(Mxfp4Backend.TRITON, 1, True),
(Mxfp4Backend.TRITON, 2, True),
(Mxfp4Backend.TRITON, 4, True),
# SM100 backends are always monolithic
(Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM, 1, True),
(Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM, 2, True),
(Mxfp4Backend.SM100_FI_MXFP4_BF16, 1, True),
(Mxfp4Backend.SM100_FI_MXFP4_BF16, 2, True),
# MARLIN is never monolithic
(Mxfp4Backend.MARLIN, 1, False),
(Mxfp4Backend.MARLIN, 2, False),
],
ids=[
"triton-no-ep",
"triton-ep2",
"triton-ep4",
"sm100-trtllm-no-ep",
"sm100-trtllm-ep2",
"sm100-bf16-no-ep",
"sm100-bf16-ep2",
"marlin-no-ep",
"marlin-ep2",
],
)
@patch(
"vllm.model_executor.layers.quantization.mxfp4.get_mxfp4_backend",
)
@patch(
"vllm.model_executor.layers.quantization.mxfp4.get_current_vllm_config",
)
def test_is_monolithic(
self,
mock_get_config,
mock_get_backend,
backend,
ep_size,
expected_monolithic,
):
"""is_monolithic should be True for TRITON regardless of EP size."""
mock_get_backend.return_value = backend
mock_compilation_config = MagicMock()
mock_compilation_config.max_cudagraph_capture_size = 1024
mock_vllm_config = MagicMock()
mock_vllm_config.compilation_config = mock_compilation_config
mock_get_config.return_value = mock_vllm_config
moe_config = _make_mock_moe_config(ep_size=ep_size)
method = Mxfp4MoEMethod(moe_config)
assert method.is_monolithic == expected_monolithic, (
f"Expected is_monolithic={expected_monolithic} for "
f"backend={backend.name}, ep_size={ep_size}, "
f"but got {method.is_monolithic}."
)
class TestTritonMoeForwardExpertMap:
"""Test that triton_kernel_moe_forward applies expert_map remapping