[Feat] Support non-gated MoE with Marlin, NVFP4 CUTLASS, FP8, INT8, compressed-tensors (#32257)
Signed-off-by: Tomer Natan <tbarnatan@computelab-frontend-8.nvidia.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Tomer Natan <tbarnatan@computelab-frontend-8.nvidia.com> Co-authored-by: mgoin <mgoin64@gmail.com> Co-authored-by: Tomer Natan <tbarnatan@ipp1-1429.ipp1a1.colossus.nvidia.com>
This commit is contained in:
@@ -1079,6 +1079,86 @@ def test_fused_marlin_moe_with_bias(m):
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=2)
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
@pytest.mark.parametrize("m", [1, 64, 256])
|
||||
@pytest.mark.parametrize("n,k", [(1024, 1024), (2048, 2048)])
|
||||
@pytest.mark.parametrize("e,topk", [(8, 2), (64, 4)])
|
||||
def test_fused_marlin_moe_non_gated(m: int, n: int, k: int, e: int, topk: int):
|
||||
"""Test Marlin MoE with non-gated activation (relu2_no_mul).
|
||||
|
||||
Non-gated activations like relu2 don't have the gate-up projection pattern,
|
||||
so w1 has shape (e, n, k) instead of (e, 2*n, k).
|
||||
"""
|
||||
torch.cuda.manual_seed(42)
|
||||
|
||||
group_size = 16 # NVFP4 group size
|
||||
is_k_full = True
|
||||
quant_type = scalar_types.float4_e2m1f
|
||||
dtype = torch.bfloat16
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
# Non-gated: w1 shape is (e, n, k) not (e, 2*n, k)
|
||||
w1 = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
|
||||
w1_data = MarlinMoEWeightData.make(
|
||||
w=w1,
|
||||
quant_type=quant_type,
|
||||
group_size=group_size,
|
||||
act_order=False,
|
||||
)
|
||||
|
||||
w2_data = MarlinMoEWeightData.make(
|
||||
w=w2,
|
||||
quant_type=quant_type,
|
||||
group_size=group_size,
|
||||
act_order=False,
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch_output = torch_moe(
|
||||
a,
|
||||
w1_data.w_ref,
|
||||
w2_data.w_ref,
|
||||
score,
|
||||
topk,
|
||||
activation="relu2",
|
||||
)
|
||||
|
||||
marlin_output = fused_marlin_moe(
|
||||
a,
|
||||
w1_data.qweight,
|
||||
w2_data.qweight,
|
||||
None, # bias1
|
||||
None, # bias2
|
||||
w1_data.scales,
|
||||
w2_data.scales,
|
||||
score,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
global_scale1=w1_data.global_scale,
|
||||
global_scale2=w2_data.global_scale,
|
||||
g_idx1=w1_data.g_idx,
|
||||
g_idx2=w2_data.g_idx,
|
||||
sort_indices1=w1_data.sort_indices,
|
||||
sort_indices2=w2_data.sort_indices,
|
||||
w1_zeros=w1_data.zeros,
|
||||
w2_zeros=w2_data.zeros,
|
||||
quant_type_id=quant_type.id,
|
||||
is_k_full=is_k_full,
|
||||
activation="relu2_no_mul",
|
||||
)
|
||||
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=1e-1, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ep_size", [1, 2])
|
||||
def test_moe_align_block_size_opcheck(ep_size):
|
||||
num_experts = 4
|
||||
|
||||
Reference in New Issue
Block a user