[Model]: Fused MoE for nomic-embed-text-v2-moe (#18321)

Signed-off-by: isotr0py <2037008807@qq.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2025-07-29 18:13:27 +08:00
committed by GitHub
parent a2480251ec
commit a4528f0cac
2 changed files with 140 additions and 111 deletions

View File

@@ -7,6 +7,7 @@ import os
from typing import Any, Callable, Optional
import torch
import torch.nn.functional as F
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
@@ -1001,6 +1002,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
@@ -1018,7 +1020,8 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8,
activation, is_act_and_mul,
apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
use_mxfp4_w4a4, per_channel_quant, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
@@ -1032,6 +1035,7 @@ def inplace_fused_experts_fake(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
@@ -1167,6 +1171,7 @@ def outplace_fused_experts(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
@@ -1183,13 +1188,12 @@ def outplace_fused_experts(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
use_int4_w4a16, use_mxfp4_w4a4,
per_channel_quant, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape)
return fused_experts_impl(
hidden_states, w1, w2, topk_weights, topk_ids, False, activation,
is_act_and_mul, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, use_mxfp4_w4a4,
per_channel_quant, global_num_experts, expert_map, w1_scale, w2_scale,
w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
def outplace_fused_experts_fake(
@@ -1199,6 +1203,7 @@ def outplace_fused_experts_fake(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
is_act_and_mul: bool = True,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
@@ -1253,6 +1258,7 @@ def fused_experts(
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
@@ -1283,6 +1289,8 @@ def fused_experts(
or is_blackwell_deep_gemm_used())
if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm):
assert apply_router_weight_on_input is False
assert is_act_and_mul, (
"DeepGemm only supports is_act_and_mul=True for now.")
return deep_gemm_moe_fp8(
hidden_states=hidden_states,
w1=w1,
@@ -1319,6 +1327,7 @@ def fused_experts(
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
is_act_and_mul=is_act_and_mul,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
@@ -1345,6 +1354,7 @@ def fused_experts_impl(
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
is_act_and_mul: bool = True,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
@@ -1503,14 +1513,21 @@ def fused_experts_impl(
per_channel_quant=per_channel_quant,
block_shape=block_shape)
if activation == "silu":
# Activation function with multiplication
if activation == "silu" and is_act_and_mul:
torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
elif activation == "gelu":
elif activation == "gelu" and is_act_and_mul:
torch.ops._C.gelu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
# Activation function without multiplication
elif activation == "silu":
intermediate_cache2 = F.silu(intermediate_cache1.view(-1, N))
elif activation == "gelu":
intermediate_cache2 = F.gelu(intermediate_cache1.view(-1, N))
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
raise ValueError(f"Unsupported FusedMoe activation: {activation}, "
f"with is_act_and_mul={is_act_and_mul}.")
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
A=intermediate_cache2,
@@ -1555,6 +1572,7 @@ def fused_moe(
renormalize: bool,
inplace: bool = False,
activation: str = "silu",
is_act_and_mul: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
@@ -1591,6 +1609,9 @@ def fused_moe(
Defaults to False.
- activation (str): The activation function to apply after the first
MoE layer.
- is_act_and_mul (bool): If True, use activation-and-mul function for
activation (self-gated activation), otherwise use activation function
for activation (ungated activation).
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
@@ -1627,6 +1648,9 @@ def fused_moe(
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
if not is_act_and_mul:
assert inplace is False, (
"is_act_and_mul=False is not supported with inplace=True")
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
@@ -1647,6 +1671,7 @@ def fused_moe(
topk_ids,
inplace=inplace,
activation=activation,
is_act_and_mul=is_act_and_mul,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,