[Kernels] Add activation chunking logic to FusedMoEModularKernel (#19168)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-06-11 12:53:10 -04:00
committed by GitHub
parent b2d9be6f7d
commit 29fa5cac1c
15 changed files with 458 additions and 396 deletions

View File

@@ -15,7 +15,8 @@ import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe)
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as iterative_moe)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
@@ -76,6 +77,13 @@ def test_fused_moe(
else:
e_map = None
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=False,
block_shape=None)
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
iterative_output = iterative_moe(a,
@@ -103,7 +111,20 @@ def test_fused_moe(
expert_map=e_map,
renormalize=False)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
m_triton_output = m_fused_moe(a,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=e,
expert_map=e_map)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
torch.testing.assert_close(m_triton_output,
torch_output,
atol=2e-2,
rtol=0)
torch.testing.assert_close(iterative_output,
torch_output,
atol=2e-2,