[Feat] Support non-gated MoE with Marlin, NVFP4 CUTLASS, FP8, INT8, compressed-tensors (#32257)

Signed-off-by: Tomer Natan <tbarnatan@computelab-frontend-8.nvidia.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Tomer Natan <tbarnatan@computelab-frontend-8.nvidia.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Tomer Natan <tbarnatan@ipp1-1429.ipp1a1.colossus.nvidia.com>
This commit is contained in:
TomerBN-Nvidia
2026-01-16 02:15:05 +02:00
committed by GitHub
parent aca5c51487
commit c277fbdf31
17 changed files with 226 additions and 127 deletions

View File

@@ -1079,6 +1079,86 @@ def test_fused_marlin_moe_with_bias(m):
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
@pytest.mark.parametrize("m", [1, 64, 256])
@pytest.mark.parametrize("n,k", [(1024, 1024), (2048, 2048)])
@pytest.mark.parametrize("e,topk", [(8, 2), (64, 4)])
def test_fused_marlin_moe_non_gated(m: int, n: int, k: int, e: int, topk: int):
"""Test Marlin MoE with non-gated activation (relu2_no_mul).
Non-gated activations like relu2 don't have the gate-up projection pattern,
so w1 has shape (e, n, k) instead of (e, 2*n, k).
"""
torch.cuda.manual_seed(42)
group_size = 16 # NVFP4 group size
is_k_full = True
quant_type = scalar_types.float4_e2m1f
dtype = torch.bfloat16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
# Non-gated: w1 shape is (e, n, k) not (e, 2*n, k)
w1 = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
w1_data = MarlinMoEWeightData.make(
w=w1,
quant_type=quant_type,
group_size=group_size,
act_order=False,
)
w2_data = MarlinMoEWeightData.make(
w=w2,
quant_type=quant_type,
group_size=group_size,
act_order=False,
)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(
a,
w1_data.w_ref,
w2_data.w_ref,
score,
topk,
activation="relu2",
)
marlin_output = fused_marlin_moe(
a,
w1_data.qweight,
w2_data.qweight,
None, # bias1
None, # bias2
w1_data.scales,
w2_data.scales,
score,
topk_weights,
topk_ids,
global_num_experts=e,
expert_map=None,
global_scale1=w1_data.global_scale,
global_scale2=w2_data.global_scale,
g_idx1=w1_data.g_idx,
g_idx2=w2_data.g_idx,
sort_indices1=w1_data.sort_indices,
sort_indices2=w2_data.sort_indices,
w1_zeros=w1_data.zeros,
w2_zeros=w2_data.zeros,
quant_type_id=quant_type.id,
is_k_full=is_k_full,
activation="relu2_no_mul",
)
torch.testing.assert_close(marlin_output, torch_output, atol=1e-1, rtol=0)
@pytest.mark.parametrize("ep_size", [1, 2])
def test_moe_align_block_size_opcheck(ep_size):
num_experts = 4