[Feat] Support non-gated activations in NVFP4 modelopt path (#29004)

This commit is contained in:
Omer Ullman Argov
2025-11-30 18:02:40 +02:00
committed by GitHub
parent cd719de5cb
commit 39d28108f4
5 changed files with 98 additions and 22 deletions

View File

@@ -16,11 +16,11 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
is_valid_flashinfer_cutlass_fused_moe,
)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import (
create_flashinfer_prepare_finalize,
)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
@@ -48,9 +48,10 @@ MNK_FACTORS = [
@pytest.mark.parametrize("e", [40, 64, 256])
@pytest.mark.parametrize("topk", [1, 6, 8])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("activation", ["silu_and_mul", "relu2"])
@torch.inference_mode()
def test_flashinfer_fp4_moe_no_graph(
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, activation: str
):
current_platform.seed_everything(7)
with set_current_vllm_config(
@@ -59,6 +60,7 @@ def test_flashinfer_fp4_moe_no_graph(
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
quant_blocksize = 16
is_gated_act = activation == "silu_and_mul"
w1_q, w2_q, quant_config = make_test_quant_config(
e,
@@ -68,6 +70,7 @@ def test_flashinfer_fp4_moe_no_graph(
quant_dtype="nvfp4",
block_shape=None,
per_act_token_quant=False,
make_gate=is_gated_act,
)
score = torch.randn((m, e), device="cuda", dtype=dtype)
@@ -76,16 +79,19 @@ def test_flashinfer_fp4_moe_no_graph(
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
flashinfer_experts = FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(),
create_flashinfer_prepare_finalize(use_dp=False, use_nvfp4=True),
FlashInferExperts(out_dtype=dtype, quant_config=quant_config),
)
fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation]
flashinfer_output = flashinfer_experts(
hidden_states=a,
w1=w1_q,
w2=w2_q,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=fi_activation,
)
# Reference check:
@@ -103,7 +109,9 @@ def test_flashinfer_fp4_moe_no_graph(
block_size=quant_blocksize,
)
w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
w1_d = torch.empty(
(e, (2 if is_gated_act else 1) * n, k), device="cuda", dtype=dtype
)
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
for idx in range(0, e):
@@ -124,7 +132,9 @@ def test_flashinfer_fp4_moe_no_graph(
block_size=quant_blocksize,
)
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
torch_output = torch_moe(
a_in_dtype, w1_d, w2_d, score, topk, activation=activation
)
torch.testing.assert_close(
torch_output, flashinfer_output, atol=1e-1, rtol=1e-1