diff --git a/tests/kernels/moe/test_flashinfer.py b/tests/kernels/moe/test_flashinfer.py index 6a51853c0..ce3a1fcea 100644 --- a/tests/kernels/moe/test_flashinfer.py +++ b/tests/kernels/moe/test_flashinfer.py @@ -19,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.config import ( fp8_w8a8_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import ( - TrtLlmFp8Experts, + TrtLlmFp8ExpertsMonolithic, ) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts, @@ -247,7 +247,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph( allow_new_interface=True, use_monolithic=True, ), - TrtLlmFp8Experts( + TrtLlmFp8ExpertsMonolithic( moe_config=td.layer.moe, quant_config=quant_config, ), diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py index 1ed76f892..1c86702e9 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py @@ -4,6 +4,7 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, @@ -11,6 +12,9 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, RoutingMethodType, ) +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceNoOP, +) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( activation_to_flashinfer_int, ) @@ -22,10 +26,13 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ) from vllm.platforms import current_platform +logger = init_logger(__name__) -class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic): + +class TrtLlmFp8ExpertsBase: """ - Fp8 TRTLLM-Gen MoE kernels. Supports monolithic interface. + Fp8 TRTLLM-Gen MoE kernels. Shared base for modular and monolithic + interfaces. """ def __init__( @@ -33,8 +40,6 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic): moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig, ): - super().__init__(moe_config, quant_config) - self.routing_method_type = moe_config.routing_method self.topk = moe_config.experts_per_token self.intermediate_size_per_partition = ( @@ -44,24 +49,7 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic): self.local_num_experts = moe_config.num_local_experts self.ep_rank = moe_config.moe_parallel_config.ep_rank - # Make additional scales for per-tensor interface. - if self.quant_config.is_per_tensor: - w1_scale = self.quant_config.w1_scale - assert w1_scale is not None - a1_scale = self.quant_config.a1_scale - assert a1_scale is not None - w2_scale = self.quant_config.w2_scale - assert w2_scale is not None - a2_scale = self.quant_config.a2_scale - assert a2_scale is not None - - self._g1_alphas = (w1_scale * a1_scale).squeeze() - self._g2_alphas = (w2_scale * a2_scale).squeeze() - self._g1_scale_c = ( - self._g1_alphas / self.quant_config.a2_scale - if moe_config.is_act_and_mul - else torch.ones_like(self._g1_alphas) / self.quant_config.a2_scale - ) + self.quant_config = quant_config @staticmethod def activation_format() -> mk.FusedMoEActivationFormat: @@ -79,50 +67,11 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic): """Does not support non-gated MoE (i.e. Nanotron-3-Nano).""" return True - @staticmethod - def _supports_quant_scheme( - weight_key: QuantKey | None, - activation_key: QuantKey | None, - ) -> bool: - """Supports Fp8 per-tensor and Fp8 block.""" - SUPPORTED_W_A = [ - (kFp8Static128BlockSym, kFp8Dynamic128Sym), - (kFp8StaticTensorSym, kFp8StaticTensorSym), - ] - return (weight_key, activation_key) in SUPPORTED_W_A - @staticmethod def _supports_activation(activation: MoEActivation) -> bool: """Supports only SiLU and RELU^2 non-gated activation.""" return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] - @staticmethod - def _supports_routing_method( - routing_method: RoutingMethodType, - weight_key: QuantKey | None, - activation_key: QuantKey | None, - ) -> bool: - """Monolithic kernels need to express router support.""" - # NOTE(dbari): TopK routing could also be enabled, but need to validate models - # NOTE(dbari): Default is not implemented and should not be enabled until it is - if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym): - # NOTE(rob): potentially allow others here. This is a conservative list. - return routing_method in [ - RoutingMethodType.DeepSeekV3, - RoutingMethodType.Renormalize, - RoutingMethodType.RenormalizeNaive, - ] - elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym): - # NOTE(dbari): as above, potentially allow others here. - return routing_method in [ - RoutingMethodType.DeepSeekV3, - RoutingMethodType.Llama4, - RoutingMethodType.Renormalize, - RoutingMethodType.RenormalizeNaive, - ] - else: - raise ValueError("Unsupported quantization scheme.") - @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: """Monolithic kernel so only use with naive DP/EP and TP.""" @@ -153,6 +102,178 @@ class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic): def supports_expert_map(self) -> bool: return False + +class TrtLlmFp8ExpertsModular(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsModular): + """ + Fp8 TRTLLM-Gen MoE kernels. Supports modular interface. + """ + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + """Supports Fp8 block.""" + SUPPORTED_W_A = [ + (kFp8Static128BlockSym, kFp8Dynamic128Sym), + ] + return (weight_key, activation_key) in SUPPORTED_W_A + + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: MoEActivation, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + # The workspaces for this implementation are managed by flashinfer. + workspace1 = (0,) + workspace2 = (0,) + output = (M, K) + + return (workspace1, workspace2, output) + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + return TopKWeightAndReduceNoOP() + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + apply_router_weight_on_input: bool, + ): + import flashinfer + + # Pack topk_ids and topk_weights into single tensor + # Format: (expert_id << 16) | (weight_bf16.view(int16)) + packed_topk_ids = (topk_ids << 16) | topk_weights.to(torch.bfloat16).view( + torch.int16 + ) + + # trtllm_fp8_block_scale_routed_moe does not support autotuning + # so skip this kernel during dummy run for autotuning. + import vllm.utils.flashinfer as fi_utils + + if fi_utils._is_fi_autotuning: + return + + assert a1q_scale is not None + + # `trtllm_fp8_block_scale_routed_moe` has a bug and does not write to the + # output tensor in-place so we need to manually copy the result to the + # output tensor + # https://github.com/flashinfer-ai/flashinfer/issues/2703 + result = flashinfer.fused_moe.trtllm_fp8_block_scale_routed_moe( + topk_ids=packed_topk_ids, + routing_bias=None, + hidden_states=hidden_states, + hidden_states_scale=a1q_scale.t().contiguous(), # type: ignore[union-attr] + gemm1_weights=w1, + gemm1_weights_scale=self.quant_config.w1_scale, + gemm2_weights=w2, + gemm2_weights_scale=self.quant_config.w2_scale, + num_experts=global_num_experts, + top_k=self.topk, + n_group=None, + topk_group=None, + intermediate_size=self.intermediate_size_per_partition, + local_expert_offset=self.ep_rank * self.local_num_experts, + local_num_experts=self.local_num_experts, + routed_scaling_factor=None, + routing_method_type=1, + use_shuffled_weight=False, + weight_layout=0, + # output=output, + ) + output.copy_(result) + + +class TrtLlmFp8ExpertsMonolithic(TrtLlmFp8ExpertsBase, mk.FusedMoEExpertsMonolithic): + """ + Fp8 TRTLLM-Gen MoE kernels. Supports monolithic interface. + """ + + def __init__( + self, + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, + ): + super().__init__(moe_config, quant_config) + + # Make additional scales for per-tensor interface. + if self.quant_config.is_per_tensor: + w1_scale = self.quant_config.w1_scale + assert w1_scale is not None + a1_scale = self.quant_config.a1_scale + assert a1_scale is not None + w2_scale = self.quant_config.w2_scale + assert w2_scale is not None + a2_scale = self.quant_config.a2_scale + assert a2_scale is not None + + self._g1_alphas = (w1_scale * a1_scale).squeeze() + self._g2_alphas = (w2_scale * a2_scale).squeeze() + self._g1_scale_c = ( + self._g1_alphas / self.quant_config.a2_scale + if moe_config.is_act_and_mul + else torch.ones_like(self._g1_alphas) / self.quant_config.a2_scale + ) + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + """Supports Fp8 per-tensor and Fp8 block.""" + SUPPORTED_W_A = [ + (kFp8Static128BlockSym, kFp8Dynamic128Sym), + (kFp8StaticTensorSym, kFp8StaticTensorSym), + ] + return (weight_key, activation_key) in SUPPORTED_W_A + + @staticmethod + def _supports_routing_method( + routing_method: RoutingMethodType, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + """Monolithic kernels need to express router support.""" + # NOTE(dbari): TopK routing could also be enabled, but need to validate models + # NOTE(dbari): Default is not implemented and should not be enabled until it is + if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym): + # NOTE(rob): potentially allow others here. This is a conservative list. + return routing_method in [ + RoutingMethodType.DeepSeekV3, + RoutingMethodType.Renormalize, + RoutingMethodType.RenormalizeNaive, + ] + elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym): + # NOTE(dbari): as above, potentially allow others here. + return routing_method in [ + RoutingMethodType.DeepSeekV3, + RoutingMethodType.Llama4, + RoutingMethodType.Renormalize, + RoutingMethodType.RenormalizeNaive, + ] + else: + raise ValueError("Unsupported quantization scheme.") + def _apply_per_block( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index 85997468a..48ca03f66 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -104,83 +104,84 @@ def _get_priority_backends( def backend_to_kernel_cls( backend: Fp8MoeBackend, -) -> type[mk.FusedMoEExperts]: +) -> list[type[mk.FusedMoEExperts]]: if backend == Fp8MoeBackend.FLASHINFER_TRTLLM: from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import ( # noqa: E501 - TrtLlmFp8Experts, + TrtLlmFp8ExpertsModular, + TrtLlmFp8ExpertsMonolithic, ) - return TrtLlmFp8Experts + return [TrtLlmFp8ExpertsMonolithic, TrtLlmFp8ExpertsModular] elif backend == Fp8MoeBackend.FLASHINFER_CUTLASS: from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts, ) - return FlashInferExperts + return [FlashInferExperts] elif backend == Fp8MoeBackend.DEEPGEMM: from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts, ) - return TritonOrDeepGemmExperts + return [TritonOrDeepGemmExperts] elif backend == Fp8MoeBackend.BATCHED_DEEPGEMM: from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts, ) - return BatchedDeepGemmExperts + return [BatchedDeepGemmExperts] elif backend == Fp8MoeBackend.MARLIN: from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( MarlinExperts, ) - return MarlinExperts + return [MarlinExperts] elif backend == Fp8MoeBackend.TRITON: from vllm.model_executor.layers.fused_moe.fused_moe import ( TritonExperts, ) - return TritonExperts + return [TritonExperts] elif backend == Fp8MoeBackend.BATCHED_TRITON: from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( BatchedTritonExperts, ) - return BatchedTritonExperts + return [BatchedTritonExperts] elif backend == Fp8MoeBackend.AITER: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( AiterExperts, ) - return AiterExperts + return [AiterExperts] elif backend == Fp8MoeBackend.VLLM_CUTLASS: from vllm.model_executor.layers.fused_moe.triton_cutlass_moe import ( TritonOrCutlassExperts, ) - return TritonOrCutlassExperts + return [TritonOrCutlassExperts] elif backend == Fp8MoeBackend.BATCHED_VLLM_CUTLASS: from vllm.model_executor.layers.fused_moe.cutlass_moe import ( CutlassBatchedExpertsFp8, ) - return CutlassBatchedExpertsFp8 + return [CutlassBatchedExpertsFp8] elif backend == Fp8MoeBackend.XPU: from vllm.model_executor.layers.fused_moe.xpu_fused_moe import ( XPUExpertsFp8, ) - return XPUExpertsFp8 + return [XPUExpertsFp8] else: raise ValueError(f"Unknown FP8 MoE backend: {backend.value}") @@ -215,8 +216,9 @@ def select_fp8_moe_backend( Select the primary FP8 MoE backend Note: Shape-specific fallbacks may still occur at runtime. """ + if config.is_lora_enabled: - return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON) + return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON)[0] # NOTE: the kernels are selected in the following order. AVAILABLE_BACKENDS = _get_priority_backends(config, weight_key, activation_key) @@ -256,13 +258,13 @@ def select_fp8_moe_backend( activation_key: QuantKey | None, activation_format: mk.FusedMoEActivationFormat, ) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts]]: - k_cls = backend_to_kernel_cls(backend) - supported, reason = k_cls.is_supported_config( - k_cls, config, weight_key, activation_key, activation_format - ) - if supported: - logger.info_once(_make_log_backend(backend), scope="local") - return backend, k_cls + for k_cls in backend_to_kernel_cls(backend): + supported, reason = k_cls.is_supported_config( + k_cls, config, weight_key, activation_key, activation_format + ) + if supported: + logger.info_once(_make_log_backend(backend), scope="local") + return backend, k_cls raise ValueError(_make_log_unsupported(backend, reason)) # Handle explicit moe_backend from user. @@ -312,7 +314,7 @@ def select_fp8_moe_backend( raise ValueError( f"FlashInfer MOE backend {fi_backend} does not support FP8 MoE." ) - k_cls = backend_to_kernel_cls(backend) + k_cls = backend_to_kernel_cls(backend)[0] return _return_or_raise( backend, config, weight_key, activation_key, activation_format ) @@ -322,23 +324,23 @@ def select_fp8_moe_backend( Fp8MoeBackend.FLASHINFER_TRTLLM, Fp8MoeBackend.FLASHINFER_CUTLASS, ]: - k_cls = backend_to_kernel_cls(backend) - supported, reason = k_cls.is_supported_config( - k_cls, - config, - weight_key, - activation_key, - activation_format, - ) - - if supported: - logger.info_once(_make_log_backend(backend), scope="local") - return backend, k_cls - else: - logger.debug_once( - _make_log_unsupported(backend, reason), scope="local" + for k_cls in backend_to_kernel_cls(backend): + supported, reason = k_cls.is_supported_config( + k_cls, + config, + weight_key, + activation_key, + activation_format, ) + if supported: + logger.info_once(_make_log_backend(backend), scope="local") + return backend, k_cls + else: + logger.debug_once( + _make_log_unsupported(backend, reason), scope="local" + ) + raise NotImplementedError( "Found VLLM_USE_FLASHINFER_MOE_FP8=1, but no " "FlashInfer FP8 MoE backend supports the configuration." @@ -382,20 +384,19 @@ def select_fp8_moe_backend( # Select kernels in order of backend. for backend in AVAILABLE_BACKENDS: - k_cls = backend_to_kernel_cls(backend) - supported, reason = k_cls.is_supported_config( - k_cls, - config, - weight_key, - activation_key, - activation_format, - ) - - if supported: - logger.info_once(_make_log_backend(backend), scope="local") - return backend, k_cls - else: - logger.debug_once(_make_log_unsupported(backend, reason), scope="local") + for k_cls in backend_to_kernel_cls(backend): + supported, reason = k_cls.is_supported_config( + k_cls, + config, + weight_key, + activation_key, + activation_format, + ) + if supported: + logger.info_once(_make_log_backend(backend), scope="local") + return backend, k_cls + else: + logger.debug_once(_make_log_unsupported(backend, reason), scope="local") # TODO(rob): per discussion with TPU team, we need a way to register # MoE backends by OOT plugins, rather than having an explicit list