[Kernels] Add activation chunking logic to FusedMoEModularKernel (#19168)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-06-11 12:53:10 -04:00
committed by GitHub
parent b2d9be6f7d
commit 29fa5cac1c
15 changed files with 458 additions and 396 deletions

View File

@@ -13,7 +13,8 @@ from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm_shape, deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe)
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
@@ -45,7 +46,7 @@ N = [128, 512, 7168, 7748, 13824]
K = [256, 3884, 4096, 13824, 16384]
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
# and its hidden size is 7168.
M_moe = [1, 2, 7, 83, 128, 2048]
M_moe = [1, 2, 7, 83, 128, 2048, 1024 * 128]
M_moe_dg = [128, 192, 1335, 2048]
N_moe = [128, 256, 1024, 4608] # [13824]
K_moe = [256, 512, 7168] # [13824]
@@ -214,6 +215,13 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
score = torch.randn((M, E), dtype=dtype)
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=False,
block_shape=block_size)
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
out = fused_moe(
@@ -231,6 +239,16 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk,
block_size)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
m_out = m_fused_moe(a,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=E,
w1_scale=w1_s,
w2_scale=w2_s)
#print(f"{out.sum()=}")
#print(f"{ref_out.sum()=}")
@@ -239,6 +257,11 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.03
rel_diff = (torch.mean(
torch.abs(m_out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.03
def per_block_cast_to_fp8(
x: torch.Tensor,