[Feat] Support non-gated activations in NVFP4 modelopt path (#29004)
This commit is contained in:
committed by
GitHub
parent
cd719de5cb
commit
39d28108f4
@@ -16,11 +16,11 @@ from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts,
|
||||
is_valid_flashinfer_cutlass_fused_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import (
|
||||
create_flashinfer_prepare_finalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
@@ -48,9 +48,10 @@ MNK_FACTORS = [
|
||||
@pytest.mark.parametrize("e", [40, 64, 256])
|
||||
@pytest.mark.parametrize("topk", [1, 6, 8])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("activation", ["silu_and_mul", "relu2"])
|
||||
@torch.inference_mode()
|
||||
def test_flashinfer_fp4_moe_no_graph(
|
||||
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
|
||||
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, activation: str
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(
|
||||
@@ -59,6 +60,7 @@ def test_flashinfer_fp4_moe_no_graph(
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
|
||||
quant_blocksize = 16
|
||||
is_gated_act = activation == "silu_and_mul"
|
||||
|
||||
w1_q, w2_q, quant_config = make_test_quant_config(
|
||||
e,
|
||||
@@ -68,6 +70,7 @@ def test_flashinfer_fp4_moe_no_graph(
|
||||
quant_dtype="nvfp4",
|
||||
block_shape=None,
|
||||
per_act_token_quant=False,
|
||||
make_gate=is_gated_act,
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
@@ -76,16 +79,19 @@ def test_flashinfer_fp4_moe_no_graph(
|
||||
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
|
||||
|
||||
flashinfer_experts = FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
create_flashinfer_prepare_finalize(use_dp=False, use_nvfp4=True),
|
||||
FlashInferExperts(out_dtype=dtype, quant_config=quant_config),
|
||||
)
|
||||
|
||||
fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation]
|
||||
|
||||
flashinfer_output = flashinfer_experts(
|
||||
hidden_states=a,
|
||||
w1=w1_q,
|
||||
w2=w2_q,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=fi_activation,
|
||||
)
|
||||
|
||||
# Reference check:
|
||||
@@ -103,7 +109,9 @@ def test_flashinfer_fp4_moe_no_graph(
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
|
||||
w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
|
||||
w1_d = torch.empty(
|
||||
(e, (2 if is_gated_act else 1) * n, k), device="cuda", dtype=dtype
|
||||
)
|
||||
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
|
||||
|
||||
for idx in range(0, e):
|
||||
@@ -124,7 +132,9 @@ def test_flashinfer_fp4_moe_no_graph(
|
||||
block_size=quant_blocksize,
|
||||
)
|
||||
|
||||
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
|
||||
torch_output = torch_moe(
|
||||
a_in_dtype, w1_d, w2_d, score, topk, activation=activation
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
torch_output, flashinfer_output, atol=1e-1, rtol=1e-1
|
||||
|
||||
@@ -264,13 +264,20 @@ def make_test_weights(
|
||||
quant_dtype: torch.dtype | str | None = None,
|
||||
block_shape: list[int] | None = None,
|
||||
per_out_ch_quant: bool = False,
|
||||
make_gate: bool = True,
|
||||
) -> tuple[
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None],
|
||||
]:
|
||||
return (
|
||||
make_test_weight(
|
||||
e, 2 * n, k, in_dtype, quant_dtype, block_shape, per_out_ch_quant
|
||||
e,
|
||||
(2 if make_gate else 1) * n,
|
||||
k,
|
||||
in_dtype,
|
||||
quant_dtype,
|
||||
block_shape,
|
||||
per_out_ch_quant,
|
||||
),
|
||||
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape, per_out_ch_quant),
|
||||
)
|
||||
@@ -297,6 +304,7 @@ def make_test_quant_config(
|
||||
quant_dtype: torch.dtype | str | None = None,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: list[int] | None = None,
|
||||
make_gate: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, FusedMoEQuantConfig]:
|
||||
(_, w1, w1_s, w1_gs), (_, w2, w2_s, w2_gs) = make_test_weights(
|
||||
e,
|
||||
@@ -306,6 +314,7 @@ def make_test_quant_config(
|
||||
quant_dtype,
|
||||
per_out_ch_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
make_gate=make_gate,
|
||||
)
|
||||
|
||||
# Hacky/trivial scales for nvfp4.
|
||||
|
||||
@@ -14,6 +14,7 @@ from torch._prims_common import TensorLikeType
|
||||
|
||||
from tests.kernels.quant_utils import native_w8a8_block_matmul
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
|
||||
from vllm.utils.torch_utils import make_tensor_with_pad
|
||||
@@ -839,6 +840,7 @@ def torch_experts(
|
||||
per_act_token_quant=False,
|
||||
block_shape: list[int] | None = None,
|
||||
apply_router_weights_on_input: bool = False,
|
||||
activation: str = "silu_and_mul",
|
||||
) -> torch.Tensor:
|
||||
assert (
|
||||
global_num_experts == -1
|
||||
@@ -881,6 +883,8 @@ def torch_experts(
|
||||
|
||||
f32 = torch.float32
|
||||
|
||||
act = CustomOp.op_registry[activation]
|
||||
|
||||
for i in range(num_experts):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
@@ -888,7 +892,7 @@ def torch_experts(
|
||||
tmp1 = a[mask] @ w1[i].transpose(0, 1)
|
||||
if b_bias1 is not None:
|
||||
tmp1 = tmp1 + b_bias1[i].view(1, -1).to(tmp1.dtype)
|
||||
tmp2 = SiluAndMul()(tmp1)
|
||||
tmp2 = act()(tmp1)
|
||||
out[mask] = tmp2 @ w2[i].transpose(0, 1)
|
||||
if b_bias2 is not None:
|
||||
out[mask] = out[mask] + b_bias2[i].view(1, -1).to(tmp1.dtype)
|
||||
@@ -969,6 +973,7 @@ def torch_moe(
|
||||
b_bias2: torch.Tensor | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
activation: str = "silu_and_mul",
|
||||
) -> torch.Tensor:
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
@@ -982,6 +987,7 @@ def torch_moe(
|
||||
b_bias1,
|
||||
b_bias2,
|
||||
expert_map,
|
||||
activation=activation,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user