[Performance] Support FP8 flashinfer TRTLLM MOE on Qwen3 and Qwen-3next (#27492)

Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
This commit is contained in:
jiahanc
2025-11-10 09:34:57 -08:00
committed by GitHub
parent b039bfda8f
commit 34553b9d27
7 changed files with 78 additions and 30 deletions

View File

@@ -43,6 +43,7 @@ from vllm.distributed import (
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
@@ -171,6 +172,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
routing_method_type=RoutingMethodType.Renormalize,
)
self.gate = ReplicatedLinear(

View File

@@ -34,6 +34,7 @@ from vllm.model_executor.layers.fla.ops import (
fused_recurrent_gated_delta_rule,
)
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.layernorm import (
GemmaRMSNorm as Qwen3NextRMSNorm,
)
@@ -173,6 +174,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel,
routing_method_type=RoutingMethodType.Renormalize,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: